@@ -40,8 +40,8 @@ def __init__(
4040 codebook_size ,
4141 init_fn : Callable = identity ,
4242 accept_image_fmap = False ,
43- rotation_trick = True , # works even better with rotation trick turned on, with no asymmetric commit loss or straight through
44- commit_loss_input_to_quantize_weight = 0.25 ,
43+ rotation_trick = True , # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize
44+ input_to_quantize_commit_loss_weight = 0.25 ,
4545 ):
4646 super ().__init__ ()
4747 self .accept_image_fmap = accept_image_fmap
@@ -59,11 +59,10 @@ def __init__(
5959 # https://arxiv.org/abs/2410.06424
6060
6161 self .rotation_trick = rotation_trick
62- self .register_buffer ('zero' , torch .tensor (0. ), persistent = False )
6362
6463 # commit loss weighting - weighing input to quantize a bit less is crucial for it to work
6564
66- self .commit_loss_input_to_quantize_weight = commit_loss_input_to_quantize_weight
65+ self .input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight
6766
6867 def forward (
6968 self ,
@@ -83,18 +82,18 @@ def forward(
8382
8483 quantized = get_at ('[c] d, b n -> b n d' , implicit_codebook , indices )
8584
85+ # commit loss and straight through, as was done in the paper
86+
87+ commit_loss = F .mse_loss (x .detach (), quantized )
88+
8689 if self .rotation_trick :
8790 # rotation trick from @cfifty
88-
8991 quantized = rotate_from_to (quantized , x )
90-
91- commit_loss = self .zero
9292 else :
93- # commit loss and straight through, as was done in the paper
9493
9594 commit_loss = (
96- F . mse_loss ( x , quantized . detach ()) * self . commit_loss_input_to_quantize_weight +
97- F .mse_loss (x .detach (), quantized )
95+ commit_loss +
96+ F .mse_loss (x , quantized .detach ()) * self . input_to_quantize_commit_loss_weight
9897 )
9998
10099 quantized = (quantized - x ).detach () + x
0 commit comments