@@ -488,7 +488,8 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
488488
489489def _gen_mobilenet_v1 (
490490 variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 ,
491- fix_stem_head = False , head_conv = False , pretrained = False , ** kwargs ):
491+ group_size = None , fix_stem_head = False , head_conv = False , pretrained = False , ** kwargs
492+ ):
492493 """
493494 Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
494495 Paper: https://arxiv.org/abs/1801.04381
@@ -503,7 +504,12 @@ def _gen_mobilenet_v1(
503504 round_chs_fn = partial (round_channels , multiplier = channel_multiplier )
504505 head_features = (1024 if fix_stem_head else max (1024 , round_chs_fn (1024 ))) if head_conv else 0
505506 model_kwargs = dict (
506- block_args = decode_arch_def (arch_def , depth_multiplier = depth_multiplier , fix_first_last = fix_stem_head ),
507+ block_args = decode_arch_def (
508+ arch_def ,
509+ depth_multiplier = depth_multiplier ,
510+ fix_first_last = fix_stem_head ,
511+ group_size = group_size ,
512+ ),
507513 num_features = head_features ,
508514 stem_size = 32 ,
509515 fix_stem = fix_stem_head ,
@@ -517,7 +523,9 @@ def _gen_mobilenet_v1(
517523
518524
519525def _gen_mobilenet_v2 (
520- variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , fix_stem_head = False , pretrained = False , ** kwargs ):
526+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 ,
527+ group_size = None , fix_stem_head = False , pretrained = False , ** kwargs
528+ ):
521529 """ Generate MobileNet-V2 network
522530 Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
523531 Paper: https://arxiv.org/abs/1801.04381
@@ -533,7 +541,12 @@ def _gen_mobilenet_v2(
533541 ]
534542 round_chs_fn = partial (round_channels , multiplier = channel_multiplier )
535543 model_kwargs = dict (
536- block_args = decode_arch_def (arch_def , depth_multiplier = depth_multiplier , fix_first_last = fix_stem_head ),
544+ block_args = decode_arch_def (
545+ arch_def ,
546+ depth_multiplier = depth_multiplier ,
547+ fix_first_last = fix_stem_head ,
548+ group_size = group_size ,
549+ ),
537550 num_features = 1280 if fix_stem_head else max (1280 , round_chs_fn (1280 )),
538551 stem_size = 32 ,
539552 fix_stem = fix_stem_head ,
@@ -613,7 +626,8 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
613626
614627def _gen_efficientnet (
615628 variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , channel_divisor = 8 ,
616- group_size = None , pretrained = False , ** kwargs ):
629+ group_size = None , pretrained = False , ** kwargs
630+ ):
617631 """Creates an EfficientNet model.
618632
619633 Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
@@ -661,7 +675,8 @@ def _gen_efficientnet(
661675
662676
663677def _gen_efficientnet_edge (
664- variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs ):
678+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs
679+ ):
665680 """ Creates an EfficientNet-EdgeTPU model
666681
667682 Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu
@@ -692,7 +707,8 @@ def _gen_efficientnet_edge(
692707
693708
694709def _gen_efficientnet_condconv (
695- variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , experts_multiplier = 1 , pretrained = False , ** kwargs ):
710+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , experts_multiplier = 1 , pretrained = False , ** kwargs
711+ ):
696712 """Creates an EfficientNet-CondConv model.
697713
698714 Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
@@ -764,7 +780,8 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
764780
765781
766782def _gen_efficientnetv2_base (
767- variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
783+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs
784+ ):
768785 """ Creates an EfficientNet-V2 base model
769786
770787 Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -780,7 +797,7 @@ def _gen_efficientnetv2_base(
780797 ]
781798 round_chs_fn = partial (round_channels , multiplier = channel_multiplier , round_limit = 0. )
782799 model_kwargs = dict (
783- block_args = decode_arch_def (arch_def , depth_multiplier ),
800+ block_args = decode_arch_def (arch_def , depth_multiplier , group_size = group_size ),
784801 num_features = round_chs_fn (1280 ),
785802 stem_size = 32 ,
786803 round_chs_fn = round_chs_fn ,
@@ -793,7 +810,8 @@ def _gen_efficientnetv2_base(
793810
794811
795812def _gen_efficientnetv2_s (
796- variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , rw = False , pretrained = False , ** kwargs ):
813+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , rw = False , pretrained = False , ** kwargs
814+ ):
797815 """ Creates an EfficientNet-V2 Small model
798816
799817 Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -831,7 +849,9 @@ def _gen_efficientnetv2_s(
831849 return model
832850
833851
834- def _gen_efficientnetv2_m (variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
852+ def _gen_efficientnetv2_m (
853+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs
854+ ):
835855 """ Creates an EfficientNet-V2 Medium model
836856
837857 Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -849,7 +869,7 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
849869 ]
850870
851871 model_kwargs = dict (
852- block_args = decode_arch_def (arch_def , depth_multiplier ),
872+ block_args = decode_arch_def (arch_def , depth_multiplier , group_size = group_size ),
853873 num_features = 1280 ,
854874 stem_size = 24 ,
855875 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
@@ -861,7 +881,9 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
861881 return model
862882
863883
864- def _gen_efficientnetv2_l (variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
884+ def _gen_efficientnetv2_l (
885+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs
886+ ):
865887 """ Creates an EfficientNet-V2 Large model
866888
867889 Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -879,7 +901,7 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
879901 ]
880902
881903 model_kwargs = dict (
882- block_args = decode_arch_def (arch_def , depth_multiplier ),
904+ block_args = decode_arch_def (arch_def , depth_multiplier , group_size = group_size ),
883905 num_features = 1280 ,
884906 stem_size = 32 ,
885907 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
@@ -891,7 +913,9 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
891913 return model
892914
893915
894- def _gen_efficientnetv2_xl (variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
916+ def _gen_efficientnetv2_xl (
917+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs
918+ ):
895919 """ Creates an EfficientNet-V2 Xtra-Large model
896920
897921 Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -909,7 +933,7 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
909933 ]
910934
911935 model_kwargs = dict (
912- block_args = decode_arch_def (arch_def , depth_multiplier ),
936+ block_args = decode_arch_def (arch_def , depth_multiplier , group_size = group_size ),
913937 num_features = 1280 ,
914938 stem_size = 32 ,
915939 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
@@ -923,7 +947,8 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
923947
924948def _gen_efficientnet_x (
925949 variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , channel_divisor = 8 ,
926- group_size = None , version = 1 , pretrained = False , ** kwargs ):
950+ group_size = None , version = 1 , pretrained = False , ** kwargs
951+ ):
927952 """Creates an EfficientNet model.
928953
929954 Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
@@ -1069,9 +1094,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
10691094 return model
10701095
10711096
1072- def _gen_tinynet (
1073- variant , model_width = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs
1074- ):
1097+ def _gen_tinynet (variant , model_width = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
10751098 """Creates a TinyNet model.
10761099 """
10771100 arch_def = [
@@ -1183,8 +1206,7 @@ def _arch_def(chs: List[int], group_size: int):
11831206 return model
11841207
11851208
1186- def _gen_test_efficientnet (
1187- variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
1209+ def _gen_test_efficientnet (variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
11881210 """ Minimal test EfficientNet generator.
11891211 """
11901212 arch_def = [
0 commit comments