Skip to content

Commit 56f20dc

Browse files
committed
cleanup rotation trick
1 parent 089011b commit 56f20dc

File tree

3 files changed

+43
-29
lines changed

3 files changed

+43
-29
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.18.0"
3+
version = "1.18.1"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_readme.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ def exists(v):
55
return v is not None
66

77
@pytest.mark.parametrize('use_cosine_sim', (True, False))
8+
@pytest.mark.parametrize('rotation_trick', (True, False))
89
def test_vq(
9-
use_cosine_sim
10+
use_cosine_sim,
11+
rotation_trick
1012
):
1113
from vector_quantize_pytorch import VectorQuantize
1214

@@ -15,7 +17,8 @@ def test_vq(
1517
codebook_size = 512, # codebook size
1618
decay = 0.8, # the exponential moving average decay, lower means the dictionary will change faster
1719
commitment_weight = 1., # the weight on the commitment loss
18-
use_cosine_sim = use_cosine_sim
20+
use_cosine_sim = use_cosine_sim,
21+
rotation_trick = rotation_trick
1922
)
2023

2124
x = torch.randn(1, 1024, 256)

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@ def noop(*args, **kwargs):
2828
def identity(t):
2929
return t
3030

31-
def l2norm(t):
32-
return F.normalize(t, p = 2, dim = -1)
31+
def l2norm(t, dim = -1, eps = 1e-6):
32+
return F.normalize(t, p = 2, dim = dim, eps = eps)
33+
34+
def safe_div(num, den, eps = 1e-6):
35+
return num / den.clamp(min = eps)
3336

3437
def Sequential(*modules):
3538
modules = [*filter(exists, modules)]
@@ -73,6 +76,19 @@ def lens_to_mask(lens, max_length):
7376
seq = torch.arange(max_length, device = lens.device)
7477
return seq < lens[:, None]
7578

79+
def efficient_rotation_trick_transform(u, q, e):
80+
"""
81+
4.2 in https://arxiv.org/abs/2410.06424
82+
"""
83+
e = rearrange(e, 'b d -> b 1 d')
84+
w = l2norm(u + q, dim = 1).detach()
85+
86+
return (
87+
e -
88+
2 * (e @ rearrange(w, 'b d -> b d 1') @ rearrange(w, 'b d -> b 1 d')) +
89+
2 * (e @ rearrange(u, 'b d -> b d 1').detach() @ rearrange(q, 'b d -> b 1 d').detach())
90+
)
91+
7692
def uniform_init(*shape):
7793
t = torch.empty(shape)
7894
nn.init.kaiming_uniform_(t)
@@ -811,7 +827,7 @@ def __init__(
811827
stochastic_sample_codes = False,
812828
sample_codebook_temp = 1.,
813829
straight_through = False,
814-
rotation_trick = True, # Propagate grads through VQ layer w/ rotation trick: https://arxiv.org/abs/2410.06424.
830+
rotation_trick = True, # Propagate grads through VQ layer w/ rotation trick: https://arxiv.org/abs/2410.06424 by @cfifty
815831
reinmax = False, # using reinmax for improved straight-through, assuming straight through helps at all
816832
sync_codebook = None,
817833
sync_affine_param = False,
@@ -946,13 +962,6 @@ def codebook(self, codes):
946962

947963
self._codebook.embed.copy_(codes)
948964

949-
@staticmethod
950-
def rotation_trick_transform(u, q, e):
951-
w = ((u + q) / torch.norm(u + q, dim=1, keepdim=True)).detach()
952-
e = e - 2 * torch.bmm(torch.bmm(e, w.unsqueeze(-1)), w.unsqueeze(1)) + 2 * torch.bmm(
953-
torch.bmm(e, u.unsqueeze(-1).detach()), q.unsqueeze(1).detach())
954-
return e
955-
956965
def get_codes_from_indices(self, indices):
957966
codebook = self.codebook
958967
is_multiheaded = codebook.ndim > 2
@@ -1103,23 +1112,25 @@ def forward(
11031112

11041113
commit_quantize = maybe_detach(quantize)
11051114

1106-
# Use the rotation trick (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
11071115
if self.rotation_trick:
1108-
init_shape = x.shape
1109-
x = x.reshape(-1, init_shape[-1])
1110-
quantize = quantize.reshape(-1, init_shape[-1])
1111-
1112-
eps = 1e-6 # For numerical stability if any vector is close to 0 norm.
1113-
rot_quantize = self.rotation_trick_transform(
1114-
x / (torch.norm(x, dim=1, keepdim=True) + eps),
1115-
quantize / (torch.norm(quantize, dim=1, keepdim=True) + eps),
1116-
x.unsqueeze(1)).squeeze()
1117-
quantize = rot_quantize * (torch.norm(quantize, dim=1, keepdim=True)
1118-
/ (torch.norm(x, dim=1, keepdim=True) + 1e-6)).detach()
1119-
1120-
x = x.reshape(init_shape)
1121-
quantize = quantize.reshape(init_shape)
1122-
else: # Use STE to get gradients through VQ layer.
1116+
# rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
1117+
x, inverse = pack_one(x, '* d')
1118+
quantize, _ = pack_one(quantize, '* d')
1119+
1120+
norm_x = x.norm(dim = -1, keepdim = True)
1121+
norm_quantize = quantize.norm(dim = -1, keepdim = True)
1122+
1123+
rot_quantize = efficient_rotation_trick_transform(
1124+
safe_div(x, norm_x),
1125+
safe_div(quantize, norm_quantize),
1126+
x
1127+
).squeeze()
1128+
1129+
quantize = rot_quantize * safe_div(norm_quantize, norm_x).detach()
1130+
1131+
x, quantize = inverse(x), inverse(quantize)
1132+
else:
1133+
# standard STE to get gradients through VQ layer.
11231134
quantize = x + (quantize - x).detach()
11241135

11251136
if self.sync_update_v > 0.:

0 commit comments

Comments
 (0)