2121from typing import Callable , List , Tuple , Union
2222
2323
24- from timm .models import is_model , list_models
24+ from timm .models import is_model , list_models , get_pretrained_cfg
2525
2626
2727parser = argparse .ArgumentParser (description = 'Per-model process launcher' )
@@ -98,16 +98,32 @@ def main():
9898 cmd , cmd_args = cmd_from_args (args )
9999
100100 model_cfgs = []
101- model_names = []
102101 if args .model_list == 'all' :
103- # NOTE should make this config, for validation / benchmark runs the focus is 1k models,
104- # so we filter out 21/22k and some other unusable heads. This will change in the future...
105- exclude_model_filters = ['*in21k' , '*in22k' , '*dino' , '*_22k' ]
106102 model_names = list_models (
107103 pretrained = args .pretrained , # only include models w/ pretrained checkpoints if set
108- exclude_filters = exclude_model_filters
109104 )
110105 model_cfgs = [(n , None ) for n in model_names ]
106+ elif args .model_list == 'all_in1k' :
107+ model_names = list_models (pretrained = True )
108+ model_cfgs = []
109+ for n in model_names :
110+ pt_cfg = get_pretrained_cfg (n )
111+ if getattr (pt_cfg , 'num_classes' , 0 ) == 1000 :
112+ print (n , pt_cfg .num_classes )
113+ model_cfgs .append ((n , None ))
114+ elif args .model_list == 'all_res' :
115+ model_names = list_models ()
116+ model_names += [n .split ('.' )[0 ] for n in list_models (pretrained = True )]
117+ model_cfgs = set ()
118+ for n in model_names :
119+ pt_cfg = get_pretrained_cfg (n )
120+ if pt_cfg is None :
121+ print (f'Model { n } is missing pretrained cfg, skipping.' )
122+ continue
123+ model_cfgs .add ((n , pt_cfg .input_size [- 1 ]))
124+ if pt_cfg .test_input_size is not None :
125+ model_cfgs .add ((n , pt_cfg .test_input_size [- 1 ]))
126+ model_cfgs = [(n , {'img-size' : r }) for n , r in sorted (model_cfgs )]
111127 elif not is_model (args .model_list ):
112128 # model name doesn't exist, try as wildcard filter
113129 model_names = list_models (args .model_list )
@@ -122,7 +138,8 @@ def main():
122138 results_file = args .results_file or './results.csv'
123139 results = []
124140 errors = []
125- print ('Running script on these models: {}' .format (', ' .join (model_names )))
141+ model_strings = '\n ' .join ([f'{ x [0 ]} , { x [1 ]} ' for x in model_cfgs ])
142+ print (f"Running script on these models:\n { model_strings } " )
126143 if not args .sort_key :
127144 if 'benchmark' in args .script :
128145 if any (['train' in a for a in args .script_args ]):
@@ -136,10 +153,14 @@ def main():
136153 print (f'Script: { args .script } , Args: { args .script_args } , Sort key: { sort_key } ' )
137154
138155 try :
139- for m , _ in model_cfgs :
156+ for m , ax in model_cfgs :
140157 if not m :
141158 continue
142159 args_str = (cmd , * [str (e ) for e in cmd_args ], '--model' , m )
160+ if ax is not None :
161+ extra_args = [(f'--{ k } ' , str (v )) for k , v in ax .items ()]
162+ extra_args = [i for t in extra_args for i in t ]
163+ args_str += tuple (extra_args )
143164 try :
144165 o = subprocess .check_output (args = args_str ).decode ('utf-8' ).split ('--result' )[- 1 ]
145166 r = json .loads (o )
@@ -157,7 +178,11 @@ def main():
157178 if errors :
158179 print (f'{ len (errors )} models had errors during run.' )
159180 for e in errors :
160- print (f"\t { e ['model' ]} ({ e .get ('error' , 'Unknown' )} )" )
181+ if 'model' in e :
182+ print (f"\t { e ['model' ]} ({ e .get ('error' , 'Unknown' )} )" )
183+ else :
184+ print (e )
185+
161186 results = list (filter (lambda x : 'error' not in x , results ))
162187
163188 no_sortkey = list (filter (lambda x : sort_key not in x , results ))
0 commit comments