33============================================================================
44
55**Author:** `Animesh Jain <https://github.com/anijain2305>`_
6+
67As deep learning models get larger, the compilation time of these models also
78increases. This extended compilation time can result in a large startup time in
89inference services or wasted resources in large-scale training. This recipe
2324
2425 pip install torch
2526
26- .. note::
27+ .. note::
2728 This feature is available starting with the 2.5 release. If you are using version 2.4,
2829 you can enable the configuration flag ``torch._dynamo.config.inline_inbuilt_nn_modules=True``
2930 to prevent recompilations during regional compilation. In version 2.5, this flag is enabled by default.
3031"""
3132
32-
33+ from time import perf_counter
3334
3435######################################################################
3536# Steps
3637# -----
37- #
38+ #
3839# In this recipe, we will follow these steps:
3940#
4041# 1. Import all necessary libraries.
4142# 2. Define and initialize a neural network with repeated regions.
4243# 3. Understand the difference between the full model and the regional compilation.
4344# 4. Measure the compilation time of the full model and the regional compilation.
44- #
45- # First, let's import the necessary libraries for loading our data:
46- #
47- #
48- #
45+ #
46+ # First, let's import the necessary libraries for loading our data:
47+ #
48+ #
49+ #
4950
5051import torch
5152import torch .nn as nn
52- from time import perf_counter
53+
5354
5455##########################################################
5556# Next, let's define and initialize a neural network with repeated regions.
56- #
57+ #
5758# Typically, neural networks are composed of repeated layers. For example, a
5859# large language model is composed of many Transformer blocks. In this recipe,
5960# we will create a ``Layer`` using the ``nn.Module`` class as a proxy for a repeated region.
6061# We will then create a ``Model`` which is composed of 64 instances of this
6162# ``Layer`` class.
62- #
63+ #
6364class Layer (torch .nn .Module ):
6465 def __init__ (self ):
6566 super ().__init__ ()
@@ -76,13 +77,16 @@ def forward(self, x):
7677 b = self .relu2 (b )
7778 return b
7879
80+
7981class Model (torch .nn .Module ):
8082 def __init__ (self , apply_regional_compilation ):
8183 super ().__init__ ()
8284 self .linear = torch .nn .Linear (10 , 10 )
8385 # Apply compile only to the repeated layers.
8486 if apply_regional_compilation :
85- self .layers = torch .nn .ModuleList ([torch .compile (Layer ()) for _ in range (64 )])
87+ self .layers = torch .nn .ModuleList (
88+ [torch .compile (Layer ()) for _ in range (64 )]
89+ )
8690 else :
8791 self .layers = torch .nn .ModuleList ([Layer () for _ in range (64 )])
8892
@@ -93,15 +97,16 @@ def forward(self, x):
9397 x = layer (x )
9498 return x
9599
100+
96101####################################################
97102# Next, let's review the difference between the full model and the regional compilation.
98- #
99- # In full model compilation, the entire model is compiled as a whole. This is the common approach
103+ #
104+ # In full model compilation, the entire model is compiled as a whole. This is the common approach
100105# most users take with ``torch.compile``. In this example, we apply ``torch.compile`` to
101106# the ``Model`` object. This will effectively inline the 64 layers, producing a
102107# large graph to compile. You can look at the full graph by running this recipe
103108# with ``TORCH_LOGS=graph_code``.
104- #
109+ #
105110#
106111
107112model = Model (apply_regional_compilation = False ).cuda ()
@@ -113,19 +118,19 @@ def forward(self, x):
113118# By strategically choosing to compile a repeated region of the model, we can compile a
114119# much smaller graph and then reuse the compiled graph for all the regions.
115120# In the example, ``torch.compile`` is applied only to the ``layers`` and not the full model.
116- #
121+ #
117122
118123regional_compiled_model = Model (apply_regional_compilation = True ).cuda ()
119124
120125#####################################################
121126# Applying compilation to a repeated region, instead of full model, leads to
122127# large savings in compile time. Here, we will just compile a layer instance and
123128# then reuse it 64 times in the ``Model`` object.
124- #
129+ #
125130# Note that with repeated regions, some part of the model might not be compiled.
126131# For example, the ``self.linear`` in the ``Model`` is outside of the scope of
127132# regional compilation.
128- #
133+ #
129134# Also, note that there is a tradeoff between performance speedup and compile
130135# time. Full model compilation involves a larger graph and,
131136# theoretically, offers more scope for optimizations. However, for practical
@@ -137,10 +142,11 @@ def forward(self, x):
137142# Next, let's measure the compilation time of the full model and the regional compilation.
138143#
139144# ``torch.compile`` is a JIT compiler, which means that it compiles on the first invocation.
140- # In the code below, we measure the total time spent in the first invocation. While this method is not
145+ # In the code below, we measure the total time spent in the first invocation. While this method is not
141146# precise, it provides a good estimate since the majority of the time is spent in
142147# compilation.
143148
149+
144150def measure_latency (fn , input ):
145151 # Reset the compiler caches to ensure no reuse between different runs
146152 torch .compiler .reset ()
@@ -151,13 +157,16 @@ def measure_latency(fn, input):
151157 end = perf_counter ()
152158 return end - start
153159
160+
154161input = torch .randn (10 , 10 , device = "cuda" )
155162full_model_compilation_latency = measure_latency (full_compiled_model , input )
156163print (f"Full model compilation time = { full_model_compilation_latency :.2f} seconds" )
157164
158165regional_compilation_latency = measure_latency (regional_compiled_model , input )
159166print (f"Regional compilation time = { regional_compilation_latency :.2f} seconds" )
160167
168+ assert regional_compilation_latency < full_model_compilation_latency
169+
161170############################################################################
162171# Conclusion
163172# -----------
@@ -166,4 +175,4 @@ def measure_latency(fn, input):
166175# has repeated regions. This approach requires user modifications to apply `torch.compile` to
167176# the repeated regions instead of more commonly used full model compilation. We
168177# are continually working on reducing cold start compilation time.
169- #
178+ #
0 commit comments