66import torch .nn .functional as F
77
88from einx import get_at
9- from einops import einsum , rearrange , repeat , reduce , pack , unpack
9+ from einops import rearrange , pack , unpack
10+
11+ from vector_quantize_pytorch .vector_quantize_pytorch import rotate_from_to
1012
1113# helper functions
1214
@@ -37,7 +39,9 @@ def __init__(
3739 dim ,
3840 codebook_size ,
3941 init_fn : Callable = identity ,
40- accept_image_fmap = False
42+ accept_image_fmap = False ,
43+ rotation_trick = True , # works even better with rotation trick turned on, with no asymmetric commit loss or straight through
44+ commit_loss_input_to_quantize_weight = 0.25 ,
4145 ):
4246 super ().__init__ ()
4347 self .accept_image_fmap = accept_image_fmap
@@ -50,6 +54,17 @@ def __init__(
5054 self .codebook_to_codes = nn .Linear (dim , dim , bias = False )
5155 self .register_buffer ('codebook' , codebook )
5256
57+
58+ # whether to use rotation trick from Fifty et al.
59+ # https://arxiv.org/abs/2410.06424
60+
61+ self .rotation_trick = rotation_trick
62+ self .register_buffer ('zero' , torch .tensor (0. ), persistent = False )
63+
64+ # commit loss weighting - weighing input to quantize a bit less is crucial for it to work
65+
66+ self .commit_loss_input_to_quantize_weight = commit_loss_input_to_quantize_weight
67+
5368 def forward (
5469 self ,
5570 x
@@ -68,14 +83,21 @@ def forward(
6883
6984 quantized = get_at ('[c] d, b n -> b n d' , implicit_codebook , indices )
7085
71- # commit loss
86+ if self .rotation_trick :
87+ # rotation trick from @cfifty
88+
89+ quantized = rotate_from_to (quantized , x )
90+
91+ commit_loss = self .zero
92+ else :
93+ # commit loss and straight through, as was done in the paper
7294
73- commit_loss = (
74- 0.25 * F .mse_loss (x , quantized .detach ()) +
75- F .mse_loss (x .detach (), quantized )
76- )
95+ commit_loss = (
96+ F .mse_loss (x , quantized .detach ()) * self . commit_loss_input_to_quantize_weight +
97+ F .mse_loss (x .detach (), quantized )
98+ )
7799
78- quantized = (quantized - x ).detach () + x
100+ quantized = (quantized - x ).detach () + x
79101
80102 if self .accept_image_fmap :
81103 quantized = inverse_pack (quantized )
0 commit comments