Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions src/qonnx/core/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import qonnx.util.basic as util
import qonnx.util.onnx as onnxutil
from qonnx.core.datatype import DataType
from qonnx.custom_op.registry import is_custom_op
from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.double_to_single_float import DoubleToSingleFloat
from qonnx.transformation.general import (
RemoveStaticGraphInputs,
Expand Down Expand Up @@ -632,11 +632,11 @@ def get_nodes_by_op_type(self, op_type):

def get_finn_nodes(self):
"""Returns a list of nodes where domain == 'qonnx.*'."""
return list(filter(lambda x: is_custom_op(x.domain), self.graph.node))
return list(filter(lambda x: util.is_finn_op(x.domain), self.graph.node))

def get_non_finn_nodes(self):
"""Returns a list of nodes where domain != 'qonnx.*'."""
return list(filter(lambda x: not is_custom_op(x.domain), self.graph.node))
return list(filter(lambda x: not util.is_finn_op(x.domain), self.graph.node))

def get_node_index(self, node):
"""Returns current index of given node, or None if not found."""
Expand Down Expand Up @@ -738,3 +738,44 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict):
qa.tensor_name = tensor_name
qa.quant_parameter_tensor_names.append(dt)
qnt_annotations.append(qa)

def get_opset_imports(self):
"""Returns a list of imported opsets as a {domain, version} dictionary."""
return {opset.domain: opset.version for opset in self._model_proto.opset_import}

def get_customop_wrapper(self, node, fallback_customop_version=util.get_preferred_qonnx_opset()):
"""Return CustomOp instance for given node, respecting the
imported opset version in the model protobuf. If the node's domain
is not found in the model's opset imports, fallback_customop_version
will be used."""
opset_imports = self.get_opset_imports()
try:
opset_import = opset_imports[node.domain]
return getCustomOp(node, onnx_opset_version=opset_import)
except KeyError:
# domain not found in imports, use fallback version
warnings.warn(
f"Domain {node.domain} not found in model opset imports, "
f"using fallback_customop_version={fallback_customop_version}"
)
return getCustomOp(node, onnx_opset_version=fallback_customop_version)

def set_opset_import(self, domain, version):
"""Sets the opset version for a given domain in the model's opset imports.
If the domain already exists, its version will be updated. If not, a new
opset import will be added.

Args:
domain (str): The domain name (e.g. "qonnx.custom_op.general")
version (int): The opset version number
"""
# find if domain already exists in opset imports
for opset in self._model_proto.opset_import:
if opset.domain == domain:
opset.version = version
return
# domain not found, add new opset import
new_opset = onnx.OperatorSetIdProto()
new_opset.domain = domain
new_opset.version = version
self._model_proto.opset_import.append(new_opset)
27 changes: 11 additions & 16 deletions src/qonnx/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,33 +51,28 @@ def get_preferred_onnx_opset():
return 11


def get_preferred_qonnx_opset():
"Return preferred ONNX opset version for QONNX"
return 1


def qonnx_make_model(graph_proto, **kwargs):
"Wrapper around ONNX make_model with preferred qonnx opset version"
opset_imports = kwargs.pop("opset_imports", None)
if opset_imports is None:
opset_imports = [make_opsetid("", get_preferred_onnx_opset())]
opset_imports = [
make_opsetid("", get_preferred_onnx_opset()),
make_opsetid("qonnx.custom_op.general", get_preferred_qonnx_opset()),
]
kwargs["opset_imports"] = opset_imports
else:
kwargs["opset_imports"] = opset_imports
return make_model(graph_proto, **kwargs)


def is_finn_op(op_type):
"""Deprecated: Use is_custom_op from qonnx.custom_op.registry instead.

Return whether given op_type string is a QONNX or FINN custom op.
This function uses hard-coded string matching and will be removed in QONNX v1.0.
Use the registry-based is_custom_op for better accuracy and extensibility.
"""
import warnings
warnings.warn(
"is_finn_op is deprecated and will be removed in QONNX v1.0. "
"Use 'from qonnx.custom_op.registry import is_custom_op' instead.",
DeprecationWarning,
stacklevel=2
)
from qonnx.custom_op.registry import is_custom_op
return is_custom_op(op_type)
"Return whether given op_type string is a QONNX or FINN custom op"
return op_type.startswith("finn") or op_type.startswith("qonnx.custom_op") or op_type.startswith("onnx.brevitas")


def get_num_default_workers():
Expand Down
31 changes: 30 additions & 1 deletion tests/core/test_modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import qonnx.core.data_layout as DataLayout
from qonnx.core.datatype import DataType
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.util.basic import qonnx_make_model
from qonnx.util.basic import get_preferred_onnx_opset, qonnx_make_model


def test_modelwrapper():
Expand Down Expand Up @@ -68,6 +68,7 @@ def test_modelwrapper():
inp_sparsity = {"dw": {"kernel_shape": [3, 3]}}
model.set_tensor_sparsity(first_conv_iname, inp_sparsity)
assert model.get_tensor_sparsity(first_conv_iname) == inp_sparsity
assert model.get_opset_imports() == {"": 8}


def test_modelwrapper_set_get_rm_initializer():
Expand Down Expand Up @@ -230,3 +231,31 @@ def test_modelwrapper_set_tensor_shape_multiple_inputs():
# check that order of inputs is preserved
assert model.graph.input[0].name == "in1"
assert model.graph.input[1].name == "in2"


def test_modelwrapper_set_opset_import():
# Create a simple model
in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4])
out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4])
node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["out1"])
graph = onnx.helper.make_graph(
nodes=[node],
name="single_node_graph",
inputs=[in1],
outputs=[out1],
)
onnx_model = qonnx_make_model(graph, producer_name="opset-test-model")
model = ModelWrapper(onnx_model)

# Test setting new domain
model.set_opset_import("qonnx.custom_op.general", 1)
preferred_onnx_opset = get_preferred_onnx_opset()
assert model.get_opset_imports() == {"": preferred_onnx_opset, "qonnx.custom_op.general": 1}

# Test updating existing domain
model.set_opset_import("qonnx.custom_op.general", 2)
assert model.get_opset_imports() == {"": preferred_onnx_opset, "qonnx.custom_op.general": 2}

# Test setting ONNX main domain
model.set_opset_import("", 13)
assert model.get_opset_imports() == {"": 13, "qonnx.custom_op.general": 2}
Loading