Skip to content

Commit 3937aa2

Browse files
committed
handle expiration of codes for rvq beam search scenario
1 parent 8b36a03 commit 3937aa2

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.27.0"
3+
version = "1.27.1"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/residual_vq.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -545,13 +545,19 @@ def forward(
545545

546546
# handle updating ema
547547

548-
if self.vq_is_ema_updating:
548+
if self.training:
549549
for vq, layer_input, indices in zip(self.layers, all_residuals.unbind(dim = -2), all_indices.unbind(dim = -1)): # in the case of quantize dropout, zip will terminate with the shorter sequence, which should be all_residuals
550-
vq.update_ema_indices(layer_input, indices, mask = mask)
550+
551+
if self.vq_is_ema_updating:
552+
vq.update_ema_indices(layer_input, indices, mask = mask)
553+
554+
batch_samples = layer_input[mask] if exists(mask) else layer_input
555+
vq.expire_codes_(batch_samples)
551556

552557
# if shared codebook, update ema only at end
553558

554-
if self.training and self.shared_codebook and not is_beam_search:
559+
if self.training and self.shared_codebook:
560+
555561
shared_layer = first(self.layers)
556562
shared_layer._codebook.update_ema()
557563
shared_layer.update_in_place_optimizer()

0 commit comments

Comments
 (0)