77"""
88
99from math import log2 , ceil
10+ from functools import partial
1011from collections import namedtuple
1112
1213import torch
@@ -51,14 +52,20 @@ def entropy(prob):
5152# cosine sim linear
5253
5354class CosineSimLinear (Module ):
54- def __init__ (self , dim_in , dim_out , ** kwargs ):
55+ def __init__ (
56+ self ,
57+ dim_in ,
58+ dim_out ,
59+ scale = 1.
60+ ):
5561 super ().__init__ ()
62+ self .scale = scale
5663 self .weight = nn .Parameter (torch .randn (dim_in , dim_out ))
5764
5865 def forward (self , x ):
5966 x = F .normalize (x , dim = - 1 )
6067 w = F .normalize (self .weight , dim = 0 )
61- return x @ w
68+ return ( x @ w ) * self . scale
6269
6370# class
6471
@@ -79,7 +86,8 @@ def __init__(
7986 use_code_agnostic_commit_loss = False ,
8087 projection_has_bias = True ,
8188 soft_clamp_input_value = None ,
82- cosine_sim_project_in = False
89+ cosine_sim_project_in = False ,
90+ cosine_sim_project_in_scale = None
8391 ):
8492 super ().__init__ ()
8593
@@ -96,8 +104,13 @@ def __init__(
96104
97105 has_projections = dim != codebook_dims
98106
99- project_in_klass = CosineSimLinear if cosine_sim_project_in else nn .Linear
100- self .project_in = project_in_klass (dim , codebook_dims , bias = projection_has_bias ) if has_projections else nn .Identity ()
107+ if cosine_sim_project_in :
108+ cosine_sim_project_in = default (cosine_sim_project_in_scale , codebook_scale )
109+ project_in_klass = partial (CosineSimLinear , scale = cosine_sim_project_in )
110+ else :
111+ project_in_klass = partial (nn .Linear , bias = projection_has_bias )
112+
113+ self .project_in = project_in_klass (dim , codebook_dims ) if has_projections else nn .Identity ()
101114 self .project_out = nn .Linear (codebook_dims , dim , bias = projection_has_bias ) if has_projections else nn .Identity ()
102115 self .has_projections = has_projections
103116
0 commit comments