Skip to content

Commit 43a6e56

Browse files
committed
address #233
1 parent 32425e5 commit 43a6e56

File tree

4 files changed

+18
-8
lines changed

4 files changed

+18
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.27.6"
3+
version = "1.27.7"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_beam.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import pytest
2+
param = pytest.mark.parametrize
3+
14
import torch
25
from vector_quantize_pytorch import VectorQuantize
36

@@ -43,20 +46,24 @@ def test_topk_and_manual_ema_update():
4346
assert torch.allclose(vq1._codebook.embed_avg, vq2._codebook.embed_avg)
4447
assert torch.allclose(vq1.codebook, vq2.codebook)
4548

46-
def test_beam_search():
49+
@param('codebook_dim', (256, 128))
50+
def test_beam_search(
51+
codebook_dim
52+
):
4753
import torch
4854
from vector_quantize_pytorch import ResidualVQ
4955

5056
residual_vq = ResidualVQ(
5157
dim = 256,
58+
codebook_dim = codebook_dim,
5259
num_quantizers = 8, # specify number of quantizers
5360
codebook_size = 1024, # codebook size
5461
quantize_dropout = True,
5562
beam_size = 2,
5663
eval_beam_size = 3
5764
)
5865

59-
x = torch.randn(1, 1024, 256)
66+
x = torch.randn(1, 1024, 256).requires_grad_()
6067

6168
for _ in range(5):
6269
quantized, indices, commit_loss = residual_vq(x)

vector_quantize_pytorch/residual_vq.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def __init__(
191191

192192
codebook_dim = default(codebook_dim, dim)
193193
codebook_input_dim = codebook_dim * heads
194+
self.codebook_dim = codebook_dim
194195

195196
requires_projection = codebook_input_dim != dim
196197
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
@@ -223,6 +224,7 @@ def __init__(
223224
self.num_quantizers = num_quantizers
224225

225226
self.codebook_sizes = codebook_sizes
227+
226228
self.uniform_codebook_size = len(unique(codebook_sizes)) == 1
227229

228230
# define vq across layers
@@ -287,10 +289,6 @@ def __init__(
287289
@property
288290
def codebook_size(self):
289291
return self.layers[0].codebook_size
290-
291-
@property
292-
def codebook_dim(self):
293-
return self.layers[0].codebook_dim
294292

295293
@property
296294
def codebooks(self):
@@ -423,7 +421,7 @@ def forward(
423421

424422
# save all inputs across layers, for use during expiration at end under shared codebook setting, or ema update during beam search
425423

426-
all_residuals = torch.empty((*input_shape[:-1], 0, input_shape[-1]), dtype = residual.dtype, device = device)
424+
all_residuals = torch.empty((*input_shape[:-1], 0, self.codebook_dim), dtype = residual.dtype, device = device)
427425

428426
# maybe prepare beam search
429427

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,11 @@ def forward(
11341134

11351135
commit_quantize = maybe_detach(quantize)
11361136

1137+
# maybe expand input if returning topk codes
1138+
1139+
if exists(topk):
1140+
x = repeat(x, '... d -> ... k d', k = topk)
1141+
11371142
# spare rotation trick calculation if inputs do not need gradients
11381143

11391144
if input_requires_grad:

0 commit comments

Comments
 (0)