@@ -28,8 +28,11 @@ def noop(*args, **kwargs):
2828def identity (t ):
2929 return t
3030
31- def l2norm (t ):
32- return F .normalize (t , p = 2 , dim = - 1 )
31+ def l2norm (t , dim = - 1 , eps = 1e-6 ):
32+ return F .normalize (t , p = 2 , dim = dim , eps = eps )
33+
34+ def safe_div (num , den , eps = 1e-6 ):
35+ return num / den .clamp (min = eps )
3336
3437def Sequential (* modules ):
3538 modules = [* filter (exists , modules )]
@@ -73,6 +76,19 @@ def lens_to_mask(lens, max_length):
7376 seq = torch .arange (max_length , device = lens .device )
7477 return seq < lens [:, None ]
7578
79+ def efficient_rotation_trick_transform (u , q , e ):
80+ """
81+ 4.2 in https://arxiv.org/abs/2410.06424
82+ """
83+ e = rearrange (e , 'b d -> b 1 d' )
84+ w = l2norm (u + q , dim = 1 ).detach ()
85+
86+ return (
87+ e -
88+ 2 * (e @ rearrange (w , 'b d -> b d 1' ) @ rearrange (w , 'b d -> b 1 d' )) +
89+ 2 * (e @ rearrange (u , 'b d -> b d 1' ).detach () @ rearrange (q , 'b d -> b 1 d' ).detach ())
90+ )
91+
7692def uniform_init (* shape ):
7793 t = torch .empty (shape )
7894 nn .init .kaiming_uniform_ (t )
@@ -811,7 +827,7 @@ def __init__(
811827 stochastic_sample_codes = False ,
812828 sample_codebook_temp = 1. ,
813829 straight_through = False ,
814- rotation_trick = True , # Propagate grads through VQ layer w/ rotation trick: https://arxiv.org/abs/2410.06424.
830+ rotation_trick = True , # Propagate grads through VQ layer w/ rotation trick: https://arxiv.org/abs/2410.06424 by @cfifty
815831 reinmax = False , # using reinmax for improved straight-through, assuming straight through helps at all
816832 sync_codebook = None ,
817833 sync_affine_param = False ,
@@ -946,13 +962,6 @@ def codebook(self, codes):
946962
947963 self ._codebook .embed .copy_ (codes )
948964
949- @staticmethod
950- def rotation_trick_transform (u , q , e ):
951- w = ((u + q ) / torch .norm (u + q , dim = 1 , keepdim = True )).detach ()
952- e = e - 2 * torch .bmm (torch .bmm (e , w .unsqueeze (- 1 )), w .unsqueeze (1 )) + 2 * torch .bmm (
953- torch .bmm (e , u .unsqueeze (- 1 ).detach ()), q .unsqueeze (1 ).detach ())
954- return e
955-
956965 def get_codes_from_indices (self , indices ):
957966 codebook = self .codebook
958967 is_multiheaded = codebook .ndim > 2
@@ -1103,23 +1112,25 @@ def forward(
11031112
11041113 commit_quantize = maybe_detach (quantize )
11051114
1106- # Use the rotation trick (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
11071115 if self .rotation_trick :
1108- init_shape = x .shape
1109- x = x .reshape (- 1 , init_shape [- 1 ])
1110- quantize = quantize .reshape (- 1 , init_shape [- 1 ])
1111-
1112- eps = 1e-6 # For numerical stability if any vector is close to 0 norm.
1113- rot_quantize = self .rotation_trick_transform (
1114- x / (torch .norm (x , dim = 1 , keepdim = True ) + eps ),
1115- quantize / (torch .norm (quantize , dim = 1 , keepdim = True ) + eps ),
1116- x .unsqueeze (1 )).squeeze ()
1117- quantize = rot_quantize * (torch .norm (quantize , dim = 1 , keepdim = True )
1118- / (torch .norm (x , dim = 1 , keepdim = True ) + 1e-6 )).detach ()
1119-
1120- x = x .reshape (init_shape )
1121- quantize = quantize .reshape (init_shape )
1122- else : # Use STE to get gradients through VQ layer.
1116+ # rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
1117+ x , inverse = pack_one (x , '* d' )
1118+ quantize , _ = pack_one (quantize , '* d' )
1119+
1120+ norm_x = x .norm (dim = - 1 , keepdim = True )
1121+ norm_quantize = quantize .norm (dim = - 1 , keepdim = True )
1122+
1123+ rot_quantize = efficient_rotation_trick_transform (
1124+ safe_div (x , norm_x ),
1125+ safe_div (quantize , norm_quantize ),
1126+ x
1127+ ).squeeze ()
1128+
1129+ quantize = rot_quantize * safe_div (norm_quantize , norm_x ).detach ()
1130+
1131+ x , quantize = inverse (x ), inverse (quantize )
1132+ else :
1133+ # standard STE to get gradients through VQ layer.
11231134 quantize = x + (quantize - x ).detach ()
11241135
11251136 if self .sync_update_v > 0. :
0 commit comments