@@ -20,6 +20,18 @@ def divisible_by(val, d):
2020
2121# helper classes
2222
23+ class ChanLayerNorm (nn .Module ):
24+ def __init__ (self , dim , eps = 1e-5 ):
25+ super ().__init__ ()
26+ self .eps = eps
27+ self .g = nn .Parameter (torch .ones (1 , dim , 1 , 1 ))
28+ self .b = nn .Parameter (torch .zeros (1 , dim , 1 , 1 ))
29+
30+ def forward (self , x ):
31+ var = torch .var (x , dim = 1 , unbiased = False , keepdim = True )
32+ mean = torch .mean (x , dim = 1 , keepdim = True )
33+ return (x - mean ) / (var + self .eps ).sqrt () * self .g + self .b
34+
2335class Downsample (nn .Module ):
2436 def __init__ (self , dim_in , dim_out ):
2537 super ().__init__ ()
@@ -212,10 +224,10 @@ def __init__(
212224 if tokenize_local_3_conv :
213225 self .local_encoder = nn .Sequential (
214226 nn .Conv2d (3 , init_dim , 3 , 2 , 1 ),
215- nn . LayerNorm (init_dim ),
227+ ChanLayerNorm (init_dim ),
216228 nn .GELU (),
217229 nn .Conv2d (init_dim , init_dim , 3 , 2 , 1 ),
218- nn . LayerNorm (init_dim ),
230+ ChanLayerNorm (init_dim ),
219231 nn .GELU (),
220232 nn .Conv2d (init_dim , init_dim , 3 , 1 , 1 )
221233 )
0 commit comments