1818
1919from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
2020from timm .layers import trunc_normal_tf_ , DropPath , LayerNorm2d , Mlp , SelectAdaptivePool2d , create_conv2d , \
21- use_fused_attn
21+ use_fused_attn , NormMlpClassifierHead , ClassifierHead
2222from ._builder import build_model_with_cfg
2323from ._features_fx import register_notrace_module
2424from ._manipulate import named_apply , checkpoint_seq
@@ -375,13 +375,23 @@ def __init__(
375375 self .stages = nn .Sequential (* stages )
376376
377377 self .num_features = dims [- 1 ]
378- self .norm_pre = norm_layer (self .num_features ) if head_norm_first else nn .Identity ()
379- self .head = nn .Sequential (OrderedDict ([
380- ('global_pool' , SelectAdaptivePool2d (pool_type = global_pool )),
381- ('norm' , nn .Identity () if head_norm_first else norm_layer (self .num_features )),
382- ('flatten' , nn .Flatten (1 ) if global_pool else nn .Identity ()),
383- ('drop' , nn .Dropout (self .drop_rate )),
384- ('fc' , nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ())]))
378+ if head_norm_first :
379+ self .norm_pre = norm_layer (self .num_features )
380+ self .head = ClassifierHead (
381+ self .num_features ,
382+ num_classes ,
383+ pool_type = global_pool ,
384+ drop_rate = self .drop_rate ,
385+ )
386+ else :
387+ self .norm_pre = nn .Identity ()
388+ self .head = NormMlpClassifierHead (
389+ self .num_features ,
390+ num_classes ,
391+ pool_type = global_pool ,
392+ drop_rate = self .drop_rate ,
393+ norm_layer = norm_layer ,
394+ )
385395
386396 named_apply (partial (_init_weights , head_init_scale = head_init_scale ), self )
387397
@@ -406,10 +416,7 @@ def get_classifier(self):
406416 return self .head .fc
407417
408418 def reset_classifier (self , num_classes = 0 , global_pool = None ):
409- if global_pool is not None :
410- self .head .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
411- self .head .flatten = nn .Flatten (1 ) if global_pool else nn .Identity ()
412- self .head .fc = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
419+ self .head .reset (num_classes , global_pool )
413420
414421 def forward_features (self , x ):
415422 x = self .stem (x )
@@ -418,12 +425,7 @@ def forward_features(self, x):
418425 return x
419426
420427 def forward_head (self , x , pre_logits : bool = False ):
421- # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
422- x = self .head .global_pool (x )
423- x = self .head .norm (x )
424- x = self .head .flatten (x )
425- x = self .head .drop (x )
426- return x if pre_logits else self .head .fc (x )
428+ return self .head (x , pre_logits = True ) if pre_logits else self .head (x )
427429
428430 def forward (self , x ):
429431 x = self .forward_features (x )
0 commit comments