@@ -818,6 +818,7 @@ def __init__(
818818 manual_ema_update = False ,
819819 learnable_codebook = False ,
820820 in_place_codebook_optimizer : Callable [..., Optimizer ] = None , # Optimizer used to update the codebook embedding if using learnable_codebook
821+ manual_in_place_optimizer_update = False ,
821822 affine_param = False ,
822823 affine_param_batch_decay = 0.99 ,
823824 affine_param_codebook_decay = 0.9 ,
@@ -913,6 +914,7 @@ def __init__(
913914 self ._codebook = codebook_class (** codebook_kwargs )
914915
915916 self .in_place_codebook_optimizer = in_place_codebook_optimizer (self ._codebook .parameters ()) if exists (in_place_codebook_optimizer ) else None
917+ self .manual_in_place_optimizer_update = manual_in_place_optimizer_update
916918
917919 self .codebook_size = codebook_size
918920
@@ -966,6 +968,13 @@ def get_output_from_indices(self, indices):
966968 codes = self .get_codes_from_indices (indices )
967969 return self .project_out (codes )
968970
971+ def update_in_place_optimizer (self ):
972+ if not exists (self .in_place_codebook_optimizer ):
973+ return
974+
975+ self .in_place_codebook_optimizer .step ()
976+ self .in_place_codebook_optimizer .zero_grad ()
977+
969978 def forward (
970979 self ,
971980 x ,
@@ -1057,8 +1066,9 @@ def forward(
10571066 loss = F .mse_loss (quantize , x .detach ())
10581067
10591068 loss .backward ()
1060- self .in_place_codebook_optimizer .step ()
1061- self .in_place_codebook_optimizer .zero_grad ()
1069+
1070+ if not self .manual_in_place_optimizer_update :
1071+ self .update_in_place_optimizer ()
10621072
10631073 inplace_optimize_loss = loss
10641074
0 commit comments