Skip to content

Commit 985341e

Browse files
sayakpaulangelayisvekars
authored
feat: add a recipe on regional aot. (#3543)
I don't see an issue for this tutorial. I discussed this with @angelayi and we agreed that having a recipe like is beneficial to the AoT workflow. I ran the Python script introduced in this PR on a single RTX 4090, and I got: ``` Full model compilation time = 5.91 seconds Regional compilation time = 2.54 seconds ``` cc @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 --------- Co-authored-by: Angela Yi <angelayi@meta.com> Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent 4696f66 commit 985341e

File tree

2 files changed

+247
-0
lines changed

2 files changed

+247
-0
lines changed

recipes_index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,13 @@ from our full-length tutorials.
333333
:link: recipes/distributed_comm_debug_mode.html
334334
:tags: Distributed-Training
335335

336+
.. customcarditem::
337+
:header: Reducing AoT cold start compilation time with regional compilation
338+
:card_description: Learn how to use regional compilation to control AoT cold start compile time
339+
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
340+
:link: recipes/regional_aot.html
341+
:tags: Model-Optimization
342+
336343
.. End of tutorial card section
337344
338345
.. -----------------------------------------
@@ -378,6 +385,7 @@ from our full-length tutorials.
378385
recipes/torch_compile_caching_tutorial
379386
recipes/torch_compile_caching_configuration_tutorial
380387
recipes/regional_compilation
388+
recipes/regional_aot
381389
recipes/intel_extension_for_pytorch.html
382390
recipes/intel_neural_compressor_for_pytorch
383391
recipes/distributed_device_mesh

recipes_source/regional_aot.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
2+
"""
3+
Reducing AoT cold start compilation time with regional compilation
4+
============================================================================
5+
6+
**Author:** `Sayak Paul <https://github.com/sayakpaul>`_, `Charles Bensimon <https://github.com/cbensimon>`_, `Angela Yi <https://github.com/angelayi>`_
7+
8+
In the `regional compilation recipe <https://docs.pytorch.org/tutorials/recipes/regional_compilation.html>`__, we showed
9+
how to reduce cold start compilation times while retaining (almost) full compilation benefits. This was demonstrated for
10+
just-in-time (JIT) compilation.
11+
12+
This recipe shows how to apply similar principles when compiling a model ahead-of-time (AoT). If you
13+
are not familiar with AOTInductor and ``torch.export``, we recommend you to check out `this tutorial <https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html>`__.
14+
15+
Prerequisites
16+
----------------
17+
18+
* Pytorch 2.6 or later
19+
* Familiarity with regional compilation
20+
* Familiarity with AOTInductor and ``torch.export``
21+
22+
Setup
23+
-----
24+
Before we begin, we need to install ``torch`` if it is not already
25+
available.
26+
27+
.. code-block:: sh
28+
29+
pip install torch
30+
"""
31+
32+
######################################################################
33+
# Steps
34+
# -----
35+
#
36+
# In this recipe, we will follow the same steps as the regional compilation recipe mentioned above:
37+
#
38+
# 1. Import all necessary libraries.
39+
# 2. Define and initialize a neural network with repeated regions.
40+
# 3. Measure the compilation time of the full model and the regional compilation with AoT.
41+
#
42+
# First, let's import the necessary libraries for loading our data:
43+
#
44+
45+
import torch
46+
torch.set_grad_enabled(False)
47+
48+
from time import perf_counter
49+
50+
###################################################################################
51+
# Defining the Neural Network
52+
# ---------------------------
53+
#
54+
# We will use the same neural network structure as the regional compilation recipe.
55+
#
56+
# We will use a network, composed of repeated layers. This mimics a
57+
# large language model, that typically is composed of many Transformer blocks. In this recipe,
58+
# we will create a ``Layer`` using the ``nn.Module`` class as a proxy for a repeated region.
59+
# We will then create a ``Model`` which is composed of 64 instances of this
60+
# ``Layer`` class.
61+
#
62+
class Layer(torch.nn.Module):
63+
def __init__(self):
64+
super().__init__()
65+
self.linear1 = torch.nn.Linear(10, 10)
66+
self.relu1 = torch.nn.ReLU()
67+
self.linear2 = torch.nn.Linear(10, 10)
68+
self.relu2 = torch.nn.ReLU()
69+
70+
def forward(self, x):
71+
a = self.linear1(x)
72+
a = self.relu1(a)
73+
a = torch.sigmoid(a)
74+
b = self.linear2(a)
75+
b = self.relu2(b)
76+
return b
77+
78+
79+
class Model(torch.nn.Module):
80+
def __init__(self):
81+
super().__init__()
82+
self.linear = torch.nn.Linear(10, 10)
83+
self.layers = torch.nn.ModuleList([Layer() for _ in range(64)])
84+
85+
def forward(self, x):
86+
# In regional compilation, the self.linear is outside of the scope of ``torch.compile``.
87+
x = self.linear(x)
88+
for layer in self.layers:
89+
x = layer(x)
90+
return x
91+
92+
93+
##################################################################################
94+
# Compiling the model ahead-of-time
95+
# ---------------------------------
96+
#
97+
# Since we're compiling the model ahead-of-time, we need to prepare representative
98+
# input examples, that we expect the model to see during actual deployments.
99+
#
100+
# Let's create an instance of ``Model`` and pass it some sample input data.
101+
#
102+
103+
model = Model().cuda()
104+
input = torch.randn(10, 10, device="cuda")
105+
output = model(input)
106+
print(f"{output.shape=}")
107+
108+
###############################################################################################
109+
# Now, let's compile our model ahead-of-time. We will use ``input`` created above to pass
110+
# to ``torch.export``. This will yield a ``torch.export.ExportedProgram`` which we can compile.
111+
112+
path = torch._inductor.aoti_compile_and_package(
113+
torch.export.export(model, args=(input,))
114+
)
115+
116+
#################################################################
117+
# We can load from this ``path`` and use it to perform inference.
118+
119+
compiled_binary = torch._inductor.aoti_load_package(path)
120+
output_compiled = compiled_binary(input)
121+
print(f"{output_compiled.shape=}")
122+
123+
######################################################################################
124+
# Compiling _regions_ of the model ahead-of-time
125+
# ----------------------------------------------
126+
#
127+
# Compiling model regions ahead-of-time, on the other hand, requires a few key changes.
128+
#
129+
# Since the compute pattern is shared by all the blocks that
130+
# are repeated in a model (``Layer`` instances in this cases), we can just
131+
# compile a single block and let the inductor reuse it.
132+
133+
model = Model().cuda()
134+
path = torch._inductor.aoti_compile_and_package(
135+
torch.export.export(model.layers[0], args=(input,)),
136+
inductor_configs={
137+
# compile artifact w/o saving params in the artifact
138+
"aot_inductor.package_constants_in_so": False,
139+
}
140+
)
141+
142+
###################################################
143+
# An exported program (``torch.export.ExportedProgram``) contains the Tensor computation,
144+
# a ``state_dict`` containing tensor values of all lifted parameters and buffer alongside
145+
# other metadata. We specify the ``aot_inductor.package_constants_in_so`` to be ``False`` to
146+
# not serialize the model parameters in the generated artifact.
147+
#
148+
# Now, when loading the compiled binary, we can reuse the existing parameters of
149+
# each block. This lets us take advantage of the compiled binary obtained above.
150+
#
151+
152+
for layer in model.layers:
153+
compiled_layer = torch._inductor.aoti_load_package(path)
154+
compiled_layer.load_constants(
155+
layer.state_dict(), check_full_update=True, user_managed=True
156+
)
157+
layer.forward = compiled_layer
158+
159+
output_regional_compiled = model(input)
160+
print(f"{output_regional_compiled.shape=}")
161+
162+
#####################################################
163+
# Just like JIT regional compilation, compiling regions within a model ahead-of-time
164+
# leads to significantly reduced cold start times. The actual number will vary from
165+
# model to model.
166+
#
167+
# Even though full model compilation offers the fullest scope of optimizations,
168+
# for practical purposes and depending on the type of model, we have seen regional
169+
# compilation (both JiT and AoT) providing similar speed benefits, while drastically
170+
# reducing the cold start times.
171+
172+
###################################################
173+
# Measuring compilation time
174+
# --------------------------
175+
# Next, let's measure the compilation time of the full model and the regional compilation.
176+
#
177+
178+
def measure_compile_time(input, regional=False):
179+
start = perf_counter()
180+
model = aot_compile_load_model(regional=regional)
181+
torch.cuda.synchronize()
182+
end = perf_counter()
183+
# make sure the model works.
184+
_ = model(input)
185+
return end - start
186+
187+
def aot_compile_load_model(regional=False) -> torch.nn.Module:
188+
input = torch.randn(10, 10, device="cuda")
189+
model = Model().cuda()
190+
191+
inductor_configs = {}
192+
if regional:
193+
inductor_configs = {"aot_inductor.package_constants_in_so": False}
194+
195+
# Reset the compiler caches to ensure no reuse between different runs
196+
torch.compiler.reset()
197+
with torch._inductor.utils.fresh_inductor_cache():
198+
path = torch._inductor.aoti_compile_and_package(
199+
torch.export.export(
200+
model.layers[0] if regional else model,
201+
args=(input,)
202+
),
203+
inductor_configs=inductor_configs,
204+
)
205+
206+
if regional:
207+
for layer in model.layers:
208+
compiled_layer = torch._inductor.aoti_load_package(path)
209+
compiled_layer.load_constants(
210+
layer.state_dict(), check_full_update=True, user_managed=True
211+
)
212+
layer.forward = compiled_layer
213+
else:
214+
model = torch._inductor.aoti_load_package(path)
215+
return model
216+
217+
input = torch.randn(10, 10, device="cuda")
218+
full_model_compilation_latency = measure_compile_time(input, regional=False)
219+
print(f"Full model compilation time = {full_model_compilation_latency:.2f} seconds")
220+
221+
regional_compilation_latency = measure_compile_time(input, regional=True)
222+
print(f"Regional compilation time = {regional_compilation_latency:.2f} seconds")
223+
224+
assert regional_compilation_latency < full_model_compilation_latency
225+
226+
############################################################################
227+
# There may also be layers in a model incompatible with compilation. So,
228+
# full compilation will result in a fragmented computation graph resulting
229+
# in potential latency degradation. In these case, regional compilation
230+
# can be beneficial.
231+
#
232+
233+
############################################################################
234+
# Conclusion
235+
# -----------
236+
#
237+
# This recipe shows how to control the cold start time when compiling your
238+
# model ahead-of-time. This becomes effective when your model has repeated
239+
# blocks, which is typically seen in large generative models.

0 commit comments

Comments
 (0)