@@ -152,6 +152,9 @@ def forward(self, x):
152152 )
153153 layer .forward = compiled_layer
154154
155+ output_regional_compiled = model (input )
156+ print (f"{ output_regional_compiled .shape = } " )
157+
155158#####################################################
156159# Just like JiT regional compilation, compiling regions within a model ahead-of-time
157160# leads to significantly reduced cold start times. The actual number will vary from
@@ -171,56 +174,65 @@ def forward(self, x):
171174# compilation.
172175
173176
174- def measure_latency (fn , input ):
175- # Reset the compiler caches to ensure no reuse between different runs
176- torch .compiler .reset ()
177- with torch ._inductor .utils .fresh_inductor_cache ():
178- start = perf_counter ()
179- fn (input )
180- torch .cuda .synchronize ()
181- end = perf_counter ()
182- return end - start
177+ def measure_compile_time (input , regional = False ):
178+ start = perf_counter ()
179+ model = aot_compile_load_model (regional = regional )
180+ torch .cuda .synchronize ()
181+ end = perf_counter ()
182+ # make sure the model works.
183+ _ = model (input )
184+ return end - start
183185
184- def aot_compile_model (regional = False ):
186+ def aot_compile_load_model (regional = False ) -> torch . nn . Module :
185187 input = torch .randn (10 , 10 , device = "cuda" )
186188 model = Model ().cuda ()
187189
188190 inductor_configs = {}
189191 if regional :
190192 inductor_configs = {"aot_inductor.package_constants_in_so" : False }
191- path = torch ._inductor .aoti_compile_and_package (
192- torch .export .export (
193- model .layers [0 ] if regional else model ,
194- args = (input ,)
195- ),
196- inductor_configs = inductor_configs ,
197- )
198-
199- if regional :
200- for layer in model .layers :
201- compiled_layer = torch ._inductor .aoti_load_package (path )
202- compiled_layer .load_constants (
203- layer .state_dict (), check_full_update = True , user_managed = True
204- )
205- layer .forward = compiled_layer
206- else :
207- compiled_layer = torch ._inductor .aoti_load_package (path )
208193
194+ # Reset the compiler caches to ensure no reuse between different runs
195+ torch .compiler .reset ()
196+ with torch ._inductor .utils .fresh_inductor_cache ():
197+ path = torch ._inductor .aoti_compile_and_package (
198+ torch .export .export (
199+ model .layers [0 ] if regional else model ,
200+ args = (input ,)
201+ ),
202+ inductor_configs = inductor_configs ,
203+ )
204+
205+ if regional :
206+ for layer in model .layers :
207+ compiled_layer = torch ._inductor .aoti_load_package (path )
208+ compiled_layer .load_constants (
209+ layer .state_dict (), check_full_update = True , user_managed = True
210+ )
211+ layer .forward = compiled_layer
212+ else :
213+ model = torch ._inductor .aoti_load_package (path )
209214 return model
210215
211216input = torch .randn (10 , 10 , device = "cuda" )
212- full_model_compilation_latency = measure_latency ( aot_compile_model (), input )
217+ full_model_compilation_latency = measure_compile_time ( input , regional = False )
213218print (f"Full model compilation time = { full_model_compilation_latency :.2f} seconds" )
214219
215- regional_compilation_latency = measure_latency ( aot_compile_model ( regional = True ), input )
220+ regional_compilation_latency = measure_compile_time ( input , regional = True )
216221print (f"Regional compilation time = { regional_compilation_latency :.2f} seconds" )
217222
218223assert regional_compilation_latency < full_model_compilation_latency
219224
225+ ############################################################################
226+ # There may also be layers in a model incompatible with compilation. So,
227+ # full compilation will result in a fragmented computation graph resulting
228+ # in potential latency degradation. In these case, regional compilation
229+ # can be beneficial.
230+ #
231+
220232############################################################################
221233# Conclusion
222234# -----------
223235#
224- # This recipe shows how to control the cold start time when compiling your model ahead-of-time.
225- # This becomes effective when your model has repeated blocks, like typically seen in large generative
226- # models.
236+ # This recipe shows how to control the cold start time when compiling your
237+ # model ahead-of-time. This becomes effective when your model has repeated
238+ # blocks, like typically seen in large generative models.
0 commit comments