@@ -29,6 +29,7 @@ def __init__(
2929 pos_embed : str = '' ,
3030 pool_type : str = 'token' ,
3131 norm_layer : Optional [nn .Module ] = None ,
32+ act_layer : Optional [nn .Module ] = nn .GELU ,
3233 drop : float = 0.0 ,
3334 ):
3435 super ().__init__ ()
@@ -54,13 +55,18 @@ def __init__(
5455
5556 self .q = nn .Linear (embed_dim , embed_dim , bias = qkv_bias )
5657 self .kv = nn .Linear (embed_dim , embed_dim * 2 , bias = qkv_bias )
57- self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
58- self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
58+ if qk_norm :
59+ qk_norm_layer = norm_layer or nn .LayerNorm
60+ self .q_norm = qk_norm_layer (self .head_dim )
61+ self .k_norm = qk_norm_layer (self .head_dim )
62+ else :
63+ self .q_norm = nn .Identity ()
64+ self .k_norm = nn .Identity ()
5965 self .proj = nn .Linear (embed_dim , embed_dim )
6066 self .proj_drop = nn .Dropout (drop )
6167
6268 self .norm = norm_layer (out_features ) if norm_layer is not None else nn .Identity ()
63- self .mlp = Mlp (embed_dim , int (embed_dim * mlp_ratio ))
69+ self .mlp = Mlp (embed_dim , int (embed_dim * mlp_ratio ), act_layer = act_layer )
6470
6571 self .init_weights ()
6672
0 commit comments