File tree Expand file tree Collapse file tree 3 files changed +17
-8
lines changed
Expand file tree Collapse file tree 3 files changed +17
-8
lines changed Original file line number Diff line number Diff line change 11[project ]
22name = " vector-quantize-pytorch"
3- version = " 1.21.0 "
3+ version = " 1.21.1 "
44description = " Vector Quantization - Pytorch"
55authors = [
66 { name = " Phil Wang" , email = " lucidrains@gmail.com" }
Original file line number Diff line number Diff line change @@ -6,9 +6,11 @@ def exists(v):
66
77@pytest .mark .parametrize ('use_cosine_sim' , (True , False ))
88@pytest .mark .parametrize ('rotation_trick' , (True , False ))
9+ @pytest .mark .parametrize ('input_requires_grad' , (True , False ))
910def test_vq (
1011 use_cosine_sim ,
11- rotation_trick
12+ rotation_trick ,
13+ input_requires_grad
1214):
1315 from vector_quantize_pytorch import VectorQuantize
1416
@@ -22,6 +24,10 @@ def test_vq(
2224 )
2325
2426 x = torch .randn (1 , 1024 , 256 )
27+
28+ if input_requires_grad :
29+ x .requires_grad_ ()
30+
2531 quantized , indices , commit_loss = vq (x )
2632
2733def test_vq_eval ():
Original file line number Diff line number Diff line change @@ -1023,7 +1023,7 @@ def forward(
10231023 return_loss_breakdown = False ,
10241024 codebook_transform_fn : Callable | None = None
10251025 ):
1026- orig_input = x
1026+ orig_input , input_requires_grad = x , x . requires_grad
10271027
10281028 # handle masking, either passed in as `mask` or `lens`
10291029
@@ -1117,11 +1117,14 @@ def forward(
11171117
11181118 commit_quantize = maybe_detach (quantize )
11191119
1120- if self .rotation_trick :
1121- quantize = rotate_to (x , quantize )
1122- else :
1123- # standard STE to get gradients through VQ layer.
1124- quantize = x + (quantize - x ).detach ()
1120+ # spare rotation trick calculation if inputs do not need gradients
1121+
1122+ if input_requires_grad :
1123+ if self .rotation_trick :
1124+ quantize = rotate_to (x , quantize )
1125+ else :
1126+ # standard STE to get gradients through VQ layer.
1127+ quantize = x + (quantize - x ).detach ()
11251128
11261129 if self .sync_update_v > 0. :
11271130 # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
You can’t perform that action at this time.
0 commit comments