66from .adaptive_avgmax_pool import *
77from timm .data import IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
88
9- _models = ['inception_resnet_v2' ]
9+ _models = ['inception_resnet_v2' , 'ens_adv_inception_resnet_v2' ]
1010__all__ = ['InceptionResnetV2' ] + _models
1111
1212default_cfgs = {
13+ # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
1314 'inception_resnet_v2' : {
14- 'url' : 'http ://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4 .pth' ,
15+ 'url' : 'https ://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6 .pth' ,
1516 'num_classes' : 1001 , 'input_size' : (3 , 299 , 299 ), 'pool_size' : (8 , 8 ),
1617 'crop_pct' : 0.8975 , 'interpolation' : 'bicubic' ,
1718 'mean' : IMAGENET_INCEPTION_MEAN , 'std' : IMAGENET_INCEPTION_STD ,
18- 'first_conv' : 'conv2d_1a.conv' , 'classifier' : 'last_linear' ,
19+ 'first_conv' : 'conv2d_1a.conv' , 'classifier' : 'classif' ,
20+ },
21+ # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
22+ 'ens_adv_inception_resnet_v2' : {
23+ 'url' : 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth' ,
24+ 'num_classes' : 1001 , 'input_size' : (3 , 299 , 299 ), 'pool_size' : (8 , 8 ),
25+ 'crop_pct' : 0.8975 , 'interpolation' : 'bicubic' ,
26+ 'mean' : IMAGENET_INCEPTION_MEAN , 'std' : IMAGENET_INCEPTION_STD ,
27+ 'first_conv' : 'conv2d_1a.conv' , 'classifier' : 'classif' ,
1928 }
2029}
2130
@@ -274,19 +283,20 @@ def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'
274283 )
275284 self .block8 = Block8 (noReLU = True )
276285 self .conv2d_7b = BasicConv2d (2080 , self .num_features , kernel_size = 1 , stride = 1 )
277- self .last_linear = nn .Linear (self .num_features , num_classes )
286+ # NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC
287+ self .classif = nn .Linear (self .num_features , num_classes )
278288
279289 def get_classifier (self ):
280- return self .last_linear
290+ return self .classif
281291
282292 def reset_classifier (self , num_classes , global_pool = 'avg' ):
283293 self .global_pool = global_pool
284294 self .num_classes = num_classes
285- del self .last_linear
295+ del self .classif
286296 if num_classes :
287- self .last_linear = torch .nn .Linear (self .num_features , num_classes )
297+ self .classif = torch .nn .Linear (self .num_features , num_classes )
288298 else :
289- self .last_linear = None
299+ self .classif = None
290300
291301 def forward_features (self , x , pool = True ):
292302 x = self .conv2d_1a (x )
@@ -314,13 +324,13 @@ def forward(self, x):
314324 x = self .forward_features (x , pool = True )
315325 if self .drop_rate > 0 :
316326 x = F .dropout (x , p = self .drop_rate , training = self .training )
317- x = self .last_linear (x )
327+ x = self .classif (x )
318328 return x
319329
320330
321331def inception_resnet_v2 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
322332 r"""InceptionResnetV2 model architecture from the
323- `"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
333+ `"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
324334 """
325335 default_cfg = default_cfgs ['inception_resnet_v2' ]
326336 model = InceptionResnetV2 (num_classes = num_classes , in_chans = in_chans , ** kwargs )
@@ -330,3 +340,16 @@ def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs
330340
331341 return model
332342
343+
344+ def ens_adv_inception_resnet_v2 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
345+ r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
346+ As per https://arxiv.org/abs/1705.07204 and
347+ https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
348+ """
349+ default_cfg = default_cfgs ['ens_adv_inception_resnet_v2' ]
350+ model = InceptionResnetV2 (num_classes = num_classes , in_chans = in_chans , ** kwargs )
351+ model .default_cfg = default_cfg
352+ if pretrained :
353+ load_pretrained (model , default_cfg , num_classes , in_chans )
354+
355+ return model
0 commit comments