Skip to content

Commit 434020f

Browse files
committed
update torch.compile tutorial
1 parent 9054234 commit 434020f

File tree

3 files changed

+69
-45
lines changed

3 files changed

+69
-45
lines changed

index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,13 @@ Welcome to PyTorch Tutorials
536536
:link: intermediate/torch_compile_tutorial.html
537537
:tags: Model-Optimization
538538

539+
.. customcarditem::
540+
:header: torch.compile End-to-End Tutorial
541+
:card_description: An example of applying torch.compile to a real model, demonstrating speedups.
542+
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
543+
:link: intermediate/torch_compile_full_example.html
544+
:tags: Model-Optimization
545+
539546
.. customcarditem::
540547
:header: Building a Convolution/Batch Norm fuser in torch.compile
541548
:card_description: Build a simple pattern matcher pass that fuses batch norm into convolution to improve performance during inference.

intermediate_source/torch_compile_full_example.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,40 @@
66
**Author:** William Wen
77
"""
88

9-
import warnings
10-
119
######################################################################
1210
# ``torch.compile`` is the new way to speed up your PyTorch code!
1311
# ``torch.compile`` makes PyTorch code run faster by
1412
# JIT-compiling PyTorch code into optimized kernels,
1513
# while requiring minimal code changes.
1614
#
1715
# This tutorial covers an end-to-end example of training and evaluating a
18-
# real model with ``torch.compile``. For a gentler introduction to ``torch.compile``,
19-
# please check out our ```torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__.
16+
# real model with ``torch.compile``. For a gentle introduction to ``torch.compile``,
17+
# please check out `the introduction to ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__.
2018
#
2119
# **Required pip Dependencies**
2220
#
2321
# - ``torch >= 2.0``
2422
# - ``torchvision``
23+
#
24+
# .. grid:: 2
25+
#
26+
# .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
27+
# :class-card: card-prerequisites
28+
#
29+
# * How to apply ``torch.compile`` to a real model
30+
# * ``torch.compile`` speedups on a real model
31+
# * ``torch.compile``'s first few iterations are expected to be slower due to compilation overhead
32+
#
33+
# .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
34+
# :class-card: card-prerequisites
35+
#
36+
# * `Introduction to ``torch.compile`` <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__
2537

2638
# NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in
2739
# order to reproduce the speedup numbers shown below and documented elsewhere.
2840

2941
import torch
42+
import warnings
3043

3144
gpu_ok = False
3245
if torch.cuda.is_available():
@@ -88,7 +101,10 @@ def init_model():
88101

89102
model = init_model()
90103

91-
model_opt = torch.compile(model, mode="reduce-overhead")
104+
# Note that we generally recommend directly compiling a torch.nn.Module by calling
105+
# its .compile() method.
106+
model_opt = init_model()
107+
model_opt.compile(mode="reduce-overhead")
92108

93109
inp = generate_data(16)[0]
94110
with torch.no_grad():
@@ -175,6 +191,9 @@ def train(mod, data):
175191

176192
model = init_model()
177193
opt = torch.optim.Adam(model.parameters())
194+
195+
# Note that because we are compiling a regular Python function, we do not
196+
# call any .compile() method.
178197
train_opt = torch.compile(train, mode="reduce-overhead")
179198

180199
compile_times = []
@@ -202,3 +221,20 @@ def train(mod, data):
202221
# We remark that the speedup numbers presented in this tutorial are for
203222
# demonstration purposes only. Official speedup values can be seen at the
204223
# `TorchInductor performance dashboard <https://hud.pytorch.org/benchmark/compilers>`__.
224+
225+
######################################################################
226+
# Conclusion
227+
# ------------
228+
#
229+
# In this tutorial, we applied ``torch.compile`` to training and inference on a real model,
230+
# demonstrating speedups.
231+
#
232+
# Importantly, we note that the first few iterations of a compiled model
233+
# are slower than eager mode due to compilation overhead, but subsequent iterations are expected to
234+
# have speedups.
235+
#
236+
# For a gentle introduction to ``torch.compile``, please check out `the introduction to ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__.
237+
#
238+
# To troubleshoot issues and to gain a deeper understanding of how to apply ``torch.compile`` to your code, check out `the ``torch.compile`` programming model <https://docs.pytorch.org/docs/main/compile/programming_model.html>`__.
239+
#
240+
# We hope that you will give ``torch.compile`` a try!

intermediate_source/torch_compile_tutorial.py

Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,17 @@
66
**Author:** William Wen
77
"""
88

9+
# sphinx_gallery_start_ignore
10+
# to clear torch logs format
11+
import torch
12+
import os
13+
os.environ["TORCH_LOGS_FORMAT"] = ""
14+
torch._logging._internal.DEFAULT_FORMATTER = (
15+
torch._logging._internal._default_formatter()
16+
)
17+
torch._logging._internal._init_logs()
18+
# sphinx_gallery_end_ignore
19+
920
######################################################################
1021
# ``torch.compile`` is the new way to speed up your PyTorch code!
1122
# ``torch.compile`` makes PyTorch code run faster by
@@ -27,6 +38,8 @@
2738
#
2839
# For an end-to-end example on a real model, check out our `end-to-end ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_full_example.html>`__.
2940
#
41+
# To troubleshoot issues and to gain a deeper understanding of how to apply ``torch.compile`` to your code, check out `the ``torch.compile`` programming model <https://docs.pytorch.org/docs/main/compile/programming_model.html>`__.
42+
#
3043
# **Contents**
3144
#
3245
# .. contents::
@@ -128,7 +141,7 @@ def forward(self, x):
128141
# -----------------------
129142
#
130143
# Now let's demonstrate how ``torch.compile`` speeds up a simple PyTorch example.
131-
# For a demonstration on a more complex model, see <TODO link>.
144+
# For a demonstration on a more complex model, see our `end-to-end ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_full_example.html>`__.
132145

133146

134147
def foo3(x):
@@ -161,7 +174,7 @@ def timed(fn):
161174
######################################################################
162175
# Notice that ``torch.compile`` appears to take a lot longer to complete
163176
# compared to eager. This is because ``torch.compile`` takes extra time to compile
164-
# the model on the first execution.
177+
# the model on the first few executions.
165178
# ``torch.compile`` re-uses compiled code whever possible,
166179
# so if we run our optimized model several more times, we should
167180
# see a significant improvement compared to eager.
@@ -281,43 +294,6 @@ def f2(x, y):
281294
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
282295
print("~" * 10)
283296

284-
######################################################################
285-
# Another case that ``torch.compile`` handles well compared to
286-
# both TorchScript tracing and scripting is the usage of third-party library functions.
287-
288-
import scipy
289-
290-
291-
def f3(x):
292-
x = x * 2
293-
x = scipy.fft.dct(x.numpy())
294-
x = torch.from_numpy(x)
295-
x = x * 2
296-
return x
297-
298-
299-
######################################################################
300-
# TorchScript tracing treats results from non-PyTorch function calls
301-
# as constants, and so our results can be silently wrong.
302-
# TorchScript scripting disallows non-PyTorch function calls.
303-
# On the other hand, ``torch.compile`` is easily able to handle
304-
# the non-PyTorch function call.
305-
306-
307-
inp1 = torch.randn(5, 5)
308-
inp2 = torch.randn(5, 5)
309-
traced_f3 = torch.jit.trace(f3, (inp1,))
310-
print("traced 3:", test_fns(f3, traced_f3, (inp2,)))
311-
312-
try:
313-
torch.jit.script(f3)
314-
except:
315-
tb.print_exc()
316-
317-
compile_f3 = torch.compile(f3)
318-
print("compile 3:", test_fns(f3, compile_f3, (inp2,)))
319-
320-
321297
######################################################################
322298
# Graph Breaks
323299
# ------------------------------------
@@ -418,6 +394,9 @@ def false_branch(y):
418394
# One important restriction is that ``torch.export`` does not support graph breaks. Please check
419395
# `this tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`__
420396
# for more details on ``torch.export``.
397+
#
398+
# Check out our `section on graph breaks in the ``torch.compile`` programming model <https://docs.pytorch.org/docs/main/compile/programming_model.graph_breaks_index.html>`__
399+
# for tips on how to work around graph breaks.
421400

422401
######################################################################
423402
# Troubleshooting
@@ -427,7 +406,7 @@ def false_branch(y):
427406
# Are you looking for tips on how to best use ``torch.compile``?
428407
# Or maybe you simply want to learn more about the inner workings of ``torch.compile``?
429408
#
430-
# Check out `the ``torch.compile`` troubleshooting guide <https://pytorch.org/docs/stable/torch.compiler_troubleshooting.html>`__!
409+
# Check out `the ``torch.compile`` programming model <https://docs.pytorch.org/docs/main/compile/programming_model.html>`__.
431410

432411
######################################################################
433412
# Conclusion
@@ -439,4 +418,6 @@ def false_branch(y):
439418
#
440419
# For an end-to-end example on a real model, check out our `end-to-end ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_full_example.html>`__.
441420
#
421+
# To troubleshoot issues and to gain a deeper understanding of how to apply ``torch.compile`` to your code, check out `the ``torch.compile`` programming model <https://docs.pytorch.org/docs/main/compile/programming_model.html>`__.
422+
#
442423
# We hope that you will give ``torch.compile`` a try!

0 commit comments

Comments
 (0)