Skip to content

Commit a4c3741

Browse files
committed
Merge branch 'fix/op_version-main-merge-yaman-changes' into feature/opversion_and_trunc_v2
2 parents 80b6006 + 1b58cf8 commit a4c3741

File tree

20 files changed

+619
-151
lines changed

20 files changed

+619
-151
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ install_requires =
4949
importlib-metadata
5050
attrs>=22.2.0
5151
clize>=5.0.1
52-
protobuf==3.20.3
52+
protobuf>=3.20.3
5353
bitstring>=3.1.7
5454
numpy>=1.24.1
5555
onnx>=1.13.0

src/qonnx/core/modelwrapper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
import qonnx.util.basic as util
4040
import qonnx.util.onnx as onnxutil
4141
from qonnx.core.datatype import DataType
42-
from qonnx.custom_op.registry import getCustomOp
42+
from qonnx.custom_op.registry import getCustomOp, is_custom_op
4343
from qonnx.transformation.double_to_single_float import DoubleToSingleFloat
4444
from qonnx.transformation.general import (
4545
RemoveStaticGraphInputs,
@@ -183,7 +183,7 @@ def transform(
183183
if self.fix_float64:
184184
(transformed_model, model_was_changed) = DoubleToSingleFloat().apply(transformed_model)
185185

186-
if apply_to_subgraphs and not use_preorder_traversal:
186+
if apply_to_subgraphs and (use_preorder_traversal is False):
187187
transformed_model.transform_subgraphs(
188188
transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal
189189
)
@@ -632,11 +632,11 @@ def get_nodes_by_op_type(self, op_type):
632632

633633
def get_finn_nodes(self):
634634
"""Returns a list of nodes where domain == 'qonnx.*'."""
635-
return list(filter(lambda x: util.is_finn_op(x.domain), self.graph.node))
635+
return list(filter(lambda x: is_custom_op(x.domain), self.graph.node))
636636

637637
def get_non_finn_nodes(self):
638638
"""Returns a list of nodes where domain != 'qonnx.*'."""
639-
return list(filter(lambda x: not util.is_finn_op(x.domain), self.graph.node))
639+
return list(filter(lambda x: not is_custom_op(x.domain), self.graph.node))
640640

641641
def get_node_index(self, node):
642642
"""Returns current index of given node, or None if not found."""

src/qonnx/core/onnx_exec.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,16 @@
3535

3636
import qonnx.analysis.topology as ta
3737
import qonnx.core.execute_custom_node as ex_cu_node
38-
from qonnx.util.basic import (
39-
get_preferred_qonnx_opset,
40-
get_sanitize_quant_tensors,
41-
is_finn_op,
42-
qonnx_make_model,
43-
sanitize_quant_values,
44-
)
38+
from qonnx.custom_op.registry import is_custom_op
39+
from qonnx.util.basic import get_preferred_qonnx_opset, get_sanitize_quant_tensors, qonnx_make_model, sanitize_quant_values
4540

4641

4742
def execute_node(node, context, graph, opset_version, return_full_exec_context=False):
4843
"""Executes a single node by using onnxruntime or with a custom function.
4944
5045
Input/output provided via context."""
5146

52-
if is_finn_op(node.domain):
47+
if is_custom_op(node.domain, node.op_type):
5348
ex_cu_node.execute_custom_node(node, context, graph, onnx_opset_version=opset_version)
5449
else:
5550
# onnxruntime unfortunately does not implement run_node as defined by ONNX,

src/qonnx/custom_op/base.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,27 @@
3636
class CustomOp(ABC):
3737
"""CustomOp class all custom op nodes are based on. Contains different functions
3838
every custom node should have. Some as abstract methods, these have to be
39-
filled when writing a new custom op node."""
39+
filled when writing a new custom op node.
40+
41+
Opset Version Support:
42+
CustomOp classes use "since version" semantics matching ONNX operators.
43+
Version is determined by the class name using _vN suffix convention:
44+
45+
- No suffix (e.g., IntQuant): Version 1 (default)
46+
- _vN suffix (e.g., IntQuant_v2): Version N
47+
48+
The registry automatically selects the highest version <= requested opset.
49+
50+
Example:
51+
class IntQuant(CustomOp):
52+
pass # Version 1 (no suffix)
53+
54+
class IntQuant_v2(CustomOp):
55+
pass # Version 2, covers opset v2-v3 (if no v3 exists)
56+
57+
class IntQuant_v4(CustomOp):
58+
pass # Version 4, covers opset v4+
59+
"""
4060

4161
def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()):
4262
super().__init__()
Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,17 @@
1-
from qonnx.custom_op.channels_last.batch_normalization import BatchNormalization
2-
from qonnx.custom_op.channels_last.conv import Conv
3-
from qonnx.custom_op.channels_last.max_pool import MaxPool
4-
5-
# channels-last ops are defined by the underlying ONNX standard op
6-
# thus, we can define them for any version of the original op
7-
# so we emulate a custom op dictionary that mimics the support for any
8-
# {ChannelsLastOp}_vX instead of hardcoding what versions are supported
9-
10-
11-
class ChannelsLastCustomOpDict(dict):
12-
def __init__(self):
13-
self._custom_ops = {"Conv": Conv, "MaxPool": MaxPool, "BatchNormalization": BatchNormalization}
14-
15-
def __getitem__(self, key):
16-
base_key = key.split("_v")[0] # Extract base key (e.g., Conv from Conv_v13)
17-
if base_key in self._custom_ops:
18-
return self._custom_ops[base_key]
19-
raise KeyError(f"Channels-last CustomOp '{key}' not found.")
20-
21-
def __contains__(self, key):
22-
base_key = key.split("_v")[0]
23-
return base_key in self._custom_ops
24-
25-
def keys(self):
26-
return self._custom_ops.keys()
27-
28-
29-
custom_op = ChannelsLastCustomOpDict()
1+
# Importing registers CustomOps in qonnx.custom_op.channels_last domain
2+
from qonnx.custom_op.channels_last.batch_normalization import (
3+
BatchNormalization_v1,
4+
BatchNormalization_v9,
5+
BatchNormalization_v14,
6+
)
7+
from qonnx.custom_op.channels_last.conv import Conv_v1
8+
from qonnx.custom_op.channels_last.max_pool import MaxPool_v1, MaxPool_v10
9+
10+
__all__ = [
11+
"Conv_v1",
12+
"MaxPool_v1",
13+
"MaxPool_v10",
14+
"BatchNormalization_v1",
15+
"BatchNormalization_v9",
16+
"BatchNormalization_v14",
17+
]

src/qonnx/custom_op/channels_last/batch_normalization.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp
3333

3434

35-
class BatchNormalization(ChannelsLastWrappedOp):
35+
class BatchNormalization_v1(ChannelsLastWrappedOp):
3636
def get_nodeattr_types(self):
3737
"""Returns a dict of permitted attributes for node, where:
3838
ret_dict[attribute_name] = (dtype, require, default_value, <allowed_values>)
@@ -133,3 +133,13 @@ def verify_node(self):
133133
)
134134

135135
return info_messages
136+
137+
138+
class BatchNormalization_v9(BatchNormalization_v1):
139+
# no relevant changes for channels-last wrapper
140+
pass
141+
142+
143+
class BatchNormalization_v14(BatchNormalization_v9):
144+
# no relevant changes for channels-last wrapper
145+
pass

src/qonnx/custom_op/channels_last/conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from qonnx.custom_op.general.im2col import compute_conv_output_dim
3434

3535

36-
class Conv(ChannelsLastWrappedOp):
36+
class Conv_v1(ChannelsLastWrappedOp):
3737
def get_nodeattr_types(self):
3838
"""Returns a dict of permitted attributes for node, where:
3939
ret_dict[attribute_name] = (dtype, require, default_value, <allowed_values>)

src/qonnx/custom_op/channels_last/max_pool.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim
3434

3535

36-
class MaxPool(ChannelsLastWrappedOp):
36+
class MaxPool_v1(ChannelsLastWrappedOp):
3737
def get_nodeattr_types(self):
3838
"""Returns a dict of permitted attributes for node, where:
3939
ret_dict[attribute_name] = (dtype, require, default_value, <allowed_values>)
@@ -171,3 +171,8 @@ def verify_node(self):
171171
)
172172

173173
return info_messages
174+
175+
176+
class MaxPool_v10(MaxPool_v1):
177+
# no relevant changes for channels-last wrapper
178+
pass

src/qonnx/custom_op/general/__init__.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828

29+
# Importing registers CustomOps in qonnx.custom_op.general domain
2930
from qonnx.custom_op.general.bipolar_quant import BipolarQuant
3031
from qonnx.custom_op.general.debugmarker import DebugMarker
3132
from qonnx.custom_op.general.floatquant import FloatQuant
@@ -34,36 +35,23 @@
3435
from qonnx.custom_op.general.intquant import IntQuant
3536
from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC
3637
from qonnx.custom_op.general.multithreshold import MultiThreshold
38+
from qonnx.custom_op.general.quant import Quant
3739
from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d
3840
from qonnx.custom_op.general.trunc import Trunc_v1, Trunc_v2
3941
from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul
4042

41-
custom_op = dict()
42-
43-
custom_op["DebugMarker"] = DebugMarker
44-
custom_op["QuantAvgPool2d"] = QuantAvgPool2d
45-
custom_op["MaxPoolNHWC"] = MaxPoolNHWC
46-
custom_op["GenericPartition"] = GenericPartition
47-
custom_op["MultiThreshold"] = MultiThreshold
48-
custom_op["XnorPopcountMatMul"] = XnorPopcountMatMul
49-
custom_op["Im2Col"] = Im2Col
50-
custom_op["IntQuant"] = IntQuant
51-
custom_op["Quant"] = IntQuant
52-
custom_op["Trunc"] = Trunc_v1
53-
custom_op["BipolarQuant"] = BipolarQuant
54-
custom_op["FloatQuant"] = FloatQuant
55-
56-
custom_op["DebugMarker_v1"] = DebugMarker
57-
custom_op["QuantAvgPool2d_v1"] = QuantAvgPool2d
58-
custom_op["MaxPoolNHWC_v1"] = MaxPoolNHWC
59-
custom_op["GenericPartition_v1"] = GenericPartition
60-
custom_op["MultiThreshold_v1"] = MultiThreshold
61-
custom_op["XnorPopcountMatMul_v1"] = XnorPopcountMatMul
62-
custom_op["Im2Col_v1"] = Im2Col
63-
custom_op["IntQuant_v1"] = IntQuant
64-
custom_op["Quant_v1"] = IntQuant
65-
custom_op["Trunc_v1"] = Trunc_v1
66-
custom_op["BipolarQuant_v1"] = BipolarQuant
67-
custom_op["FloatQuant_v1"] = FloatQuant
68-
69-
custom_op["Trunc_v2"] = Trunc_v2
43+
__all__ = [
44+
"BipolarQuant",
45+
"DebugMarker",
46+
"FloatQuant",
47+
"GenericPartition",
48+
"Im2Col",
49+
"IntQuant",
50+
"MaxPoolNHWC",
51+
"MultiThreshold",
52+
"Quant",
53+
"QuantAvgPool2d",
54+
"Trunc_v1",
55+
"Trunc_v2",
56+
"XnorPopcountMatMul",
57+
]

0 commit comments

Comments
 (0)