1414import argparse
1515import os
1616import re
17+ import shutil
1718import sys
1819import tarfile
1920import tempfile
@@ -187,14 +188,17 @@ def __init__(self, url, local, input_func, input_names, output_names,
187188 check_only_shape = False , model_type = "frozen" , force_input_shape = False ,
188189 skip_tensorflow = False , opset_constraints = None , tf_min_version = None , tag = None ,
189190 skip_conversion = False , converted_model = None , signature_def = None , concrete_function = None ,
190- large_model = False , structured_outputs = None , run_tf_frozen = None , use_custom_ops = False ):
191+ large_model = False , structured_outputs = None , run_tf_frozen = None , use_custom_ops = False ,
192+ ort_profile = None , tf_profile = None ):
191193 self .url = url
192194 self .input_func = input_func
193195 self .local = local
194196 self .input_names = input_names
195197 self .output_names = output_names
196198 self .disabled = disabled
197199 self .large_model = large_model
200+ self .ort_profile = ort_profile
201+ self .tf_profile = tf_profile
198202 self .use_custom_ops = use_custom_ops
199203 if run_tf_frozen is None :
200204 run_tf_frozen = not self .large_model
@@ -324,13 +328,14 @@ def run_onnxruntime(self, name, model_proto, inputs, outputs, external_tensor_st
324328 as_text = utils .is_debug_mode (),
325329 external_tensor_storage = external_tensor_storage )
326330 logger .info ("Model saved to %s" , model_path )
331+ opt = rt .SessionOptions ()
327332 if self .use_custom_ops :
328333 from ortcustomops import get_library_path
329- opt = rt .SessionOptions ()
330334 opt .register_custom_ops_library (get_library_path ())
331335 m = rt .InferenceSession (model_path , opt )
332- else :
333- m = rt .InferenceSession (model_path )
336+ if self .ort_profile is not None :
337+ opt .enable_profiling = True
338+ m = rt .InferenceSession (model_path , opt )
334339 results = m .run (outputs , inputs )
335340 if self .perf :
336341 n = 0
@@ -342,6 +347,9 @@ def run_onnxruntime(self, name, model_proto, inputs, outputs, external_tensor_st
342347 n += PERF_STEP
343348 self .onnx_runtime = 1000 * (time .time () - start ) / n
344349 logger .info ("ORT perf {:.2f}ms/inference, n={}" .format (self .onnx_runtime , n ))
350+ if self .ort_profile is not None :
351+ tmp_path = m .end_profiling ()
352+ shutil .move (tmp_path , self .ort_profile )
345353 return results
346354
347355 @staticmethod
@@ -449,10 +457,14 @@ def run_tflite():
449457 n = 0
450458 start = time .time ()
451459 stop = start + PERF_TIME
460+ if self .tf_profile is not None :
461+ tf .profiler .experimental .start (self .tf_profile )
452462 while time .time () < stop :
453463 for _ in range (PERF_STEP ):
454464 _ = concrete_func (** inputs )
455465 n += PERF_STEP
466+ if self .tf_profile is not None :
467+ tf .profiler .experimental .stop ()
456468 self .tf_runtime = 1000 * (time .time () - start ) / n
457469 logger .info ("TF perf {:.2f}ms/inference, n={}" .format (self .tf_runtime , n ))
458470 logger .info ("TensorFlow OK" )
@@ -497,7 +509,11 @@ def run_tflite():
497509 if self .skip_tensorflow :
498510 logger .info ("TensorFlow SKIPPED" )
499511 elif self .run_tf_frozen :
512+ if self .tf_profile is not None :
513+ tf .profiler .experimental .start (self .tf_profile )
500514 tf_results = self .run_tensorflow (sess , inputs )
515+ if self .tf_profile is not None :
516+ tf .profiler .experimental .stop ()
501517 logger .info ("TensorFlow OK" )
502518 tf_graph = sess .graph
503519
@@ -690,7 +706,7 @@ def load_tests_from_yaml(path):
690706 for kw in ["rtol" , "atol" , "ptol" , "disabled" , "check_only_shape" , "model_type" , "concrete_function" ,
691707 "skip_tensorflow" , "force_input_shape" , "tf_min_version" , "tag" , "skip_conversion" ,
692708 "converted_model" , "signature_def" , "large_model" , "structured_outputs" , "run_tf_frozen" ,
693- "use_custom_ops" , "dequantize" ]:
709+ "use_custom_ops" , "dequantize" , "ort_profile" , "tf_profile" ]:
694710 if settings .get (kw ) is not None :
695711 kwargs [kw ] = settings [kw ]
696712
0 commit comments