@@ -250,21 +250,21 @@ def efficient_rotation_trick_transform(u, q, e):
250250 2 * (e @ rearrange (u , 'b d -> b d 1' ).detach () @ rearrange (q , 'b d -> b 1 d' ).detach ())
251251 )
252252
253- def rotate_from_to (src , tgt ):
253+ def rotate_to (src , tgt ):
254254 # rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
255- tgt , inverse = pack_one (tgt , '* d' )
256- src , _ = pack_one (src , '* d' )
255+ src , inverse = pack_one (src , '* d' )
256+ tgt , _ = pack_one (tgt , '* d' )
257257
258- norm_tgt = tgt .norm (dim = - 1 , keepdim = True )
259258 norm_src = src .norm (dim = - 1 , keepdim = True )
259+ norm_tgt = tgt .norm (dim = - 1 , keepdim = True )
260260
261- rotated_src = efficient_rotation_trick_transform (
262- safe_div (tgt , norm_tgt ),
261+ rotated_tgt = efficient_rotation_trick_transform (
263262 safe_div (src , norm_src ),
264- tgt
263+ safe_div (tgt , norm_tgt ),
264+ src
265265 ).squeeze ()
266266
267- rotated = rotated_src * safe_div (norm_src , norm_tgt ).detach ()
267+ rotated = rotated_tgt * safe_div (norm_tgt , norm_src ).detach ()
268268
269269 return inverse (rotated )
270270
@@ -1118,7 +1118,7 @@ def forward(
11181118 commit_quantize = maybe_detach (quantize )
11191119
11201120 if self .rotation_trick :
1121- quantize = rotate_from_to ( quantize , x )
1121+ quantize = rotate_to ( x , quantize )
11221122 else :
11231123 # standard STE to get gradients through VQ layer.
11241124 quantize = x + (quantize - x ).detach ()
0 commit comments