diff --git a/README.md b/README.md
index ecf6bd47..108b498c 100644
--- a/README.md
+++ b/README.md
@@ -11,7 +11,7 @@
-QONNX (Quantized ONNX) introduces several custom operators -- [`IntQuant`](docs/qonnx-custom-ops/intquant_op.md), [`FloatQuant`](docs/qonnx-custom-ops/floatquant_op.md), [`BipolarQuant`](docs/qonnx-custom-ops/bipolar_quant_op.md), and [`Trunc`](docs/qonnx-custom-ops/trunc_op.md) -- in order to represent arbitrary-precision integer and minifloat quantization in ONNX. This enables:
+QONNX (Quantized ONNX) introduces several [custom operators](docs/qonnx-custom-ops/overview.md) -- `IntQuant`, `FloatQuant`, `BipolarQuant`, and `Trunc` -- in order to represent arbitrary-precision integer and minifloat quantization in ONNX. This enables:
* Representation of binary, ternary, 3-bit, 4-bit, 6-bit or any other integer/fixed-point quantization.
* Representation of minifloat quantization with configurable exponent and mantissa bits.
* Quantization is an operator itself, and can be applied to any parameter or layer input.
@@ -29,9 +29,7 @@ This repository contains a set of Python utilities to work with QONNX models, in
### Operator definitions
-* [Quant](docs/qonnx-custom-ops/quant_op.md) for 2-to-arbitrary-bit quantization, with scaling and zero-point
-* [BipolarQuant](docs/qonnx-custom-ops/bipolar_quant_op.md) for 1-bit (bipolar) quantization, with scaling and zero-point
-* [Trunc](docs/qonnx-custom-ops/trunc_op.md) for truncating to a specified number of bits, with scaling and zero-point
+Please see the [custom operator overview](docs/qonnx-custom-ops/overview.md) table for more details.
### Installation
diff --git a/docs/qonnx-custom-ops/bipolar_quant_op.md b/docs/qonnx-custom-ops/bipolarquant_v1.md
similarity index 94%
rename from docs/qonnx-custom-ops/bipolar_quant_op.md
rename to docs/qonnx-custom-ops/bipolarquant_v1.md
index 3a70458e..03c0c01e 100644
--- a/docs/qonnx-custom-ops/bipolar_quant_op.md
+++ b/docs/qonnx-custom-ops/bipolarquant_v1.md
@@ -5,7 +5,7 @@ Additionally, takes one float as input, which define the scaling.
#### Version
-This operator is not part of the ONNX standard and is not currently versioned.
+The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.
#### Attributes
diff --git a/docs/qonnx-custom-ops/floatquant_op.md b/docs/qonnx-custom-ops/floatquant_v1.md
similarity index 98%
rename from docs/qonnx-custom-ops/floatquant_op.md
rename to docs/qonnx-custom-ops/floatquant_v1.md
index fc51b75f..4536194b 100644
--- a/docs/qonnx-custom-ops/floatquant_op.md
+++ b/docs/qonnx-custom-ops/floatquant_v1.md
@@ -16,7 +16,7 @@ special (symbolic) values. This makes it nontrivial to infer the maximum represe
#### Version
-This operator is not part of the ONNX standard and is not currently versioned.
+The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.
#### Attributes
diff --git a/docs/qonnx-custom-ops/intquant_op.md b/docs/qonnx-custom-ops/intquant_v1.md
similarity index 97%
rename from docs/qonnx-custom-ops/intquant_op.md
rename to docs/qonnx-custom-ops/intquant_v1.md
index fb627efb..4d15c0ec 100644
--- a/docs/qonnx-custom-ops/intquant_op.md
+++ b/docs/qonnx-custom-ops/intquant_v1.md
@@ -9,11 +9,11 @@ rounding_mode defines how quantized values are rounded.
Notes:
* This operator was previously named `Quant` but is renamed to `IntQuant` to distinguish it from `FloatQuant`. For a transition period, qonnx will transparently handle `Quant` as `IntQuant` for backwards compatibility reasons, but only `IntQuant` should be used for new models.
-* This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists.
+* This operator does not work for binary or bipolar quantization, for this purpose the simpler `BipolarQuant` node exists.
#### Version
-This operator is not part of the ONNX standard and is not currently versioned.
+The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.
#### Attributes
diff --git a/docs/qonnx-custom-ops/overview.md b/docs/qonnx-custom-ops/overview.md
new file mode 100644
index 00000000..dfb93c38
--- /dev/null
+++ b/docs/qonnx-custom-ops/overview.md
@@ -0,0 +1,13 @@
+## Operator Schemas
+
+This file lists the QONNX custom operators, similar to `Operators.md` for the ONNX standard.
+It is manually updated, since QONNX custom operators are relatively few in number.
+
+### qonnx.custom_op.general
+
+|**Operator**|**Since version**||
+|-|-|-|
+|BipolarQuant|1|
+|FloatQuant|1|
+|IntQuant|1|
+|Trunc|2, 1|
diff --git a/docs/qonnx-custom-ops/trunc_op.md b/docs/qonnx-custom-ops/trunc_v1.md
similarity index 96%
rename from docs/qonnx-custom-ops/trunc_op.md
rename to docs/qonnx-custom-ops/trunc_v1.md
index 1b5f0d04..04b88443 100644
--- a/docs/qonnx-custom-ops/trunc_op.md
+++ b/docs/qonnx-custom-ops/trunc_v1.md
@@ -6,7 +6,7 @@ The attribute rounding_mode defines how truncated values are rounded.
#### Version
-This operator is not part of the ONNX standard and is not currently versioned.
+The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.
#### Attributes
diff --git a/docs/qonnx-custom-ops/trunc_v2.md b/docs/qonnx-custom-ops/trunc_v2.md
new file mode 100644
index 00000000..d716c6c2
--- /dev/null
+++ b/docs/qonnx-custom-ops/trunc_v2.md
@@ -0,0 +1,144 @@
+### **Trunc**
+
+Truncates the values of one input data (Tensor) at a specified bitwidth and produces one output data (Tensor).
+Additionally, takes four float tensors as input, which define the scale, zero-point, input bit-width and output bit-width of the quantization.
+The attribute rounding_mode defines how truncated values are rounded.
+
+#### Version
+
+This operator is not part of the ONNX standard.
+The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 2.
+
+#### Attributes
+
+
+- rounding_mode : string (default is "FLOOR")
+- Defines how rounding should be applied during truncation. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".
+- signed : int (default is 1)
+- Defines if the quantization includes a signed bit. E.g. at 8b unsigned=[0, 255] vs signed=[-128, 127].
+- narrow : int (default is 0)
+- Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].
+
+
+#### Inputs
+
+
+- X (differentiable) : tensor(float32)
+- input tensor to truncate
+- scale : float32
+- The scale factor at the input of the truncation
+- zeropt : float32
+- The zero-point at the input of the truncation
+- in_bitwidth : int32
+- The number of bits used at the input of the truncation
+- out_scale : float32
+- The scale factor of the output of the truncation
+- out_bitwidth : int32
+- The number of bits used at the output of the truncation
+
+
+
+#### Outputs
+
+
+- Y (differentiable) : tensor(float32)
+- Output tensor
+
+
+
+#### Examples
+
+Trunc
+
+```python
+from onnx import helper
+import numpy as np
+
+# Define node settings and input
+x = np.random.randn(100).astype(np.float32)*10.
+scale = np.array(1.)
+zeropt = np.array(0.)
+in_bitwidth = np.array(10)
+out_bitwidth = np.array(4)
+rounding_mode = "ROUND"
+
+# Create node
+node = helper.make_node(
+ 'Trunc',
+ domain='finn.custom_op.general',
+ inputs=['x', 'scale', 'zeropt', 'in_bitwidth', 'out_bitwidth'],
+ outputs=['y'],
+ rounding_mode=rounding_mode,
+)
+
+# Execute the same settings with the reference implementation (trunc)
+# See the sample implementation for more details on trunc.
+output_ref = trunc(inp_tensor, scale, zeropt, in_bitwidth, out_bitwidth, rounding_mode)
+
+# Execute node and compare
+expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_trunc')
+
+```
+
+
+
+
+#### Sample Implementation
+
+
+Trunc
+
+```python
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import numpy as np
+
+def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):
+
+ # Scaling
+ y = inp_tensor / scale
+ y = y + zeropt
+ # Rounding
+ y = np.round(y)
+ # Rescale
+ trunc_scale = 2 ** np.round(
+ np.log2(output_scale / scale)
+ ) # Trunc scale should be a power-of-two - ensure that is the case
+ y = y / trunc_scale
+
+ # Clamping
+ min_int_val = min_int(signed, narrow, output_bit_width)
+ max_int_val = max_int(signed, narrow, output_bit_width)
+ y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y)
+ y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y)
+ # To int (truncate)
+ rounding_fx = resolve_rounding_mode(rounding_mode)
+ y = rounding_fx(y)
+
+ # Rescale
+ output_zeropt = zeropt / trunc_scale # Rescale zero-point
+ y = y - output_zeropt
+ y = y * output_scale
+
+ return y
+
+def resolve_rounding_mode(mode_string):
+ """Resolve the rounding mode string of Quant and Trunc ops
+ to the corresponding numpy functions."""
+ if mode_string == "ROUND":
+ return np.round
+ elif mode_string == "CEIL":
+ return np.ceil
+ elif mode_string == "FLOOR":
+ return np.floor
+ else:
+ raise ValueError(f"Could not resolve rounding mode called: {mode_string}")
+
+```
+
+
diff --git a/src/qonnx/core/execute_custom_node.py b/src/qonnx/core/execute_custom_node.py
index 7acf3792..cd6bb605 100644
--- a/src/qonnx/core/execute_custom_node.py
+++ b/src/qonnx/core/execute_custom_node.py
@@ -27,10 +27,9 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import qonnx.custom_op.registry as registry
-from qonnx.util.basic import get_preferred_onnx_opset
-def execute_custom_node(node, context, graph, onnx_opset_version=get_preferred_onnx_opset()):
+def execute_custom_node(node, context, graph, onnx_opset_version):
"""Call custom implementation to execute a single custom node.
Input/output provided via context."""
op_type = node.op_type
diff --git a/src/qonnx/core/onnx_exec.py b/src/qonnx/core/onnx_exec.py
index 3a686f7e..893504de 100644
--- a/src/qonnx/core/onnx_exec.py
+++ b/src/qonnx/core/onnx_exec.py
@@ -36,15 +36,10 @@
import qonnx.analysis.topology as ta
import qonnx.core.execute_custom_node as ex_cu_node
from qonnx.custom_op.registry import is_custom_op
-from qonnx.util.basic import (
- get_preferred_onnx_opset,
- get_sanitize_quant_tensors,
- qonnx_make_model,
- sanitize_quant_values,
-)
+from qonnx.util.basic import get_preferred_qonnx_opset, get_sanitize_quant_tensors, qonnx_make_model, sanitize_quant_values
-def execute_node(node, context, graph, return_full_exec_context=False, opset_version=get_preferred_onnx_opset()):
+def execute_node(node, context, graph, opset_version, return_full_exec_context=False):
"""Executes a single node by using onnxruntime or with a custom function.
Input/output provided via context."""
@@ -158,7 +153,7 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N
model_exec_mode = model.get_metadata_prop("exec_mode")
if (model_exec_mode is None) or (model_exec_mode == ""):
# extract opset version for node-by-node execution
- opset_version = model.model.opset_import[0].version
+ opset_imports = model.get_opset_imports()
# execute the model node by node
# we can simply walk down the list since the ONNX spec guarantees that it is
# topologically sorted
@@ -176,7 +171,11 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N
if get_sanitize_quant_tensors() != 0:
# round input values to match quantization annotation
execution_context = sanitize_quant_values(model, node.input, execution_context)
- execute_node(node, execution_context, graph, return_full_exec_context, opset_version)
+ if node.domain in opset_imports:
+ opset_version = opset_imports[node.domain]
+ else:
+ opset_version = get_preferred_qonnx_opset()
+ execute_node(node, execution_context, graph, opset_version, return_full_exec_context)
if get_sanitize_quant_tensors() != 0:
# round output values to quantization annotation
execution_context = sanitize_quant_values(model, node.output, execution_context)
diff --git a/src/qonnx/custom_op/base.py b/src/qonnx/custom_op/base.py
index 775d9f95..383e453d 100644
--- a/src/qonnx/custom_op/base.py
+++ b/src/qonnx/custom_op/base.py
@@ -30,15 +30,35 @@
import onnx.numpy_helper as np_helper
from abc import ABC, abstractmethod
-from qonnx.util.basic import get_by_name, get_preferred_onnx_opset
+from qonnx.util.basic import get_by_name, get_preferred_qonnx_opset
class CustomOp(ABC):
"""CustomOp class all custom op nodes are based on. Contains different functions
every custom node should have. Some as abstract methods, these have to be
- filled when writing a new custom op node."""
+ filled when writing a new custom op node.
- def __init__(self, onnx_node, onnx_opset_version=get_preferred_onnx_opset()):
+ Opset Version Support:
+ CustomOp classes use "since version" semantics matching ONNX operators.
+ Version is determined by the class name using _vN suffix convention:
+
+ - No suffix (e.g., IntQuant): Version 1 (default)
+ - _vN suffix (e.g., IntQuant_v2): Version N
+
+ The registry automatically selects the highest version <= requested opset.
+
+ Example:
+ class IntQuant(CustomOp):
+ pass # Version 1 (no suffix)
+
+ class IntQuant_v2(CustomOp):
+ pass # Version 2, covers opset v2-v3 (if no v3 exists)
+
+ class IntQuant_v4(CustomOp):
+ pass # Version 4, covers opset v4+
+ """
+
+ def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()):
super().__init__()
self.onnx_node = onnx_node
self.onnx_opset_version = onnx_opset_version
diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py
index 77a048e7..390b0030 100644
--- a/src/qonnx/custom_op/channels_last/__init__.py
+++ b/src/qonnx/custom_op/channels_last/__init__.py
@@ -1,11 +1,17 @@
# Importing registers CustomOps in qonnx.custom_op.channels_last domain
-from qonnx.custom_op.channels_last.batch_normalization import BatchNormalization
-from qonnx.custom_op.channels_last.conv import Conv
-from qonnx.custom_op.channels_last.max_pool import MaxPool
+from qonnx.custom_op.channels_last.batch_normalization import (
+ BatchNormalization_v1,
+ BatchNormalization_v9,
+ BatchNormalization_v14,
+)
+from qonnx.custom_op.channels_last.conv import Conv_v1
+from qonnx.custom_op.channels_last.max_pool import MaxPool_v1, MaxPool_v10
-# Legacy dictionary for backward compatibility
-custom_op = {
- "Conv": Conv,
- "MaxPool": MaxPool,
- "BatchNormalization": BatchNormalization,
-}
\ No newline at end of file
+__all__ = [
+ "Conv_v1",
+ "MaxPool_v1",
+ "MaxPool_v10",
+ "BatchNormalization_v1",
+ "BatchNormalization_v9",
+ "BatchNormalization_v14",
+]
diff --git a/src/qonnx/custom_op/channels_last/batch_normalization.py b/src/qonnx/custom_op/channels_last/batch_normalization.py
index f3b3f872..a49591f4 100644
--- a/src/qonnx/custom_op/channels_last/batch_normalization.py
+++ b/src/qonnx/custom_op/channels_last/batch_normalization.py
@@ -32,7 +32,7 @@
from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp
-class BatchNormalization(ChannelsLastWrappedOp):
+class BatchNormalization_v1(ChannelsLastWrappedOp):
def get_nodeattr_types(self):
"""Returns a dict of permitted attributes for node, where:
ret_dict[attribute_name] = (dtype, require, default_value, )
@@ -133,3 +133,13 @@ def verify_node(self):
)
return info_messages
+
+
+class BatchNormalization_v9(BatchNormalization_v1):
+ # no relevant changes for channels-last wrapper
+ pass
+
+
+class BatchNormalization_v14(BatchNormalization_v9):
+ # no relevant changes for channels-last wrapper
+ pass
diff --git a/src/qonnx/custom_op/channels_last/conv.py b/src/qonnx/custom_op/channels_last/conv.py
index b0ff237b..9d74dd59 100644
--- a/src/qonnx/custom_op/channels_last/conv.py
+++ b/src/qonnx/custom_op/channels_last/conv.py
@@ -33,7 +33,7 @@
from qonnx.custom_op.general.im2col import compute_conv_output_dim
-class Conv(ChannelsLastWrappedOp):
+class Conv_v1(ChannelsLastWrappedOp):
def get_nodeattr_types(self):
"""Returns a dict of permitted attributes for node, where:
ret_dict[attribute_name] = (dtype, require, default_value, )
diff --git a/src/qonnx/custom_op/channels_last/max_pool.py b/src/qonnx/custom_op/channels_last/max_pool.py
index 383f3008..21a39d1d 100644
--- a/src/qonnx/custom_op/channels_last/max_pool.py
+++ b/src/qonnx/custom_op/channels_last/max_pool.py
@@ -33,7 +33,7 @@
from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim
-class MaxPool(ChannelsLastWrappedOp):
+class MaxPool_v1(ChannelsLastWrappedOp):
def get_nodeattr_types(self):
"""Returns a dict of permitted attributes for node, where:
ret_dict[attribute_name] = (dtype, require, default_value, )
@@ -171,3 +171,8 @@ def verify_node(self):
)
return info_messages
+
+
+class MaxPool_v10(MaxPool_v1):
+ # no relevant changes for channels-last wrapper
+ pass
diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py
index 2f3896de..d76df696 100644
--- a/src/qonnx/custom_op/general/__init__.py
+++ b/src/qonnx/custom_op/general/__init__.py
@@ -35,23 +35,23 @@
from qonnx.custom_op.general.intquant import IntQuant
from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC
from qonnx.custom_op.general.multithreshold import MultiThreshold
-from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d
from qonnx.custom_op.general.quant import Quant
-from qonnx.custom_op.general.trunc import Trunc
+from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d
+from qonnx.custom_op.general.trunc import Trunc_v1, Trunc_v2
from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul
-# Legacy dictionary for backward compatibility
-custom_op = {
- "DebugMarker": DebugMarker,
- "QuantAvgPool2d": QuantAvgPool2d,
- "MaxPoolNHWC": MaxPoolNHWC,
- "GenericPartition": GenericPartition,
- "MultiThreshold": MultiThreshold,
- "XnorPopcountMatMul": XnorPopcountMatMul,
- "Im2Col": Im2Col,
- "IntQuant": IntQuant,
- "Quant": IntQuant, # Alias
- "Trunc": Trunc,
- "BipolarQuant": BipolarQuant,
- "FloatQuant": FloatQuant,
-}
\ No newline at end of file
+__all__ = [
+ "BipolarQuant",
+ "DebugMarker",
+ "FloatQuant",
+ "GenericPartition",
+ "Im2Col",
+ "IntQuant",
+ "MaxPoolNHWC",
+ "MultiThreshold",
+ "Quant",
+ "QuantAvgPool2d",
+ "Trunc_v1",
+ "Trunc_v2",
+ "XnorPopcountMatMul",
+]
diff --git a/src/qonnx/custom_op/general/maxpoolnhwc.py b/src/qonnx/custom_op/general/maxpoolnhwc.py
index eb964fc4..93c6012d 100644
--- a/src/qonnx/custom_op/general/maxpoolnhwc.py
+++ b/src/qonnx/custom_op/general/maxpoolnhwc.py
@@ -97,10 +97,7 @@ def execute_node(self, context, graph):
inp_vi = helper.make_tensor_value_info(inp_name, TensorProto.FLOAT, inp.shape)
out_vi = helper.make_tensor_value_info(out_name, TensorProto.FLOAT, dummy_out.shape)
tmp_graph = helper.make_graph(nodes=[node], name="tmp_graph", inputs=[inp_vi], outputs=[out_vi])
- opset_version = self.onnx_opset_version
- opset_imports = [helper.make_opsetid("", opset_version)]
- onnx_kwargs = {"opset_imports": opset_imports}
- tmp_model = qonnx_make_model(tmp_graph, producer_name="finn", **onnx_kwargs)
+ tmp_model = qonnx_make_model(tmp_graph, producer_name="finn")
tmp_model = ModelWrapper(tmp_model)
new_ctx = {inp_name: inp}
from qonnx.core.onnx_exec import execute_onnx
diff --git a/src/qonnx/custom_op/general/quantavgpool2d.py b/src/qonnx/custom_op/general/quantavgpool2d.py
index c0e24071..00617dcf 100644
--- a/src/qonnx/custom_op/general/quantavgpool2d.py
+++ b/src/qonnx/custom_op/general/quantavgpool2d.py
@@ -33,7 +33,7 @@
from qonnx.core.datatype import DataType
from qonnx.custom_op.base import CustomOp
from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim
-from qonnx.util.basic import qonnx_make_model
+from qonnx.util.basic import get_preferred_onnx_opset, qonnx_make_model
class QuantAvgPool2d(CustomOp):
@@ -132,7 +132,7 @@ def execute_node(self, context, graph):
outputs=[outp],
)
- opset_version = self.onnx_opset_version
+ opset_version = get_preferred_onnx_opset()
opset_imports = [helper.make_opsetid("", opset_version)]
onnx_kwargs = {"opset_imports": opset_imports}
model_avgpool = qonnx_make_model(graph_avgpool, **onnx_kwargs)
diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py
index 8e2eaa19..10c7e992 100644
--- a/src/qonnx/custom_op/general/trunc.py
+++ b/src/qonnx/custom_op/general/trunc.py
@@ -31,10 +31,99 @@
from qonnx.core.datatype import DataType
from qonnx.custom_op.base import CustomOp
-from qonnx.custom_op.general.quant import resolve_rounding_mode
+from qonnx.custom_op.general.quant import max_int, min_int, resolve_rounding_mode
+from qonnx.util.basic import get_preferred_qonnx_opset
-def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
+def trunc_v2(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):
+ # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR
+
+ # Scaling
+ y = inp_tensor / scale
+ y = y + zeropt
+ # Rounding
+ y = np.round(y)
+ # Rescale
+ trunc_scale = 2 ** np.round(
+ np.log2(output_scale / scale)
+ ) # Trunc scale should be a power-of-two - ensure that is the case
+ y = y / trunc_scale
+
+ # Clamping
+ min_int_val = min_int(signed, narrow, output_bit_width)
+ max_int_val = max_int(signed, narrow, output_bit_width)
+ y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y)
+ y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y)
+ # To int (truncate)
+ rounding_fx = resolve_rounding_mode(rounding_mode)
+ y = rounding_fx(y)
+
+ # Rescale
+ output_zeropt = zeropt / trunc_scale # Rescale zero-point
+ y = y - output_zeropt
+ y = y * output_scale
+
+ return y
+
+
+class Trunc_v2(CustomOp):
+ """Generic truncation operation for QONNX. Takes four inputs:
+ - input tensor to truncate
+ - the scale
+ - the zero-point
+ - the truncation scale
+ - the truncation bit-width
+
+ The output is a tensor of the same shape as the input tensor, with truncated
+ values.
+ """
+
+ def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()):
+ super().__init__(onnx_node, onnx_opset_version)
+ # override any specified opset version, this instance is v2
+ self.onnx_opset_version = 2
+
+ def get_nodeattr_types(self):
+ return {
+ # The rounding mode, which is used for the trunc function
+ "rounding_mode": ("s", True, "FLOOR"),
+ "narrow": ("i", False, 0, {0, 1}),
+ "signed": ("i", False, 1, {0, 1}),
+ }
+
+ def make_shape_compatible_op(self, model):
+ node = self.onnx_node
+ return helper.make_node("Identity", [node.input[0]], [node.output[0]])
+
+ def infer_node_datatype(self, model):
+ node = self.onnx_node
+ model.set_tensor_datatype(node.output[0], DataType["FLOAT32"])
+
+ def execute_node(self, context, graph):
+ node = self.onnx_node
+ # save inputs
+ inp_tensor = context[node.input[0]]
+ scale = context[node.input[1]]
+ zeropt = context[node.input[2]]
+ input_bit_width = context[node.input[3]]
+ output_scale = context[node.input[4]]
+ output_bit_width = context[node.input[5]]
+ # save attributes
+ rounding_mode = self.get_nodeattr("rounding_mode")
+ narrow = self.get_nodeattr("narrow")
+ signed = self.get_nodeattr("signed")
+ # calculate output
+ ret = trunc_v2(
+ inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode
+ )
+ # set context according to output name
+ context[node.output[0]] = ret
+
+ def verify_node(self):
+ pass
+
+
+def trunc_v1(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
# Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR
# Scaling
@@ -58,7 +147,7 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding
return y
-class Trunc(CustomOp):
+class Trunc_v1(CustomOp):
"""Generic truncation operation for QONNX. Takes four inputs:
- input tensor to truncate
- the scale
@@ -69,6 +158,11 @@ class Trunc(CustomOp):
values.
"""
+ def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()):
+ super().__init__(onnx_node, onnx_opset_version)
+ # override any specified opset version, this instance is v1
+ self.onnx_opset_version = 1
+
def get_nodeattr_types(self):
return {
# The rounding mode, which is used for the trunc function
@@ -94,7 +188,7 @@ def execute_node(self, context, graph):
# save attributes
rounding_mode = self.get_nodeattr("rounding_mode")
# calculate output
- ret = trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode)
+ ret = trunc_v1(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode)
# set context according to output name
context[node.output[0]] = ret
diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py
index b116f9e1..e9f6f0e7 100644
--- a/src/qonnx/custom_op/registry.py
+++ b/src/qonnx/custom_op/registry.py
@@ -28,14 +28,15 @@
import importlib
import inspect
+import warnings
from threading import RLock
from typing import Dict, List, Optional, Tuple, Type
from qonnx.custom_op.base import CustomOp
-from qonnx.util.basic import get_preferred_onnx_opset
-# Registry keyed by original ONNX domain: (domain, op_type) -> CustomOp class
-_OP_REGISTRY: Dict[Tuple[str, str], Type[CustomOp]] = {}
+# Nested registry for O(1) lookups: domain -> op_type -> version -> CustomOp class
+# Uses "since version" semantics: version N covers opset N until a higher version exists
+_OP_REGISTRY: Dict[str, Dict[str, Dict[int, Type[CustomOp]]]] = {}
_REGISTRY_LOCK = RLock()
@@ -68,92 +69,335 @@ def resolve_domain(domain: str) -> str:
return _DOMAIN_ALIASES.get(domain, domain)
-def add_op_to_domain(domain: str, op_class: Type[CustomOp]) -> None:
- """Register a custom op directly to a domain at runtime.
+def _get_op_type_for_class(cls: Type[CustomOp]) -> str:
+ """Extract the op_type from a CustomOp class name, stripping _vN suffix if present.
- The op_type is automatically derived from the class name.
- Useful for testing and experimentation. For production, define CustomOps
- in the appropriate module file.
+ Args:
+ cls: CustomOp class
+
+ Returns:
+ op_type string (e.g., "IntQuant_v2" -> "IntQuant")
+ """
+ name = cls.__name__
+ # Strip _vN suffix if present
+ if "_v" in name:
+ parts = name.split("_v")
+ if len(parts) == 2 and parts[1].isdigit():
+ return parts[0] # IntQuant_v2 -> IntQuant
+ return name
+
+
+def _get_op_version_for_class(cls: Type[CustomOp]) -> int:
+ """Extract version from a CustomOp class name.
Args:
- domain: ONNX domain name (e.g., "qonnx.custom_op.general")
- op_class: CustomOp subclass
+ cls: CustomOp class
- Example:
- add_op_to_domain("qonnx.custom_op.general", MyTestOp)
+ Returns:
+ Opset version (defaults to 1 if no _vN suffix present)
"""
- if not issubclass(op_class, CustomOp):
- raise ValueError(f"{op_class} must be a subclass of CustomOp")
+ name = cls.__name__
+ if "_v" in name:
+ parts = name.rsplit("_v", 1)
+ if len(parts) == 2 and parts[1].isdigit():
+ return int(parts[1])
+ return 1
- op_type = op_class.__name__
- with _REGISTRY_LOCK:
- _OP_REGISTRY[(domain, op_type)] = op_class
+def _discover_from_custom_op_dict(module, op_type: str, domain: str) -> Dict[int, Type[CustomOp]]:
+ """Extract CustomOp versions from legacy custom_op dict (backward compatibility).
+ Supports the old registration pattern:
+ custom_op = dict()
+ custom_op["IntQuant"] = IntQuant
+ custom_op["IntQuant_v2"] = IntQuant_v2
-def _discover_custom_op(domain: str, op_type: str) -> bool:
- """Discover and register a single custom op.
+ Args:
+ module: The imported module to check
+ op_type: The specific op type to discover
+ domain: The domain name (for warnings)
+
+ Returns:
+ Dict mapping version -> CustomOp class
+ """
+ versions = {}
+
+ if not (hasattr(module, "custom_op") and isinstance(module.custom_op, dict)):
+ return versions
+
+ # Iterate all dict entries, filter by op_type
+ for key, obj in module.custom_op.items():
+ # Check if this dict key matches the requested op_type
+ base_name = key.split("_v")[0] if "_v" in key else key
+ if base_name != op_type:
+ continue
+
+ if not (inspect.isclass(obj) and issubclass(obj, CustomOp) and obj is not CustomOp):
+ continue
+
+ try:
+ version = _get_op_version_for_class(obj)
+ except ValueError as e:
+ warnings.warn(str(e))
+ continue
+
+ if version in versions:
+ warnings.warn(
+ f"Multiple classes found for {domain}.{op_type} version {version}: "
+ f"{versions[version].__name__} and {obj.__name__}. Using {obj.__name__}."
+ )
+ versions[version] = obj
+
+ return versions
+
+
+def _discover_custom_op_versions(domain: str, op_type: str) -> Dict[int, Type[CustomOp]]:
+ """Discover all versions of a SPECIFIC custom op without loading entire domain.
+
+ Uses __all__ when available for efficient filtering, otherwise falls back to
+ full module inspection. Only loads classes matching the requested op_type.
Args:
domain: The ONNX domain name
op_type: The specific op type to discover
Returns:
- True if op was found and registered, False otherwise
+ Dict mapping version -> CustomOp class
"""
module_path = resolve_domain(domain)
+ versions = {}
try:
module = importlib.import_module(module_path)
except ModuleNotFoundError:
- return False
+ return versions
+
+ # Fast path: use __all__ to find only matching classes
+ if hasattr(module, "__all__"):
+ # Filter __all__ to find all versions of THIS op_type
+ # e.g., op_type="IntQuant" matches ["IntQuant", "IntQuant_v2", "IntQuant_v4"]
+ candidates = []
+ for name in module.__all__:
+ # Strip _vN suffix to check if it matches
+ base_name = name.split("_v")[0] if "_v" in name else name
+ if base_name == op_type:
+ candidates.append(name)
+
+ # Import ONLY the matching classes (lazy loading)
+ for name in candidates:
+ try:
+ obj = getattr(module, name)
+ except AttributeError:
+ continue
+
+ if not (inspect.isclass(obj) and issubclass(obj, CustomOp) and obj is not CustomOp):
+ continue
+
+ try:
+ version = _get_op_version_for_class(obj)
+ except ValueError as e:
+ warnings.warn(str(e))
+ continue
+
+ if version in versions:
+ warnings.warn(
+ f"Multiple classes found for {domain}.{op_type} version {version}: "
+ f"{versions[version].__name__} and {obj.__name__}. Using {obj.__name__}."
+ )
+ versions[version] = obj
+
+ # Backward compatibility: if __all__ didn't have the op, try custom_op dict
+ if not versions:
+ versions = _discover_from_custom_op_dict(module, op_type, domain)
+
+ else:
+ # No __all__ - try legacy dict first (O(1) check, cheaper than full scan)
+ versions = _discover_from_custom_op_dict(module, op_type, domain)
+
+ # Still nothing? Fallback to full module scan (for external modules)
+ if not versions:
+ for name, obj in inspect.getmembers(module, inspect.isclass):
+ if not issubclass(obj, CustomOp) or obj is CustomOp:
+ continue
+
+ class_op_type = _get_op_type_for_class(obj)
+ if class_op_type != op_type:
+ continue
+
+ try:
+ version = _get_op_version_for_class(obj)
+ except ValueError as e:
+ warnings.warn(str(e))
+ continue
+
+ if version in versions:
+ warnings.warn(
+ f"Multiple classes found for {domain}.{op_type} version {version}: "
+ f"{versions[version].__name__} and {obj.__name__}. Using {obj.__name__}."
+ )
+ versions[version] = obj
+
+ return versions
+
+
+def _resolve_version(
+ available_versions: Dict[int, Type[CustomOp]], requested_version: Optional[int]
+) -> Tuple[int, Type[CustomOp]]:
+ """Resolve which version to use given available and requested versions.
+
+ Uses "since version" semantics: highest version <= requested is selected.
+
+ Resolution strategy:
+ 1. If requested is None, use highest available version
+ 2. Try exact match
+ 3. Use highest version <= requested
+ 4. Raise KeyError if no suitable version
- # Try namespace lookup
- op_class = getattr(module, op_type, None)
- if inspect.isclass(op_class) and issubclass(op_class, CustomOp):
- _OP_REGISTRY[(domain, op_type)] = op_class
- return True
+ Args:
+ available_versions: Dict of available versions -> CustomOp classes
+ requested_version: Requested opset version, or None for highest
- # Try legacy dict
- custom_op_dict = getattr(module, 'custom_op', None)
- if isinstance(custom_op_dict, dict):
- op_class = custom_op_dict.get(op_type)
- if inspect.isclass(op_class) and issubclass(op_class, CustomOp):
- _OP_REGISTRY[(domain, op_type)] = op_class
- return True
+ Returns:
+ Tuple of (resolved_version, CustomOp_class)
- return False
+ Raises:
+ KeyError: If no suitable version found
+ """
+ if not available_versions:
+ raise KeyError("No versions available")
+
+ # Strategy 1: If no specific version requested, use highest
+ if requested_version is None:
+ highest = max(available_versions.keys())
+ return highest, available_versions[highest]
+
+ # Strategy 2: Try exact match
+ if requested_version in available_versions:
+ return requested_version, available_versions[requested_version]
+
+ # Strategy 3: Use highest version <= requested (since version semantics)
+ suitable = [v for v in available_versions.keys() if v <= requested_version]
+ if suitable:
+ selected = max(suitable)
+ return selected, available_versions[selected]
+
+ # Strategy 4: No suitable version found
+ available_list = sorted(available_versions.keys())
+ raise KeyError(
+ f"No suitable version found. Requested: {requested_version}, "
+ f"Available: {available_list}. Lowest available version is {available_list[0]}."
+ )
-def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset()):
+def add_op_to_domain(domain: str, op_class: Type[CustomOp]) -> None:
+ """Register a custom op directly to a domain at runtime.
+
+ The op_type and version are automatically derived from the class name.
+ Useful for testing and experimentation. For production, define CustomOps
+ in the appropriate module file.
+
+ Args:
+ domain: ONNX domain name (e.g., "qonnx.custom_op.general")
+ op_class: CustomOp subclass (version inferred from name)
+
+ Example:
+ add_op_to_domain("qonnx.custom_op.general", MyTestOp) # v1
+ add_op_to_domain("qonnx.custom_op.general", MyTestOp_v2) # v2
+ """
+ if not issubclass(op_class, CustomOp):
+ raise ValueError(f"{op_class} must be a subclass of CustomOp")
+
+ op_type = _get_op_type_for_class(op_class)
+ op_version = _get_op_version_for_class(op_class)
+
+ with _REGISTRY_LOCK:
+ # Ensure nested dict structure exists
+ if domain not in _OP_REGISTRY:
+ _OP_REGISTRY[domain] = {}
+ if op_type not in _OP_REGISTRY[domain]:
+ _OP_REGISTRY[domain][op_type] = {}
+
+ _OP_REGISTRY[domain][op_type][op_version] = op_class
+
+
+def getCustomOp(node, onnx_opset_version=None):
"""Get a custom op instance for an ONNX node.
+ Uses "since version" semantics: selects highest version <= requested opset.
+ Lazy loads only the requested op_type using __all__ for efficiency.
+
Args:
node: ONNX node with domain and op_type attributes
- onnx_opset_version: ONNX opset version to use
+ onnx_opset_version: Opset version from model's opset_import, or None for highest
Returns:
CustomOp instance for the node
Raises:
- KeyError: If op_type not found in domain
+ KeyError: If op_type not found in domain or no suitable version available
"""
op_type = node.op_type
domain = node.domain
- key = (domain, op_type)
with _REGISTRY_LOCK:
- if key in _OP_REGISTRY:
- return _OP_REGISTRY[key](node, onnx_opset_version=onnx_opset_version)
+ # O(1) nested dict lookup to check cache
+ if domain in _OP_REGISTRY and op_type in _OP_REGISTRY[domain]:
+ cached_versions = _OP_REGISTRY[domain][op_type]
+ else:
+ # Cache miss: discover THIS op only (lazy, uses __all__ for speed)
+ cached_versions = _discover_custom_op_versions(domain, op_type)
+
+ if not cached_versions:
+ module_path = resolve_domain(domain)
+ raise KeyError(
+ f"Op '{op_type}' not found in domain '{domain}' (module: {module_path}). "
+ f"Ensure it's defined in the module with proper naming (OpName or OpName_vN)."
+ )
+
+ # Cache it in nested structure
+ if domain not in _OP_REGISTRY:
+ _OP_REGISTRY[domain] = {}
+ _OP_REGISTRY[domain][op_type] = cached_versions
+
+ # Resolve which version to use
+ resolved_version, op_class = _resolve_version(cached_versions, onnx_opset_version)
+
+ # Instantiate and return
+ return op_class(node, onnx_opset_version=resolved_version)
+
+
+def get_supported_versions(domain: str, op_type: str) -> List[int]:
+ """Get list of supported opset versions for a custom op.
+
+ Returns all "since versions" where the operator was introduced or changed.
+
+ Args:
+ domain: ONNX domain name
+ op_type: Operation type name
- if _discover_custom_op(domain, op_type):
- return _OP_REGISTRY[key](node, onnx_opset_version=onnx_opset_version)
+ Returns:
+ Sorted list of opset versions
+
+ Raises:
+ KeyError: If op not found
+ """
+ with _REGISTRY_LOCK:
+ # O(1) check if cached
+ if domain in _OP_REGISTRY and op_type in _OP_REGISTRY[domain]:
+ return sorted(_OP_REGISTRY[domain][op_type].keys())
- module_path = resolve_domain(domain)
- raise KeyError(
- f"Op '{op_type}' not found in domain '{domain}' (module: {module_path}). "
- f"Ensure it's exported in the module namespace or in the custom_op dict."
- )
+ # Not cached: discover this op
+ versions_dict = _discover_custom_op_versions(domain, op_type)
+
+ if not versions_dict:
+ raise KeyError(f"Op '{op_type}' not found in domain '{domain}'")
+
+ # Cache discovered versions
+ if domain not in _OP_REGISTRY:
+ _OP_REGISTRY[domain] = {}
+ _OP_REGISTRY[domain][op_type] = versions_dict
+
+ return sorted(versions_dict.keys())
def is_custom_op(domain: str, op_type: Optional[str] = None) -> bool:
@@ -173,14 +417,15 @@ def is_custom_op(domain: str, op_type: Optional[str] = None) -> bool:
with _REGISTRY_LOCK:
if op_type is not None:
- # Check for specific op
- key = (domain, op_type)
- if key in _OP_REGISTRY:
+ # Check for specific op - O(1) with nested dict
+ if domain in _OP_REGISTRY and op_type in _OP_REGISTRY[domain]:
return True
- return _discover_custom_op(domain, op_type)
+ # Try to discover
+ versions = _discover_custom_op_versions(domain, op_type)
+ return len(versions) > 0
else:
# Check if domain has any registered ops
- if any(d == domain for d, _ in _OP_REGISTRY.keys()):
+ if domain in _OP_REGISTRY and _OP_REGISTRY[domain]:
return True
# Try to import the domain module as fallback
module_path = resolve_domain(domain)
@@ -203,12 +448,10 @@ def hasCustomOp(domain: str, op_type: str) -> bool:
Returns:
True if the op exists, False otherwise
"""
- import warnings
warnings.warn(
- "hasCustomOp is deprecated and will be removed in QONNX v1.0. "
- "Use is_custom_op instead.",
+ "hasCustomOp is deprecated and will be removed in QONNX v1.0. " "Use is_custom_op instead.",
DeprecationWarning,
- stacklevel=2
+ stacklevel=2,
)
return is_custom_op(domain, op_type)
@@ -216,6 +459,9 @@ def hasCustomOp(domain: str, op_type: str) -> bool:
def get_ops_in_domain(domain: str) -> List[Tuple[str, Type[CustomOp]]]:
"""Get all CustomOp classes available in a domain.
+ Note: Returns unique op_types. If multiple versions exist, returns the highest version.
+ This function eagerly loads all ops in the domain.
+
Args:
domain: ONNX domain name (e.g., "qonnx.custom_op.general")
@@ -227,34 +473,49 @@ def get_ops_in_domain(domain: str) -> List[Tuple[str, Type[CustomOp]]]:
for op_name, op_class in ops:
print(f"{op_name}: {op_class}")
"""
- ops = []
module_path = resolve_domain(domain)
+ ops_dict = {}
with _REGISTRY_LOCK:
- # Strategy 1: Get cached ops (fast path)
- for (d, op_type), op_class in _OP_REGISTRY.items():
- if d == domain:
- ops.append((op_type, op_class))
+ # Strategy 1: Get cached ops (fast path) - use highest version
+ if domain in _OP_REGISTRY:
+ for op_type, versions in _OP_REGISTRY[domain].items():
+ if versions:
+ highest_version = max(versions.keys())
+ ops_dict[op_type] = versions[highest_version]
# Strategy 2: Discover from module (for uncached ops)
+ # This uses full scan since we want ALL ops
try:
module = importlib.import_module(module_path)
- # Check namespace exports
- for name, obj in inspect.getmembers(module):
- if (inspect.isclass(obj) and
- issubclass(obj, CustomOp) and
- obj is not CustomOp and
- not name.startswith('_') and
- not any(op[0] == name for op in ops)):
- ops.append((name, obj))
-
- # Check legacy custom_op dict
- if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict):
- for name, cls in module.custom_op.items():
- if not any(op[0] == name for op in ops):
- ops.append((name, cls))
+ # Use __all__ if available for efficiency
+ if hasattr(module, "__all__"):
+ candidates = [(name, getattr(module, name, None)) for name in module.__all__]
+ candidates = [(n, obj) for n, obj in candidates if obj is not None]
+ else:
+ candidates = inspect.getmembers(module, inspect.isclass)
+
+ for name, obj in candidates:
+ if not (inspect.isclass(obj) and issubclass(obj, CustomOp) and obj is not CustomOp):
+ continue
+
+ op_type = _get_op_type_for_class(obj)
+ try:
+ version = _get_op_version_for_class(obj)
+ except ValueError:
+ continue
+
+ # Keep highest version only
+ if op_type not in ops_dict:
+ ops_dict[op_type] = obj
+ else:
+ # Check if this version is higher
+ existing_version = _get_op_version_for_class(ops_dict[op_type])
+ if version > existing_version:
+ ops_dict[op_type] = obj
+
except ModuleNotFoundError:
pass # Domain doesn't exist as module, return cached ops only
- return ops
+ return list(ops_dict.items())
diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py
index 175af058..a00f8a9c 100644
--- a/src/qonnx/transformation/channels_last.py
+++ b/src/qonnx/transformation/channels_last.py
@@ -32,8 +32,8 @@
from onnx import TensorProto, helper
from qonnx.core.modelwrapper import ModelWrapper
-from qonnx.custom_op import channels_last
from qonnx.custom_op.channels_last.base_wrapped_op import to_channels_first_args, to_channels_last_args
+from qonnx.custom_op.registry import get_ops_in_domain
from qonnx.transformation.base import Transformation
from qonnx.transformation.fold_constants import FoldConstants
from qonnx.transformation.general import SortGraph
@@ -44,7 +44,7 @@
from qonnx.util.onnx import is_eltwise_optype
# Standard ONNX nodes which require a ChannelsLast data format to function properly
-_channelsLast_node_types = list(channels_last.custom_op.keys())
+_channelsLast_node_types = list([x[0] for x in get_ops_in_domain("qonnx.custom_op.channels_last")])
# Nodes, which do not modify the shape of the tensor
# And modify all values in the same way.
@@ -270,8 +270,15 @@ def apply(self, model):
# Attach to original node
n.output[i] = outp_trans_in
- # Modify domain
+ # Modify node domain
n.domain = "qonnx.custom_op.channels_last"
+ opset_imports = model.get_opset_imports()
+ # Ensure channels_last domain is imported in model
+ if "qonnx.custom_op.channels_last" not in opset_imports:
+ # use the same opset for channels last ops as the standard ONNX opset
+ # (since they are defined based on the standard ops under the hood)
+ onnx_opset = opset_imports[""] if "" in opset_imports.keys() else opset_imports["ai.onnx"]
+ model.model.opset_import.append(helper.make_opsetid("qonnx.custom_op.channels_last", onnx_opset))
# Set modified flag
graph_modified = True
diff --git a/src/qonnx/transformation/fixedpt_quantize.py b/src/qonnx/transformation/fixedpt_quantize.py
index 127fa4b1..3b3357ed 100644
--- a/src/qonnx/transformation/fixedpt_quantize.py
+++ b/src/qonnx/transformation/fixedpt_quantize.py
@@ -41,19 +41,15 @@ def default_op_filter(op):
class FixedPointQuantizeParamsFromDict(Transformation):
"""
- Quantize model parameters to a given fixed-point representation.
- The self.max_err dictionary stores the maximum error for each quantized input after calling.
- Parameters:
- fixedpt_dict: Dictionary containing tensor names and their corresponding target fixed-point
- <<<<<<< HEAD
- data type or its canonical name
- =======
- data type or its canonical name
- >>>>>>> 7dfc4b8 ([Lint] rerun linter, fix errors)
- rounding_mode: Rounding mode used for conversion into fixed point.
- Default is "ROUND",
- possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN",
- "HALF_UP", "HALF_DOWN"]
+ Quantize model parameters to a given fixed-point representation.
+ The self.max_err dictionary stores the maximum error for each quantized input after calling.
+ Parameters:
+ fixedpt_dict: Dictionary containing tensor names and their corresponding target fixed-point
+ data type or its canonical name
+ rounding_mode: Rounding mode used for conversion into fixed point.
+ Default is "ROUND",
+ possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN",
+ "HALF_UP", "HALF_DOWN"]
"""
def __init__(self, fixedpt_dict, rounding_mode="ROUND"):
diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py
index 17957d12..cef4f67b 100644
--- a/src/qonnx/util/basic.py
+++ b/src/qonnx/util/basic.py
@@ -78,13 +78,15 @@ def is_finn_op(op_type):
Use the registry-based is_custom_op for better accuracy and extensibility.
"""
import warnings
+
warnings.warn(
"is_finn_op is deprecated and will be removed in QONNX v1.0. "
"Use 'from qonnx.custom_op.registry import is_custom_op' instead.",
DeprecationWarning,
- stacklevel=2
+ stacklevel=2,
)
from qonnx.custom_op.registry import is_custom_op
+
return is_custom_op(op_type)
diff --git a/tests/core/test_custom_onnx_exec.py b/tests/core/test_custom_onnx_exec.py
index 8eec7156..54b71754 100644
--- a/tests/core/test_custom_onnx_exec.py
+++ b/tests/core/test_custom_onnx_exec.py
@@ -32,6 +32,8 @@
import qonnx.core.execute_custom_node as ex_cu_node
from qonnx.custom_op.registry import getCustomOp
+mt_node_version = 1
+
def test_execute_custom_node_multithreshold():
inputs = np.ndarray(
@@ -155,7 +157,7 @@ def test_execute_custom_node_multithreshold():
execution_context["v"] = inputs
execution_context["thresholds"] = threshold_values
- ex_cu_node.execute_custom_node(node_def, execution_context, graph_def)
+ ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version)
outputs = np.ndarray(
shape=(6, 3, 2, 2),
@@ -250,7 +252,7 @@ def test_execute_custom_node_multithreshold():
)
graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out])
- ex_cu_node.execute_custom_node(node_def, execution_context, graph_def)
+ ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version)
outputs_scaled = 2.0 * outputs - 1.0
assert (execution_context["out"] == outputs_scaled).all()
@@ -270,7 +272,7 @@ def test_execute_custom_node_multithreshold():
execution_context["v"] = inputs_nhwc
graph_def = helper.make_graph([node_def], "test_model", [v_nhwc, thresholds], [out_nhwc])
- ex_cu_node.execute_custom_node(node_def, execution_context, graph_def)
+ ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version)
assert (execution_context["out"] == outputs_nhwc).all()
# check the set of allowed values
op_inst = getCustomOp(node_def)
diff --git a/tests/custom_op/test_attr.py b/tests/custom_op/test_attr.py
index cde5a321..d1d32546 100644
--- a/tests/custom_op/test_attr.py
+++ b/tests/custom_op/test_attr.py
@@ -29,10 +29,9 @@
import numpy as np
import onnx.parser as oprs
-import qonnx.custom_op.general as general
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op.base import CustomOp
-from qonnx.custom_op.registry import getCustomOp
+from qonnx.custom_op.registry import add_op_to_domain, getCustomOp
class AttrTestOp(CustomOp):
@@ -60,7 +59,7 @@ def verify_node(self):
def test_attr():
- general.custom_op["AttrTestOp"] = AttrTestOp
+ add_op_to_domain("qonnx.custom_op.general", AttrTestOp)
ishp = (1, 10)
wshp = (1, 3)
oshp = wshp
diff --git a/tests/custom_op/test_customop_version.py b/tests/custom_op/test_customop_version.py
new file mode 100644
index 00000000..e0d30c56
--- /dev/null
+++ b/tests/custom_op/test_customop_version.py
@@ -0,0 +1,137 @@
+# Copyright (c) 2025 Advanced Micro Devices, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of qonnx nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import onnx.parser as oprs
+
+from qonnx.core.modelwrapper import ModelWrapper
+from qonnx.custom_op.base import CustomOp
+from qonnx.custom_op.registry import add_op_to_domain, getCustomOp
+
+
+class VerTestOp_v1(CustomOp):
+ def get_nodeattr_types(self):
+ my_attrs = {"v1_attr": ("i", True, 0)}
+ return my_attrs
+
+ def make_shape_compatible_op(self, model):
+ ishape = model.get_tensor_shape(self.onnx_node.input[0])
+ return super().make_const_shape_op(ishape)
+
+ def infer_node_datatype(self, model):
+ node = self.onnx_node
+ # data type stays the same
+ dtype = model.get_tensor_datatype(node.input[0])
+ model.set_tensor_datatype(node.output[0], dtype)
+
+ def execute_node(self, context, graph):
+ node = self.onnx_node
+ context[node.output[0]] = context[node.input[0]]
+
+ def verify_node(self):
+ pass
+
+
+class VerTestOp_v2(VerTestOp_v1):
+ def get_nodeattr_types(self):
+ my_attrs = {"v2_attr": ("i", True, 0)}
+ return my_attrs
+
+
+class VerTestOp_v3(VerTestOp_v2):
+ def get_nodeattr_types(self):
+ my_attrs = {"v3_attr": ("i", True, 0)}
+ return my_attrs
+
+
+def make_vertest_model(vertest_ver, no_opset_import):
+ ishp = (1, 10)
+ oshp = ishp
+ ishp_str = str(list(ishp))
+ oshp_str = str(list(oshp))
+ if no_opset_import:
+ opset_import = ""
+ else:
+ opset_import = f', "qonnx.custom_op.general" : {vertest_ver}'
+ input = f"""
+ <
+ ir_version: 7,
+ opset_import: ["" : 9{opset_import}]
+ >
+ agraph (float{ishp_str} in0) => (float{oshp_str} out0)
+ {{
+ out0 = qonnx.custom_op.general.VerTestOp<
+ v{vertest_ver}_attr={vertest_ver}
+ >(in0)
+ }}
+ """
+ model = oprs.parse_model(input)
+ model = ModelWrapper(model)
+ return model
+
+
+def test_customop_version():
+ # Register test ops with the registry
+ # The _vN suffix will be automatically stripped to get op_type="VerTestOp"
+ add_op_to_domain("qonnx.custom_op.general", VerTestOp_v1)
+ add_op_to_domain("qonnx.custom_op.general", VerTestOp_v2)
+ add_op_to_domain("qonnx.custom_op.general", VerTestOp_v3)
+
+ # if onnx is lacking the opset import, getCustomOp with no version
+ # should return the highest available version
+ model = make_vertest_model(1, True)
+ inst = getCustomOp(model.graph.node[0])
+ # With no opset_import, getCustomOp(None) uses highest version -> v3
+ assert isinstance(inst, VerTestOp_v3)
+ # alternatively, when using ModelWrapper.get_customop_wrapper and onnx is
+ # lacking the opset import, should fall back to the specified version
+ inst = model.get_customop_wrapper(model.graph.node[0], fallback_customop_version=2)
+ assert isinstance(inst, VerTestOp_v2)
+
+ for ver in [1, 2, 3]:
+ model = make_vertest_model(ver, False)
+ # use ModelWrapper.get_customop_wrapper for implicit
+ # fetching of op version
+ inst = model.get_customop_wrapper(model.graph.node[0])
+ assert inst.get_nodeattr(f"v{ver}_attr") == ver
+ assert inst.onnx_opset_version == ver
+ # explicitly specify onnx_opset_version in getCustomOp
+ # note: new code should avoid calling getCustomOp directly like this
+ # and instead use ModelWrapper.get_customop_wrapper
+ inst = getCustomOp(model.graph.node[0], onnx_opset_version=ver)
+ assert inst.get_nodeattr(f"v{ver}_attr") == ver
+ assert inst.onnx_opset_version == ver
+ # getCustomOp with no version specified uses highest available
+ model = make_vertest_model(1, False)
+ inst = getCustomOp(model.graph.node[0])
+ assert isinstance(inst, VerTestOp_v3) # highest version
+ assert inst.onnx_opset_version == 3
+ # requesting v4 should return largest available version (v3 in this case)
+ model = make_vertest_model(3, False)
+ inst = getCustomOp(model.graph.node[0], onnx_opset_version=4)
+ assert isinstance(inst, VerTestOp_v3)
+ assert inst.onnx_opset_version == 3
diff --git a/tests/custom_op/test_floatquant.py b/tests/custom_op/test_floatquant.py
index c0f89cde..f792f793 100644
--- a/tests/custom_op/test_floatquant.py
+++ b/tests/custom_op/test_floatquant.py
@@ -168,7 +168,6 @@ def test_brevitas_vs_qonnx(data):
scale = 1.0
exponent_bias = compute_default_exponent_bias(exponent_bit_width)
max_val = compute_max_val(exponent_bit_width, mantissa_bit_width, exponent_bias)
- xq_t = brevitas_float_quant(x, bit_width, exponent_bit_width, mantissa_bit_width,
- exponent_bias, sign, max_val).numpy()
+ xq_t = brevitas_float_quant(x, bit_width, exponent_bit_width, mantissa_bit_width, exponent_bias, sign, max_val).numpy()
xq = qonnx_float_quant(x.numpy(), scale, exponent_bit_width, mantissa_bit_width, exponent_bias, sign, max_val)
np.testing.assert_array_equal(xq, xq_t)
diff --git a/tests/transformation/test_channelslast.py b/tests/transformation/test_channelslast.py
index 24e64b4f..92b4964e 100644
--- a/tests/transformation/test_channelslast.py
+++ b/tests/transformation/test_channelslast.py
@@ -32,9 +32,8 @@
import qonnx.core.onnx_exec as oxe
from qonnx.core.modelwrapper import ModelWrapper
-from qonnx.custom_op import channels_last
from qonnx.custom_op.channels_last.base_wrapped_op import to_channels_last_args
-from qonnx.custom_op.registry import getCustomOp
+from qonnx.custom_op.registry import get_ops_in_domain, getCustomOp, is_custom_op
from qonnx.transformation.channels_last import (
AbsorbChanFirstIntoMatMul,
InsertChannelsLastDomainsAndTrafos,
@@ -47,7 +46,6 @@
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast
from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit
-from qonnx.util.basic import is_finn_op
from qonnx.util.test import download_model, get_golden_in_and_output, test_model_details
from qonnx.util.to_channels_last import to_channels_last
@@ -92,7 +90,7 @@ def analysis_testing_for_chanlast_domain(model):
"BatchNormalization": 3,
}
# Check that all wrapped_ops in the registry have a definition here
- chanlast_op_types = list(channels_last.custom_op.keys())
+ chanlast_op_types = list([x[0] for x in get_ops_in_domain("qonnx.custom_op.channels_last")])
testable_op_types = list(ChanLast_node_types_and_min_dim_input.keys())
for op_name in chanlast_op_types:
assert (
@@ -126,7 +124,7 @@ def analysis_test_for_left_transposes(model, test_model, make_input_channels_las
def verify_all_nodes(model):
result = dict()
for n in model.graph.node:
- if is_finn_op(n.domain):
+ if is_custom_op(n.domain):
n_instance = getCustomOp(n)
verify_result = n_instance.verify_node()
result[n.name] = verify_result