Skip to content

Commit 0c45ef9

Browse files
committed
reviewer feedback and fixes.
1 parent 4a870c8 commit 0c45ef9

File tree

1 file changed

+44
-32
lines changed

1 file changed

+44
-32
lines changed

recipes_source/regional_aot.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ def forward(self, x):
152152
)
153153
layer.forward = compiled_layer
154154

155+
output_regional_compiled = model(input)
156+
print(f"{output_regional_compiled.shape=}")
157+
155158
#####################################################
156159
# Just like JiT regional compilation, compiling regions within a model ahead-of-time
157160
# leads to significantly reduced cold start times. The actual number will vary from
@@ -171,56 +174,65 @@ def forward(self, x):
171174
# compilation.
172175

173176

174-
def measure_latency(fn, input):
175-
# Reset the compiler caches to ensure no reuse between different runs
176-
torch.compiler.reset()
177-
with torch._inductor.utils.fresh_inductor_cache():
178-
start = perf_counter()
179-
fn(input)
180-
torch.cuda.synchronize()
181-
end = perf_counter()
182-
return end - start
177+
def measure_compile_time(input, regional=False):
178+
start = perf_counter()
179+
model = aot_compile_load_model(regional=regional)
180+
torch.cuda.synchronize()
181+
end = perf_counter()
182+
# make sure the model works.
183+
_ = model(input)
184+
return end - start
183185

184-
def aot_compile_model(regional=False):
186+
def aot_compile_load_model(regional=False) -> torch.nn.Module:
185187
input = torch.randn(10, 10, device="cuda")
186188
model = Model().cuda()
187189

188190
inductor_configs = {}
189191
if regional:
190192
inductor_configs = {"aot_inductor.package_constants_in_so": False}
191-
path = torch._inductor.aoti_compile_and_package(
192-
torch.export.export(
193-
model.layers[0] if regional else model,
194-
args=(input,)
195-
),
196-
inductor_configs=inductor_configs,
197-
)
198-
199-
if regional:
200-
for layer in model.layers:
201-
compiled_layer = torch._inductor.aoti_load_package(path)
202-
compiled_layer.load_constants(
203-
layer.state_dict(), check_full_update=True, user_managed=True
204-
)
205-
layer.forward = compiled_layer
206-
else:
207-
compiled_layer = torch._inductor.aoti_load_package(path)
208193

194+
# Reset the compiler caches to ensure no reuse between different runs
195+
torch.compiler.reset()
196+
with torch._inductor.utils.fresh_inductor_cache():
197+
path = torch._inductor.aoti_compile_and_package(
198+
torch.export.export(
199+
model.layers[0] if regional else model,
200+
args=(input,)
201+
),
202+
inductor_configs=inductor_configs,
203+
)
204+
205+
if regional:
206+
for layer in model.layers:
207+
compiled_layer = torch._inductor.aoti_load_package(path)
208+
compiled_layer.load_constants(
209+
layer.state_dict(), check_full_update=True, user_managed=True
210+
)
211+
layer.forward = compiled_layer
212+
else:
213+
model = torch._inductor.aoti_load_package(path)
209214
return model
210215

211216
input = torch.randn(10, 10, device="cuda")
212-
full_model_compilation_latency = measure_latency(aot_compile_model(), input)
217+
full_model_compilation_latency = measure_compile_time(input, regional=False)
213218
print(f"Full model compilation time = {full_model_compilation_latency:.2f} seconds")
214219

215-
regional_compilation_latency = measure_latency(aot_compile_model(regional=True), input)
220+
regional_compilation_latency = measure_compile_time(input, regional=True)
216221
print(f"Regional compilation time = {regional_compilation_latency:.2f} seconds")
217222

218223
assert regional_compilation_latency < full_model_compilation_latency
219224

225+
############################################################################
226+
# There may also be layers in a model incompatible with compilation. So,
227+
# full compilation will result in a fragmented computation graph resulting
228+
# in potential latency degradation. In these case, regional compilation
229+
# can be beneficial.
230+
#
231+
220232
############################################################################
221233
# Conclusion
222234
# -----------
223235
#
224-
# This recipe shows how to control the cold start time when compiling your model ahead-of-time.
225-
# This becomes effective when your model has repeated blocks, like typically seen in large generative
226-
# models.
236+
# This recipe shows how to control the cold start time when compiling your
237+
# model ahead-of-time.This becomes effective when your model has repeated
238+
# blocks, like typically seen in large generative models.

0 commit comments

Comments
 (0)