@@ -280,6 +280,7 @@ def __init__(
280280 gumbel_sample = gumbel_sample ,
281281 sample_codebook_temp = 1. ,
282282 ema_update = True ,
283+ manual_ema_update = False ,
283284 affine_param = False ,
284285 sync_affine_param = False ,
285286 affine_param_batch_decay = 0.99 ,
@@ -290,6 +291,7 @@ def __init__(
290291
291292 self .decay = decay
292293 self .ema_update = ema_update
294+ self .manual_ema_update = manual_ema_update
293295
294296 init_fn = uniform_init if not kmeans_init else torch .zeros
295297 embed = init_fn (num_codebooks , codebook_size , dim )
@@ -458,6 +460,12 @@ def expire_codes_(self, batch_samples):
458460 batch_samples = rearrange (batch_samples , 'h ... d -> h (...) d' )
459461 self .replace (batch_samples , batch_mask = expired_codes )
460462
463+ def update_ema (self ):
464+ cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum (dim = - 1 , keepdim = True )
465+
466+ embed_normalized = self .embed_avg / rearrange (cluster_size , '... -> ... 1' )
467+ self .embed .data .copy_ (embed_normalized )
468+
461469 @autocast ('cuda' , enabled = False )
462470 def forward (
463471 self ,
@@ -551,11 +559,9 @@ def forward(
551559
552560 ema_inplace (self .embed_avg .data , embed_sum , self .decay )
553561
554- cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum (dim = - 1 , keepdim = True )
555-
556- embed_normalized = self .embed_avg / rearrange (cluster_size , '... -> ... 1' )
557- self .embed .data .copy_ (embed_normalized )
558- self .expire_codes_ (x )
562+ if not self .manual_ema_update :
563+ self .update_ema ()
564+ self .expire_codes_ (x )
559565
560566 if needs_codebook_dim :
561567 quantize , embed_ind = map (lambda t : rearrange (t , '1 ... -> ...' ), (quantize , embed_ind ))
@@ -582,11 +588,14 @@ def __init__(
582588 gumbel_sample = gumbel_sample ,
583589 sample_codebook_temp = 1. ,
584590 ema_update = True ,
591+ manual_ema_update = False
585592 ):
586593 super ().__init__ ()
587594 self .transform_input = l2norm
588595
589596 self .ema_update = ema_update
597+ self .manual_ema_update = manual_ema_update
598+
590599 self .decay = decay
591600
592601 if not kmeans_init :
@@ -671,6 +680,14 @@ def expire_codes_(self, batch_samples):
671680 batch_samples = rearrange (batch_samples , 'h ... d -> h (...) d' )
672681 self .replace (batch_samples , batch_mask = expired_codes )
673682
683+ def update_ema (self ):
684+ cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum (dim = - 1 , keepdim = True )
685+
686+ embed_normalized = self .embed_avg / rearrange (cluster_size , '... -> ... 1' )
687+ embed_normalized = l2norm (embed_normalized )
688+
689+ self .embed .data .copy_ (embed_normalized )
690+
674691 @autocast ('cuda' , enabled = False )
675692 def forward (
676693 self ,
@@ -746,13 +763,9 @@ def forward(
746763
747764 ema_inplace (self .embed_avg .data , embed_sum , self .decay )
748765
749- cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum (dim = - 1 , keepdim = True )
750-
751- embed_normalized = self .embed_avg / rearrange (cluster_size , '... -> ... 1' )
752- embed_normalized = l2norm (embed_normalized )
753-
754- self .embed .data .copy_ (embed_normalized )
755- self .expire_codes_ (x )
766+ if not self .manual_ema_update :
767+ self .update_ema ()
768+ self .expire_codes_ (x )
756769
757770 if needs_codebook_dim :
758771 quantize , embed_ind = map (lambda t : rearrange (t , '1 ... -> ...' ), (quantize , embed_ind ))
@@ -802,6 +815,7 @@ def __init__(
802815 sync_codebook = None ,
803816 sync_affine_param = False ,
804817 ema_update = True ,
818+ manual_ema_update = False ,
805819 learnable_codebook = False ,
806820 in_place_codebook_optimizer : Callable [..., Optimizer ] = None , # Optimizer used to update the codebook embedding if using learnable_codebook
807821 affine_param = False ,
@@ -881,7 +895,8 @@ def __init__(
881895 learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook ,
882896 sample_codebook_temp = sample_codebook_temp ,
883897 gumbel_sample = gumbel_sample_fn ,
884- ema_update = ema_update
898+ ema_update = ema_update ,
899+ manual_ema_update = manual_ema_update
885900 )
886901
887902 if affine_param :
0 commit comments