33from enum import IntEnum
44import numpy as np
55from onnx import ModelProto , TensorProto , ValueInfoProto , load
6+ from onnx .reference import ReferenceEvaluator
67from onnx .helper import tensor_dtype_to_np_dtype
78from onnx .shape_inference import infer_shapes
89from . import to_array_extended
@@ -166,9 +167,9 @@ def enumerate_results(
166167 Returns:
167168 iterator on tuple(result kind, name, value, node.op_type or None)
168169 """
169- assert isinstance (self .evaluator , ExtendedReferenceEvaluator ), (
170+ assert isinstance (self .evaluator , ReferenceEvaluator ), (
170171 f"This implementation only works with "
171- f"ExtendedReferenceEvaluator not { type (self .evaluator )} "
172+ f"ReferenceEvaluator not { type (self .evaluator )} "
172173 )
173174 attributes = {}
174175 if output_names is None :
@@ -595,6 +596,7 @@ def compare_onnx_execution(
595596 raise_exc : bool = True ,
596597 mode : str = "execute" ,
597598 keep_tensor : bool = False ,
599+ cls : Optional [type [ReferenceEvaluator ]] = None ,
598600) -> Tuple [List [ResultExecution ], List [ResultExecution ], List [Tuple [int , int ]]]:
599601 """
600602 Compares the execution of two onnx models.
@@ -611,6 +613,7 @@ def compare_onnx_execution(
611613 :param mode: the model should be executed but the function can be executed
612614 but the comparison may append on nodes only
613615 :param keep_tensor: keeps the tensor in order to compute a precise distance
616+ :param cls: evaluator class to use
614617 :return: four results, a sequence of results
615618 for the first model and the second model,
616619 the alignment between the two, DistanceExecution
@@ -634,15 +637,15 @@ def compare_onnx_execution(
634637 print (f"[compare_onnx_execution] execute with { len (inputs )} inputs" )
635638 print ("[compare_onnx_execution] execute first model" )
636639 res1 = list (
637- YieldEvaluator (model1 ).enumerate_summarized (
640+ YieldEvaluator (model1 , cls = cls ).enumerate_summarized (
638641 None , feeds1 , raise_exc = raise_exc , keep_tensor = keep_tensor
639642 )
640643 )
641644 if verbose :
642645 print (f"[compare_onnx_execution] got { len (res1 )} results" )
643646 print ("[compare_onnx_execution] execute second model" )
644647 res2 = list (
645- YieldEvaluator (model2 ).enumerate_summarized (
648+ YieldEvaluator (model2 , cls = cls ).enumerate_summarized (
646649 None , feeds2 , raise_exc = raise_exc , keep_tensor = keep_tensor
647650 )
648651 )
0 commit comments