diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index d23ce8ac..b43a1155 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -39,7 +39,7 @@ import qonnx.util.basic as util import qonnx.util.onnx as onnxutil from qonnx.core.datatype import DataType -from qonnx.custom_op.registry import is_custom_op +from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.double_to_single_float import DoubleToSingleFloat from qonnx.transformation.general import ( RemoveStaticGraphInputs, @@ -632,11 +632,11 @@ def get_nodes_by_op_type(self, op_type): def get_finn_nodes(self): """Returns a list of nodes where domain == 'qonnx.*'.""" - return list(filter(lambda x: is_custom_op(x.domain), self.graph.node)) + return list(filter(lambda x: util.is_finn_op(x.domain), self.graph.node)) def get_non_finn_nodes(self): """Returns a list of nodes where domain != 'qonnx.*'.""" - return list(filter(lambda x: not is_custom_op(x.domain), self.graph.node)) + return list(filter(lambda x: not util.is_finn_op(x.domain), self.graph.node)) def get_node_index(self, node): """Returns current index of given node, or None if not found.""" @@ -738,3 +738,44 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict): qa.tensor_name = tensor_name qa.quant_parameter_tensor_names.append(dt) qnt_annotations.append(qa) + + def get_opset_imports(self): + """Returns a list of imported opsets as a {domain, version} dictionary.""" + return {opset.domain: opset.version for opset in self._model_proto.opset_import} + + def get_customop_wrapper(self, node, fallback_customop_version=util.get_preferred_qonnx_opset()): + """Return CustomOp instance for given node, respecting the + imported opset version in the model protobuf. If the node's domain + is not found in the model's opset imports, fallback_customop_version + will be used.""" + opset_imports = self.get_opset_imports() + try: + opset_import = opset_imports[node.domain] + return getCustomOp(node, onnx_opset_version=opset_import) + except KeyError: + # domain not found in imports, use fallback version + warnings.warn( + f"Domain {node.domain} not found in model opset imports, " + f"using fallback_customop_version={fallback_customop_version}" + ) + return getCustomOp(node, onnx_opset_version=fallback_customop_version) + + def set_opset_import(self, domain, version): + """Sets the opset version for a given domain in the model's opset imports. + If the domain already exists, its version will be updated. If not, a new + opset import will be added. + + Args: + domain (str): The domain name (e.g. "qonnx.custom_op.general") + version (int): The opset version number + """ + # find if domain already exists in opset imports + for opset in self._model_proto.opset_import: + if opset.domain == domain: + opset.version = version + return + # domain not found, add new opset import + new_opset = onnx.OperatorSetIdProto() + new_opset.domain = domain + new_opset.version = version + self._model_proto.opset_import.append(new_opset) diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 4e300dd1..e756366d 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -51,11 +51,19 @@ def get_preferred_onnx_opset(): return 11 +def get_preferred_qonnx_opset(): + "Return preferred ONNX opset version for QONNX" + return 1 + + def qonnx_make_model(graph_proto, **kwargs): "Wrapper around ONNX make_model with preferred qonnx opset version" opset_imports = kwargs.pop("opset_imports", None) if opset_imports is None: - opset_imports = [make_opsetid("", get_preferred_onnx_opset())] + opset_imports = [ + make_opsetid("", get_preferred_onnx_opset()), + make_opsetid("qonnx.custom_op.general", get_preferred_qonnx_opset()), + ] kwargs["opset_imports"] = opset_imports else: kwargs["opset_imports"] = opset_imports @@ -63,21 +71,8 @@ def qonnx_make_model(graph_proto, **kwargs): def is_finn_op(op_type): - """Deprecated: Use is_custom_op from qonnx.custom_op.registry instead. - - Return whether given op_type string is a QONNX or FINN custom op. - This function uses hard-coded string matching and will be removed in QONNX v1.0. - Use the registry-based is_custom_op for better accuracy and extensibility. - """ - import warnings - warnings.warn( - "is_finn_op is deprecated and will be removed in QONNX v1.0. " - "Use 'from qonnx.custom_op.registry import is_custom_op' instead.", - DeprecationWarning, - stacklevel=2 - ) - from qonnx.custom_op.registry import is_custom_op - return is_custom_op(op_type) + "Return whether given op_type string is a QONNX or FINN custom op" + return op_type.startswith("finn") or op_type.startswith("qonnx.custom_op") or op_type.startswith("onnx.brevitas") def get_num_default_workers(): diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py index 722f0fb1..fb26e420 100644 --- a/tests/core/test_modelwrapper.py +++ b/tests/core/test_modelwrapper.py @@ -33,7 +33,7 @@ import qonnx.core.data_layout as DataLayout from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper -from qonnx.util.basic import qonnx_make_model +from qonnx.util.basic import get_preferred_onnx_opset, qonnx_make_model def test_modelwrapper(): @@ -68,6 +68,7 @@ def test_modelwrapper(): inp_sparsity = {"dw": {"kernel_shape": [3, 3]}} model.set_tensor_sparsity(first_conv_iname, inp_sparsity) assert model.get_tensor_sparsity(first_conv_iname) == inp_sparsity + assert model.get_opset_imports() == {"": 8} def test_modelwrapper_set_get_rm_initializer(): @@ -230,3 +231,31 @@ def test_modelwrapper_set_tensor_shape_multiple_inputs(): # check that order of inputs is preserved assert model.graph.input[0].name == "in1" assert model.graph.input[1].name == "in2" + + +def test_modelwrapper_set_opset_import(): + # Create a simple model + in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4]) + out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4]) + node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["out1"]) + graph = onnx.helper.make_graph( + nodes=[node], + name="single_node_graph", + inputs=[in1], + outputs=[out1], + ) + onnx_model = qonnx_make_model(graph, producer_name="opset-test-model") + model = ModelWrapper(onnx_model) + + # Test setting new domain + model.set_opset_import("qonnx.custom_op.general", 1) + preferred_onnx_opset = get_preferred_onnx_opset() + assert model.get_opset_imports() == {"": preferred_onnx_opset, "qonnx.custom_op.general": 1} + + # Test updating existing domain + model.set_opset_import("qonnx.custom_op.general", 2) + assert model.get_opset_imports() == {"": preferred_onnx_opset, "qonnx.custom_op.general": 2} + + # Test setting ONNX main domain + model.set_opset_import("", 13) + assert model.get_opset_imports() == {"": 13, "qonnx.custom_op.general": 2}