Skip to content

Commit 99de3c7

Browse files
authored
Adds one example to the documentation (#313)
* add an example about gemm * add an example * add version * fixes
1 parent d75d0aa commit 99de3c7

File tree

2 files changed

+218
-0
lines changed

2 files changed

+218
-0
lines changed
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""
2+
.. _l-plot-gemm-or-matmul-add:
3+
4+
====================
5+
Gemm or Matmul + Add
6+
====================
7+
8+
Order of computation matters. ``1 + 1e-20 - 1 != 1 - 1 + 1e-20`` if the
9+
precision of the computation is taken into account.
10+
What an operator Gemm in :epkg:`onnxruntime`, the most simple
11+
way to represent a linear neural layer.
12+
13+
A model with three choices
14+
==========================
15+
"""
16+
17+
import cpuinfo
18+
import numpy as np
19+
import pandas
20+
import matplotlib.pyplot as plt
21+
import onnx
22+
import onnx.helper as oh
23+
import torch
24+
from onnx_diagnostic.helpers import max_diff
25+
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
26+
from onnx_diagnostic.reference import OnnxruntimeEvaluator
27+
from onnxruntime import (
28+
InferenceSession,
29+
SessionOptions,
30+
__version__ as version_onnxruntime,
31+
GraphOptimizationLevel,
32+
)
33+
34+
print(f"onnxruntime version = {version_onnxruntime}")
35+
print(f"cpu name = {cpuinfo.get_cpu_info()['brand_raw']}")
36+
if torch.cuda.is_available():
37+
print(f"gpu name = {torch.cuda.get_device_name(0)}")
38+
print(f"cuda version = {torch.version.cuda}")
39+
40+
# %%
41+
# The version is important. Numerical differences are observed
42+
# with onnxruntime<=1.22. Let's see how to make them happen.
43+
44+
45+
def make_model_gemm(itype: int) -> onnx.ModelProto:
46+
return oh.make_model(
47+
oh.make_graph(
48+
[
49+
oh.make_node("Gemm", ["A", "X", "B"], ["GemmOnly"]),
50+
oh.make_node("Gemm", ["A", "X"], ["gmm"]),
51+
oh.make_node("Add", ["gmm", "B"], ["GemmAdd"]),
52+
oh.make_node("MatMul", ["A", "X"], ["mm"]),
53+
oh.make_node("Add", ["mm", "B"], ["MatMulAdd"]),
54+
oh.make_node("FusedMatMul", ["A", "X"], ["fmm"], domain="com.microsoft"),
55+
oh.make_node("Add", ["fmm", "B"], ["FusedMatMulAdd"]),
56+
],
57+
"test",
58+
[
59+
oh.make_tensor_value_info("A", itype, ["a", "b"]),
60+
oh.make_tensor_value_info("X", itype, ["b", "c"]),
61+
oh.make_tensor_value_info("B", itype, ["c"]),
62+
],
63+
[
64+
oh.make_tensor_value_info("GemmOnly", itype, ["a", "c"]),
65+
oh.make_tensor_value_info("GemmAdd", itype, ["a", "c"]),
66+
oh.make_tensor_value_info("FusedMatMulAdd", itype, ["a", "c"]),
67+
oh.make_tensor_value_info("MatMulAdd", itype, ["a", "c"]),
68+
],
69+
),
70+
opset_imports=[oh.make_opsetid("", 22)],
71+
ir_version=10,
72+
)
73+
74+
75+
def matrix_diff(tensors):
76+
mat = np.zeros((len(tensors), len(tensors)), dtype=np.float32)
77+
for i, t in enumerate(tensors):
78+
for j in range(i + 1, len(tensors)):
79+
mat[i, j] = max_diff(t, tensors[j])["abs"]
80+
mat[j, i] = mat[i, j]
81+
return mat
82+
83+
84+
itype = onnx.TensorProto.FLOAT16
85+
dtype = np.float16
86+
model = make_model_gemm(itype)
87+
88+
A = np.random.randn(512, 256).astype(dtype)
89+
X = np.random.randn(256, 256).astype(dtype)
90+
B = np.random.randn(256).astype(dtype)
91+
feeds = dict(A=A, X=X, B=B)
92+
93+
# %%
94+
# We disable all the optimization made by onnxruntime to make
95+
# the computation follows what we want to verify.
96+
opts = SessionOptions()
97+
opts.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
98+
opts.optimized_model_filepath = "plot_gemm_or_matmul.optimized.onnx"
99+
sess = InferenceSession(model.SerializeToString(), opts, providers=["CPUExecutionProvider"])
100+
results = [A @ X + B, *sess.run(None, feeds)]
101+
diffs = matrix_diff(results)
102+
103+
print(diffs)
104+
105+
# %%
106+
onx = onnx.load(opts.optimized_model_filepath)
107+
print(pretty_onnx(onx))
108+
109+
# %%
110+
# It seems some cast were still inserted.
111+
112+
# %%
113+
# Let's try with CUDA and float32 if it is available.
114+
115+
A = torch.randn((512, 512), dtype=torch.float32)
116+
X = torch.randn((512, 512), dtype=torch.float32)
117+
B = torch.randn((512), dtype=torch.float32)
118+
119+
for itype, dtype, device in [
120+
(onnx.TensorProto.FLOAT16, torch.float16, "cpu"),
121+
(onnx.TensorProto.FLOAT, torch.float32, "cpu"),
122+
(onnx.TensorProto.FLOAT16, torch.float16, "cuda"),
123+
(onnx.TensorProto.FLOAT, torch.float32, "cuda"),
124+
]:
125+
if device == "cuda" and not torch.cuda.is_available():
126+
continue
127+
a = A.to(dtype).to(device)
128+
x = X.to(dtype).to(device)
129+
b = B.to(dtype).to(device)
130+
feeds = dict(A=a, X=x, B=b)
131+
model = make_model_gemm(itype)
132+
133+
sess = OnnxruntimeEvaluator(model, whole=True)
134+
results = sess.run(None, feeds)
135+
diffs = matrix_diff(results)
136+
print(f"------ dtype={dtype}, device={device!r}")
137+
print(diffs)
138+
139+
# %%
140+
# A weird bias
141+
# ============
142+
#
143+
# In the previous example, the coefficients of the bias
144+
# are similar to the others coefficients. What if we make them
145+
# a lot higher.
146+
147+
B = (torch.arange(512, dtype=torch.float32) + 1) / 512 * 16384
148+
labels = ["linear", *[o.name for o in model.graph.output], "a @ x + b"]
149+
all_results = {}
150+
151+
for itype, dtype, device in [
152+
(onnx.TensorProto.FLOAT, torch.float32, "cpu"),
153+
(onnx.TensorProto.FLOAT16, torch.float16, "cpu"),
154+
# missing implementation in onnxruntime
155+
# (onnx.TensorProto.BFLOAT16, torch.bfloat16, "cpu"),
156+
(onnx.TensorProto.FLOAT, torch.float32, "cuda"),
157+
(onnx.TensorProto.FLOAT16, torch.float16, "cuda"),
158+
(onnx.TensorProto.BFLOAT16, torch.bfloat16, "cuda"),
159+
]:
160+
if device == "cuda" and not torch.cuda.is_available():
161+
continue
162+
a = A.to(dtype).to(device)
163+
x = X.to(dtype).to(device)
164+
b = B.to(dtype).to(device)
165+
feeds = dict(A=a, X=x, B=b)
166+
model = make_model_gemm(itype)
167+
168+
filename = f"plot_gemm_or_matmul.{itype}.{device}.onnx"
169+
sess = OnnxruntimeEvaluator(
170+
model,
171+
whole=True,
172+
graph_optimization_level=GraphOptimizationLevel.ORT_DISABLE_ALL,
173+
optimized_model_filepath=filename,
174+
)
175+
results = [torch.nn.functional.linear(a, x.T, b), *sess.run(None, feeds), a @ x + b]
176+
all_results[device, dtype] = results
177+
has_cast = "Cast" in [n.op_type for n in onnx.load(filename).graph.node]
178+
diffs = matrix_diff(results)
179+
df = pandas.DataFrame(diffs, columns=labels, index=labels)
180+
print(f"------ has_cast={has_cast}, dtype={dtype}, device={device!r}, max(b)={b.max()}")
181+
print(df)
182+
183+
# %%
184+
# Cast is inserted on CPU because some kernel are not available for
185+
# float16. Even though, we can see huge discrepancies happening.
186+
#
187+
# bias value vs discrepancies
188+
# ===========================
189+
#
190+
# Let's compare GemmOnly (so bias is included) and Gemm+Add.
191+
192+
i, j = 1, -1
193+
labs = labels[i], labels[j]
194+
195+
fig, ax = plt.subplots(len(all_results), 2, figsize=(8, 2.5 * len(results)))
196+
for pos, ((device, dtype), results) in enumerate(all_results.items()):
197+
m1, m2 = results[i], results[j]
198+
diff = torch.abs(m1.to(torch.float32) - m2.to(torch.float32)).max(dim=0)[0]
199+
print(f"labels={labs}, {device}/{dtype}: max(diff)={diff.max()}")
200+
expand = 0.5 if diff.max() >= 1 else diff.max().detach().cpu() / 2
201+
ax[pos, 0].plot(B.tolist(), (diff.detach().cpu() + torch.rand(512) * expand).tolist(), ".")
202+
ax[pos, 0].set_title(f"{labs[0]}-{labs[1]} {device}/{dtype}")
203+
204+
corr = matrix_diff(results)
205+
ax[pos, 1].imshow(corr, cmap="Blues", vmin=0, vmax=corr.max())
206+
# ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}')
207+
ax[pos, 1].set_xticks(range(len(labels)), labels, rotation=45)
208+
ax[pos, 1].set_yticks(range(len(labels)), labels)
209+
ax[pos, 1].set_title(f"max={diff.max()}")
210+
fig.tight_layout()
211+
fig.savefig("plot_gemm_or_matmul_add.png")
212+
213+
# %%
214+
# Discrepancies do not happen all the time but it is very likely to happen.
215+
# The use of Gemm with a bias not null should be used when torch is doing
216+
# the same and it seems to depend on the type as well.
217+
# The difference is even higher for bfloat16.

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ packaging
1414
pandas
1515
Pillow
1616
psutil
17+
py-cpuinfo
1718
pytest
1819
pytest-coverage
1920
pytest-subtests

0 commit comments

Comments
 (0)