@@ -62,8 +62,9 @@ def __init__(
6262 straight_through_activation = nn .Identity (),
6363 num_codebooks = 1 ,
6464 keep_num_codebooks_dim = None ,
65- codebook_scale = 1. , # for residual LFQ, codebook scaled down by 2x at each layer
66- frac_per_sample_entropy = 1. # make less than 1. to only use a random fraction of the probs for per sample entropy
65+ codebook_scale = 1. , # for residual LFQ, codebook scaled down by 2x at each layer
66+ frac_per_sample_entropy = 1. , # make less than 1. to only use a random fraction of the probs for per sample entropy
67+ use_code_agnostic_commit_loss = False
6768 ):
6869 super ().__init__ ()
6970
@@ -110,6 +111,7 @@ def __init__(
110111 # commitment loss
111112
112113 self .commitment_loss_weight = commitment_loss_weight
114+ self .use_code_agnostic_commit_loss = use_code_agnostic_commit_loss
113115
114116 # for no auxiliary loss, during inference
115117
@@ -259,8 +261,19 @@ def forward(
259261
260262 # commit loss
261263
262- if self .training :
263- commit_loss = F .mse_loss (original_input , quantized .detach (), reduction = 'none' )
264+ if self .training and self .commitment_loss_weight > 0. :
265+
266+ if self .use_code_agnostic_commit_loss :
267+ # credit goes to @MattMcPartlon for sharing this in https://github.com/lucidrains/vector-quantize-pytorch/issues/120#issuecomment-2095089337
268+
269+ commit_loss = F .mse_loss (
270+ original_input ** 2 ,
271+ codebook_value ** 2 ,
272+ reduction = 'none'
273+ )
274+
275+ else :
276+ commit_loss = F .mse_loss (original_input , quantized .detach (), reduction = 'none' )
264277
265278 if exists (mask ):
266279 commit_loss = commit_loss [mask ]
0 commit comments