Skip to content

Commit 97b9a87

Browse files
committed
add SimVQ with or without rotation trick https://arxiv.org/abs/2411.02038
1 parent 01b45eb commit 97b9a87

File tree

7 files changed

+45
-18
lines changed

7 files changed

+45
-18
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,3 +714,12 @@ assert loss.item() >= 0
714714
url = {https://api.semanticscholar.org/CorpusID:273229218}
715715
}
716716
```
717+
718+
```bibtex
719+
@inproceedings{Zhu2024AddressingRC,
720+
title = {Addressing Representation Collapse in Vector Quantized Models with One Linear Layer},
721+
author = {Yongxin Zhu and Bocheng Li and Yifei Xin and Linli Xu},
722+
year = {2024},
723+
url = {https://api.semanticscholar.org/CorpusID:273812459}
724+
}
725+
```

examples/autoencoder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def iterate_dataset(data_loader):
7171
shuffle=True,
7272
)
7373

74-
print("baseline")
7574
torch.random.manual_seed(seed)
7675

7776
model = SimpleVQAutoEncoder(

examples/autoencoder_fsq.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def iterate_dataset(data_loader):
7676
shuffle=True,
7777
)
7878

79-
print("baseline")
8079
torch.random.manual_seed(seed)
8180
model = SimpleFSQAutoEncoder(levels).to(device)
8281
opt = torch.optim.AdamW(model.parameters(), lr=lr)

examples/autoencoder_lfq.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@ def iterate_dataset(data_loader):
8787
shuffle=True,
8888
)
8989

90-
print("baseline")
91-
9290
torch.random.manual_seed(seed)
9391

9492
model = LFQAutoEncoder(

examples/autoencoder_sim_vq.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ def SimVQAutoEncoder(**vq_kwargs):
2222
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
2323
nn.MaxPool2d(kernel_size=2, stride=2),
2424
nn.GELU(),
25-
nn.Conv2d(16, 64, kernel_size=3, stride=1, padding=1),
25+
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
2626
nn.MaxPool2d(kernel_size=2, stride=2),
27-
SimVQ(dim=64, accept_image_fmap = True, **vq_kwargs),
27+
SimVQ(dim=32, accept_image_fmap = True, **vq_kwargs),
2828
nn.Upsample(scale_factor=2, mode="nearest"),
29-
nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1),
29+
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
3030
nn.GELU(),
3131
nn.Upsample(scale_factor=2, mode="nearest"),
3232
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
@@ -73,11 +73,11 @@ def iterate_dataset(data_loader):
7373
shuffle=True,
7474
)
7575

76-
print("baseline")
7776
torch.random.manual_seed(seed)
7877

7978
model = SimVQAutoEncoder(
80-
codebook_size=num_codes,
79+
codebook_size = num_codes,
80+
rotation_trick = rotation_trick
8181
).to(device)
8282

8383
opt = torch.optim.AdamW(model.parameters(), lr=lr)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.19.5"
3+
version = "1.20.0"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/sim_vq.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import torch.nn.functional as F
77

88
from einx import get_at
9-
from einops import einsum, rearrange, repeat, reduce, pack, unpack
9+
from einops import rearrange, pack, unpack
10+
11+
from vector_quantize_pytorch.vector_quantize_pytorch import rotate_from_to
1012

1113
# helper functions
1214

@@ -37,7 +39,9 @@ def __init__(
3739
dim,
3840
codebook_size,
3941
init_fn: Callable = identity,
40-
accept_image_fmap = False
42+
accept_image_fmap = False,
43+
rotation_trick = True, # works even better with rotation trick turned on, with no asymmetric commit loss or straight through
44+
commit_loss_input_to_quantize_weight = 0.25,
4145
):
4246
super().__init__()
4347
self.accept_image_fmap = accept_image_fmap
@@ -50,6 +54,17 @@ def __init__(
5054
self.codebook_to_codes = nn.Linear(dim, dim, bias = False)
5155
self.register_buffer('codebook', codebook)
5256

57+
58+
# whether to use rotation trick from Fifty et al.
59+
# https://arxiv.org/abs/2410.06424
60+
61+
self.rotation_trick = rotation_trick
62+
self.register_buffer('zero', torch.tensor(0.), persistent = False)
63+
64+
# commit loss weighting - weighing input to quantize a bit less is crucial for it to work
65+
66+
self.commit_loss_input_to_quantize_weight = commit_loss_input_to_quantize_weight
67+
5368
def forward(
5469
self,
5570
x
@@ -68,14 +83,21 @@ def forward(
6883

6984
quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices)
7085

71-
# commit loss
86+
if self.rotation_trick:
87+
# rotation trick from @cfifty
88+
89+
quantized = rotate_from_to(quantized, x)
90+
91+
commit_loss = self.zero
92+
else:
93+
# commit loss and straight through, as was done in the paper
7294

73-
commit_loss = (
74-
0.25 * F.mse_loss(x, quantized.detach()) +
75-
F.mse_loss(x.detach(), quantized)
76-
)
95+
commit_loss = (
96+
F.mse_loss(x, quantized.detach()) * self.commit_loss_input_to_quantize_weight +
97+
F.mse_loss(x.detach(), quantized)
98+
)
7799

78-
quantized = (quantized - x).detach() + x
100+
quantized = (quantized - x).detach() + x
79101

80102
if self.accept_image_fmap:
81103
quantized = inverse_pack(quantized)

0 commit comments

Comments
 (0)