Skip to content

Commit d72fb8c

Browse files
authored
Merge pull request #209 from fastmachinelearning/feature/opversion_and_trunc_v2
Introduce v1 and v2 for Trunc
2 parents 0bc01a4 + a4c3741 commit d72fb8c

File tree

9 files changed

+265
-15
lines changed

9 files changed

+265
-15
lines changed

README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
<img align="left" src="https://xilinx.github.io/finn/img/TFC_1W2A.onnx.png" alt="QONNX example" style="margin-right: 20px" width="200"/>
1212

1313

14-
QONNX (Quantized ONNX) introduces several custom operators -- [`IntQuant`](docs/qonnx-custom-ops/intquant_op.md), [`FloatQuant`](docs/qonnx-custom-ops/floatquant_op.md), [`BipolarQuant`](docs/qonnx-custom-ops/bipolar_quant_op.md), and [`Trunc`](docs/qonnx-custom-ops/trunc_op.md) -- in order to represent arbitrary-precision integer and minifloat quantization in ONNX. This enables:
14+
QONNX (Quantized ONNX) introduces several [custom operators](docs/qonnx-custom-ops/overview.md) -- `IntQuant`, `FloatQuant`, `BipolarQuant`, and `Trunc` -- in order to represent arbitrary-precision integer and minifloat quantization in ONNX. This enables:
1515
* Representation of binary, ternary, 3-bit, 4-bit, 6-bit or any other integer/fixed-point quantization.
1616
* Representation of minifloat quantization with configurable exponent and mantissa bits.
1717
* Quantization is an operator itself, and can be applied to any parameter or layer input.
@@ -29,9 +29,7 @@ This repository contains a set of Python utilities to work with QONNX models, in
2929

3030
### Operator definitions
3131

32-
* [Quant](docs/qonnx-custom-ops/quant_op.md) for 2-to-arbitrary-bit quantization, with scaling and zero-point
33-
* [BipolarQuant](docs/qonnx-custom-ops/bipolar_quant_op.md) for 1-bit (bipolar) quantization, with scaling and zero-point
34-
* [Trunc](docs/qonnx-custom-ops/trunc_op.md) for truncating to a specified number of bits, with scaling and zero-point
32+
Please see the [custom operator overview](docs/qonnx-custom-ops/overview.md) table for more details.
3533

3634
### Installation
3735

docs/qonnx-custom-ops/bipolar_quant_op.md renamed to docs/qonnx-custom-ops/bipolarquant_v1.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Additionally, takes one float as input, which define the scaling.
55

66
#### Version
77

8-
This operator is not part of the ONNX standard and is not currently versioned.
8+
The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.
99

1010
#### Attributes
1111

docs/qonnx-custom-ops/floatquant_op.md renamed to docs/qonnx-custom-ops/floatquant_v1.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ special (symbolic) values. This makes it nontrivial to infer the maximum represe
1616

1717
#### Version
1818

19-
This operator is not part of the ONNX standard and is not currently versioned.
19+
The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.
2020

2121
#### Attributes
2222

docs/qonnx-custom-ops/intquant_op.md renamed to docs/qonnx-custom-ops/intquant_v1.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ rounding_mode defines how quantized values are rounded.
99

1010
Notes:
1111
* This operator was previously named `Quant` but is renamed to `IntQuant` to distinguish it from `FloatQuant`. For a transition period, qonnx will transparently handle `Quant` as `IntQuant` for backwards compatibility reasons, but only `IntQuant` should be used for new models.
12-
* This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists.
12+
* This operator does not work for binary or bipolar quantization, for this purpose the simpler `BipolarQuant` node exists.
1313

1414
#### Version
1515

16-
This operator is not part of the ONNX standard and is not currently versioned.
16+
The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.
1717

1818
#### Attributes
1919

docs/qonnx-custom-ops/overview.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
## Operator Schemas
2+
3+
This file lists the QONNX custom operators, similar to `Operators.md` for the ONNX standard.
4+
It is manually updated, since QONNX custom operators are relatively few in number.
5+
6+
### qonnx.custom_op.general
7+
8+
|**Operator**|**Since version**||
9+
|-|-|-|
10+
|<a href="bipolarquant_v1.md">BipolarQuant</a>|<a href="bipolarquant_v1.md">1</a>|
11+
|<a href="floatquant_v1.md">FloatQuant</a>|<a href="floatquant_v1.md">1</a>|
12+
|<a href="intquant_v1.md">IntQuant</a>|<a href="intquant_v1.md">1</a>|
13+
|<a href="trunc_v2.md">Trunc</a>|<a href="trunc_v2.md">2</a>, <a href="trunc_v1.md">1</a>|

docs/qonnx-custom-ops/trunc_op.md renamed to docs/qonnx-custom-ops/trunc_v1.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ The attribute rounding_mode defines how truncated values are rounded.
66

77
#### Version
88

9-
This operator is not part of the ONNX standard and is not currently versioned.
9+
The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.
1010

1111
#### Attributes
1212

docs/qonnx-custom-ops/trunc_v2.md

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
### <a name="Trunc"></a><a name="abs">**Trunc**</a>
2+
3+
Truncates the values of one input data (Tensor<T>) at a specified bitwidth and produces one output data (Tensor<T>).
4+
Additionally, takes four float tensors as input, which define the scale, zero-point, input bit-width and output bit-width of the quantization.
5+
The attribute rounding_mode defines how truncated values are rounded.
6+
7+
#### Version
8+
9+
This operator is not part of the ONNX standard.
10+
The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 2.
11+
12+
#### Attributes
13+
14+
<dl>
15+
<dt><tt>rounding_mode</tt> : string (default is "FLOOR")</dt>
16+
<dd>Defines how rounding should be applied during truncation. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".</dd>
17+
<dt><tt>signed</tt> : int (default is 1)</dt>
18+
<dd>Defines if the quantization includes a signed bit. E.g. at 8b unsigned=[0, 255] vs signed=[-128, 127].</dd>
19+
<dt><tt>narrow</tt> : int (default is 0)</dt>
20+
<dd>Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].</dd>
21+
</dl>
22+
23+
#### Inputs
24+
25+
<dl>
26+
<dt><tt>X</tt> (differentiable) : tensor(float32)</dt>
27+
<dd>input tensor to truncate</dd>
28+
<dt><tt>scale</tt> : float32</dt>
29+
<dd>The scale factor at the input of the truncation</dd>
30+
<dt><tt>zeropt</tt> : float32</dt>
31+
<dd>The zero-point at the input of the truncation</dd>
32+
<dt><tt>in_bitwidth</tt> : int32</dt>
33+
<dd>The number of bits used at the input of the truncation</dd>
34+
<dt><tt>out_scale</tt> : float32</dt>
35+
<dd>The scale factor of the output of the truncation</dd>
36+
<dt><tt>out_bitwidth</tt> : int32</dt>
37+
<dd>The number of bits used at the output of the truncation</dd>
38+
</dl>
39+
40+
41+
#### Outputs
42+
43+
<dl>
44+
<dt><tt>Y</tt> (differentiable) : tensor(float32)</dt>
45+
<dd>Output tensor</dd>
46+
</dl>
47+
48+
49+
#### Examples
50+
<details>
51+
<summary>Trunc</summary>
52+
53+
```python
54+
from onnx import helper
55+
import numpy as np
56+
57+
# Define node settings and input
58+
x = np.random.randn(100).astype(np.float32)*10.
59+
scale = np.array(1.)
60+
zeropt = np.array(0.)
61+
in_bitwidth = np.array(10)
62+
out_bitwidth = np.array(4)
63+
rounding_mode = "ROUND"
64+
65+
# Create node
66+
node = helper.make_node(
67+
'Trunc',
68+
domain='finn.custom_op.general',
69+
inputs=['x', 'scale', 'zeropt', 'in_bitwidth', 'out_bitwidth'],
70+
outputs=['y'],
71+
rounding_mode=rounding_mode,
72+
)
73+
74+
# Execute the same settings with the reference implementation (trunc)
75+
# See the sample implementation for more details on trunc.
76+
output_ref = trunc(inp_tensor, scale, zeropt, in_bitwidth, out_bitwidth, rounding_mode)
77+
78+
# Execute node and compare
79+
expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_trunc')
80+
81+
```
82+
83+
</details>
84+
85+
86+
#### Sample Implementation
87+
88+
<details>
89+
<summary>Trunc</summary>
90+
91+
```python
92+
# SPDX-License-Identifier: Apache-2.0
93+
94+
from __future__ import absolute_import
95+
from __future__ import division
96+
from __future__ import print_function
97+
from __future__ import unicode_literals
98+
99+
import numpy as np
100+
101+
def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):
102+
103+
# Scaling
104+
y = inp_tensor / scale
105+
y = y + zeropt
106+
# Rounding
107+
y = np.round(y)
108+
# Rescale
109+
trunc_scale = 2 ** np.round(
110+
np.log2(output_scale / scale)
111+
) # Trunc scale should be a power-of-two - ensure that is the case
112+
y = y / trunc_scale
113+
114+
# Clamping
115+
min_int_val = min_int(signed, narrow, output_bit_width)
116+
max_int_val = max_int(signed, narrow, output_bit_width)
117+
y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y)
118+
y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y)
119+
# To int (truncate)
120+
rounding_fx = resolve_rounding_mode(rounding_mode)
121+
y = rounding_fx(y)
122+
123+
# Rescale
124+
output_zeropt = zeropt / trunc_scale # Rescale zero-point
125+
y = y - output_zeropt
126+
y = y * output_scale
127+
128+
return y
129+
130+
def resolve_rounding_mode(mode_string):
131+
"""Resolve the rounding mode string of Quant and Trunc ops
132+
to the corresponding numpy functions."""
133+
if mode_string == "ROUND":
134+
return np.round
135+
elif mode_string == "CEIL":
136+
return np.ceil
137+
elif mode_string == "FLOOR":
138+
return np.floor
139+
else:
140+
raise ValueError(f"Could not resolve rounding mode called: {mode_string}")
141+
142+
```
143+
144+
</details>

src/qonnx/custom_op/general/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from qonnx.custom_op.general.multithreshold import MultiThreshold
3838
from qonnx.custom_op.general.quant import Quant
3939
from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d
40-
from qonnx.custom_op.general.trunc import Trunc
40+
from qonnx.custom_op.general.trunc import Trunc_v1, Trunc_v2
4141
from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul
4242

4343
__all__ = [
@@ -51,6 +51,7 @@
5151
"MultiThreshold",
5252
"Quant",
5353
"QuantAvgPool2d",
54-
"Trunc",
54+
"Trunc_v1",
55+
"Trunc_v2",
5556
"XnorPopcountMatMul",
5657
]

src/qonnx/custom_op/general/trunc.py

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,99 @@
3131

3232
from qonnx.core.datatype import DataType
3333
from qonnx.custom_op.base import CustomOp
34-
from qonnx.custom_op.general.quant import resolve_rounding_mode
34+
from qonnx.custom_op.general.quant import max_int, min_int, resolve_rounding_mode
35+
from qonnx.util.basic import get_preferred_qonnx_opset
3536

3637

37-
def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
38+
def trunc_v2(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):
39+
# Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR
40+
41+
# Scaling
42+
y = inp_tensor / scale
43+
y = y + zeropt
44+
# Rounding
45+
y = np.round(y)
46+
# Rescale
47+
trunc_scale = 2 ** np.round(
48+
np.log2(output_scale / scale)
49+
) # Trunc scale should be a power-of-two - ensure that is the case
50+
y = y / trunc_scale
51+
52+
# Clamping
53+
min_int_val = min_int(signed, narrow, output_bit_width)
54+
max_int_val = max_int(signed, narrow, output_bit_width)
55+
y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y)
56+
y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y)
57+
# To int (truncate)
58+
rounding_fx = resolve_rounding_mode(rounding_mode)
59+
y = rounding_fx(y)
60+
61+
# Rescale
62+
output_zeropt = zeropt / trunc_scale # Rescale zero-point
63+
y = y - output_zeropt
64+
y = y * output_scale
65+
66+
return y
67+
68+
69+
class Trunc_v2(CustomOp):
70+
"""Generic truncation operation for QONNX. Takes four inputs:
71+
- input tensor to truncate
72+
- the scale
73+
- the zero-point
74+
- the truncation scale
75+
- the truncation bit-width
76+
77+
The output is a tensor of the same shape as the input tensor, with truncated
78+
values.
79+
"""
80+
81+
def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()):
82+
super().__init__(onnx_node, onnx_opset_version)
83+
# override any specified opset version, this instance is v2
84+
self.onnx_opset_version = 2
85+
86+
def get_nodeattr_types(self):
87+
return {
88+
# The rounding mode, which is used for the trunc function
89+
"rounding_mode": ("s", True, "FLOOR"),
90+
"narrow": ("i", False, 0, {0, 1}),
91+
"signed": ("i", False, 1, {0, 1}),
92+
}
93+
94+
def make_shape_compatible_op(self, model):
95+
node = self.onnx_node
96+
return helper.make_node("Identity", [node.input[0]], [node.output[0]])
97+
98+
def infer_node_datatype(self, model):
99+
node = self.onnx_node
100+
model.set_tensor_datatype(node.output[0], DataType["FLOAT32"])
101+
102+
def execute_node(self, context, graph):
103+
node = self.onnx_node
104+
# save inputs
105+
inp_tensor = context[node.input[0]]
106+
scale = context[node.input[1]]
107+
zeropt = context[node.input[2]]
108+
input_bit_width = context[node.input[3]]
109+
output_scale = context[node.input[4]]
110+
output_bit_width = context[node.input[5]]
111+
# save attributes
112+
rounding_mode = self.get_nodeattr("rounding_mode")
113+
narrow = self.get_nodeattr("narrow")
114+
signed = self.get_nodeattr("signed")
115+
# calculate output
116+
ret = trunc_v2(
117+
inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode
118+
)
119+
# set context according to output name
120+
context[node.output[0]] = ret
121+
122+
def verify_node(self):
123+
pass
124+
125+
126+
def trunc_v1(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
38127
# Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR
39128

40129
# Scaling
@@ -58,7 +147,7 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding
58147
return y
59148

60149

61-
class Trunc(CustomOp):
150+
class Trunc_v1(CustomOp):
62151
"""Generic truncation operation for QONNX. Takes four inputs:
63152
- input tensor to truncate
64153
- the scale
@@ -69,6 +158,11 @@ class Trunc(CustomOp):
69158
values.
70159
"""
71160

161+
def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()):
162+
super().__init__(onnx_node, onnx_opset_version)
163+
# override any specified opset version, this instance is v1
164+
self.onnx_opset_version = 1
165+
72166
def get_nodeattr_types(self):
73167
return {
74168
# The rounding mode, which is used for the trunc function
@@ -94,7 +188,7 @@ def execute_node(self, context, graph):
94188
# save attributes
95189
rounding_mode = self.get_nodeattr("rounding_mode")
96190
# calculate output
97-
ret = trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode)
191+
ret = trunc_v1(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode)
98192
# set context according to output name
99193
context[node.output[0]] = ret
100194

0 commit comments

Comments
 (0)