@@ -48,6 +48,18 @@ def log(t, eps = 1e-5):
4848def entropy (prob ):
4949 return (- prob * log (prob )).sum (dim = - 1 )
5050
51+ # cosine sim linear
52+
53+ class CosineSimLinear (Module ):
54+ def __init__ (self , dim_in , dim_out , ** kwargs ):
55+ super ().__init__ ()
56+ self .weight = nn .Parameter (torch .randn (dim_in , dim_out ))
57+
58+ def forward (self , x ):
59+ x = F .normalize (x , dim = - 1 )
60+ w = F .normalize (self .weight , dim = 0 )
61+ return x @ w
62+
5163# class
5264
5365class LFQ (Module ):
@@ -66,7 +78,8 @@ def __init__(
6678 frac_per_sample_entropy = 1. , # make less than 1. to only use a random fraction of the probs for per sample entropy
6779 use_code_agnostic_commit_loss = False ,
6880 projection_has_bias = True ,
69- soft_clamp_input_value = None
81+ soft_clamp_input_value = None ,
82+ cosine_sim_project_in = False
7083 ):
7184 super ().__init__ ()
7285
@@ -82,7 +95,9 @@ def __init__(
8295 dim = default (dim , codebook_dims )
8396
8497 has_projections = dim != codebook_dims
85- self .project_in = nn .Linear (dim , codebook_dims , bias = projection_has_bias ) if has_projections else nn .Identity ()
98+
99+ project_in_klass = CosineSimLinear if cosine_sim_project_in else nn .Linear
100+ self .project_in = project_in_klass (dim , codebook_dims , bias = projection_has_bias ) if has_projections else nn .Identity ()
86101 self .project_out = nn .Linear (codebook_dims , dim , bias = projection_has_bias ) if has_projections else nn .Identity ()
87102 self .has_projections = has_projections
88103
0 commit comments