Skip to content

Commit 22a0375

Browse files
committed
handle expiration of codes for residual vq with shared codebooks, handling #162
1 parent 6105a32 commit 22a0375

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.17.6"
3+
version = "1.17.7"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/residual_vq.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,10 @@ def forward(
315315
if self.implicit_neural_codebook:
316316
maybe_code_transforms = (None, *self.mlps)
317317

318+
# save all inputs across layers, for use during expiration at end under shared codebook setting
319+
320+
all_residuals = []
321+
318322
# go through the layers
319323

320324
for quantizer_index, (vq, maybe_mlp) in enumerate(zip(self.layers, maybe_code_transforms)):
@@ -333,6 +337,10 @@ def forward(
333337
if exists(maybe_mlp):
334338
maybe_mlp = partial(maybe_mlp, condition = quantized_out)
335339

340+
# save for expiration
341+
342+
all_residuals.append(residual)
343+
336344
# vector quantize forward
337345

338346
quantized, *rest = vq(
@@ -360,8 +368,10 @@ def forward(
360368
# if shared codebook, update ema only at end
361369

362370
if self.shared_codebook:
363-
first(self.layers)._codebook.update_ema()
364-
first(self.layers).update_in_place_optimizer()
371+
shared_layer = first(self.layers)
372+
shared_layer._codebook.update_ema()
373+
shared_layer.update_in_place_optimizer()
374+
shared_layer.expire_codes_(torch.cat(all_residuals, dim = -2))
365375

366376
# project out, if needed
367377

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,18 @@ def update_in_place_optimizer(self):
975975
self.in_place_codebook_optimizer.step()
976976
self.in_place_codebook_optimizer.zero_grad()
977977

978+
def maybe_split_heads_from_input(self, x):
979+
if self.heads == 1:
980+
return x
981+
982+
ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d'
983+
return rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = self.heads)
984+
985+
def expire_codes_(self, x):
986+
x = self._codebook.transform_input(x)
987+
x = self.maybe_split_heads_from_input(x)
988+
self._codebook.expire_codes_(x)
989+
978990
def forward(
979991
self,
980992
x,
@@ -1024,9 +1036,7 @@ def forward(
10241036

10251037
# handle multi-headed separate codebooks
10261038

1027-
if is_multiheaded:
1028-
ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d'
1029-
x = rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = heads)
1039+
x = self.maybe_split_heads_from_input(x)
10301040

10311041
# l2norm for cosine sim, otherwise identity
10321042

0 commit comments

Comments
 (0)