@@ -65,7 +65,8 @@ def __init__(
6565 codebook_scale = 1. , # for residual LFQ, codebook scaled down by 2x at each layer
6666 frac_per_sample_entropy = 1. , # make less than 1. to only use a random fraction of the probs for per sample entropy
6767 use_code_agnostic_commit_loss = False ,
68- projection_has_bias = True
68+ projection_has_bias = True ,
69+ soft_clamp_input_value = None
6970 ):
7071 super ().__init__ ()
7172
@@ -114,6 +115,11 @@ def __init__(
114115 self .commitment_loss_weight = commitment_loss_weight
115116 self .use_code_agnostic_commit_loss = use_code_agnostic_commit_loss
116117
118+ # whether to soft clamp the input value from -value to value
119+
120+ self .soft_clamp_input_value = soft_clamp_input_value
121+ assert not exists (soft_clamp_input_value ) or soft_clamp_input_value >= 1.
122+
117123 # for no auxiliary loss, during inference
118124
119125 self .register_buffer ('mask' , 2 ** torch .arange (codebook_dim - 1 , - 1 , - 1 ))
@@ -195,6 +201,12 @@ def forward(
195201
196202 x = self .project_in (x )
197203
204+ # maybe soft clamp
205+
206+ if exists (self .soft_clamp_input_value ):
207+ clamp_value = self .soft_clamp_input_value
208+ x = (x / clamp_value ).tanh () * clamp_value
209+
198210 # split out number of codebooks
199211
200212 x = rearrange (x , 'b n (c d) -> b n c d' , c = self .num_codebooks )
0 commit comments