1616# Doesn't torch.compile already capture the backward graph?
1717# ------------
1818# Partially. AOTAutograd captures the backward graph ahead-of-time, but with certain limitations:
19- # - Graph breaks in the forward lead to graph breaks in the backward
20- # - `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
19+ # - Graph breaks in the forward lead to graph breaks in the backward
20+ # - `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
2121#
2222# Compiled Autograd addresses these limitations by directly integrating with the autograd engine, allowing
2323# it to capture the full backward graph at runtime. Models with these two characteristics should try
2424# Compiled Autograd, and potentially observe better performance.
2525#
2626# However, Compiled Autograd has its own limitations:
27- # - Dynamic autograd structure leads to recompiles
27+ # - Dynamic autograd structure leads to recompiles
2828#
2929
30+ ######################################################################
31+ # Tutorial output cells setup
32+ # ------------
33+ #
34+
35+ import os
36+
37+ class ScopedLogging :
38+ def __init__ (self ):
39+ assert "TORCH_LOGS" not in os .environ
40+ assert "TORCH_LOGS_FORMAT" not in os .environ
41+ os .environ ["TORCH_LOGS" ] = "compiled_autograd_verbose"
42+ os .environ ["TORCH_LOGS_FORMAT" ] = "short"
43+
44+ def __del__ (self ):
45+ del os .environ ["TORCH_LOGS" ]
46+ del os .environ ["TORCH_LOGS_FORMAT" ]
47+
48+
3049######################################################################
3150# Basic Usage
3251# ------------
3352#
3453
54+ import torch
55+
3556# NOTE: Must be enabled before using the decorator
3657torch ._dynamo .config .compiled_autograd = True
3758
@@ -57,21 +78,12 @@ def train(model, x):
5778# ------------
5879# Run the script with either TORCH_LOGS environment variables
5980#
60- """
61- Prints graph:
62- TORCH_LOGS="compiled_autograd" python example.py
63-
64- Performance degrading, prints verbose graph and recompile reasons:
65- TORCH_LOGS="compiled_autograd_verbose" python example.py
66- """
67-
68- ######################################################################
69- # Or with the set_logs private API:
81+ # - To only print the compiled autograd graph, use `TORCH_LOGS="compiled_autograd" python example.py`
82+ # - To sacrifice some performance, in order to print the graph with more tensor medata and recompile reasons, use `TORCH_LOGS="compiled_autograd_verbose" python example.py`
83+ #
84+ # Logs can also be enabled through the private API torch._logging._internal.set_logs.
7085#
7186
72- # flag must be enabled before wrapping using torch.compile
73- torch ._logging ._internal .set_logs (compiled_autograd = True )
74-
7587@torch .compile
7688def train (model , x ):
7789 loss = model (x ).sum ()
@@ -80,14 +92,15 @@ def train(model, x):
8092train (model , x )
8193
8294######################################################################
83- # The compiled autograd graph should now be logged to stdout. Certain graph nodes will have names that are prefixed by " aot0_" ,
95+ # The compiled autograd graph should now be logged to stdout. Certain graph nodes will have names that are prefixed by aot0_,
8496# these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0.
8597#
8698# NOTE: This is the graph that we will call torch.compile on, NOT the optimized graph. Compiled Autograd basically
8799# generated some python code to represent the entire C++ autograd execution.
88100#
89101"""
90- INFO:torch._dynamo.compiled_autograd.__compiled_autograd:TRACED GRAPH
102+ DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
103+ DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
91104 ===== Compiled autograd graph =====
92105 <eval_with_key>.4 class CompiledAutograd(torch.nn.Module):
93106 def forward(self, inputs, sizes, scalars, hooks):
@@ -178,6 +191,7 @@ def fn(x):
178191 return temp .sum ()
179192
180193x = torch .randn (10 , 10 , requires_grad = True )
194+ torch ._dynamo .utils .counters .clear ()
181195loss = fn (x )
182196
183197# 1. base torch.compile
@@ -205,7 +219,6 @@ def fn(x):
205219x .register_hook (lambda grad : grad + 10 )
206220loss = fn (x )
207221
208- torch ._logging ._internal .set_logs (compiled_autograd = True )
209222with torch ._dynamo .compiled_autograd .enable (torch .compile (backend = "aot_eager" )):
210223 loss .backward ()
211224
@@ -214,22 +227,22 @@ def fn(x):
214227#
215228
216229"""
217- INFO:torch._dynamo.compiled_autograd.__compiled_autograd:TRACED GRAPH
230+ DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
231+ DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
218232 ===== Compiled autograd graph =====
219233 <eval_with_key>.2 class CompiledAutograd(torch.nn.Module):
220234 def forward(self, inputs, sizes, scalars, hooks):
221- ...
222- getitem_2 = hooks[0]; hooks = None
223- call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None
224- ...
235+ ...
236+ getitem_2 = hooks[0]; hooks = None
237+ call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None
238+ ...
225239"""
226240
227241######################################################################
228- # Understanding recompilation reasons for Compiled Autograd
242+ # Common recompilation reasons for Compiled Autograd
229243# ------------
230244# 1. Due to change in autograd structure
231245
232- torch ._logging ._internal .set_logs (compiled_autograd_verbose = True )
233246torch ._dynamo .config .compiled_autograd = True
234247x = torch .randn (10 , requires_grad = True )
235248for op in [torch .add , torch .sub , torch .mul , torch .div ]:
@@ -238,14 +251,18 @@ def forward(self, inputs, sizes, scalars, hooks):
238251
239252######################################################################
240253# You should see some cache miss logs (recompiles):
241- # Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
242- # ...
243- # Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[]
244- # ...
245- # Cache miss due to new autograd node: MulBackward0 (NodeCall 2) with key size 71, previous key sizes=[]
246- # ...
247- # Cache miss due to new autograd node: DivBackward0 (NodeCall 2) with key size 70, previous key sizes=[]
248- # ...
254+ #
255+
256+ """
257+ Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
258+ ...
259+ Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[]
260+ ...
261+ Cache miss due to new autograd node: MulBackward0 (NodeCall 2) with key size 71, previous key sizes=[]
262+ ...
263+ Cache miss due to new autograd node: DivBackward0 (NodeCall 2) with key size 70, previous key sizes=[]
264+ ...
265+ """
249266
250267######################################################################
251268# 2. Due to dynamic shapes
@@ -260,12 +277,16 @@ def forward(self, inputs, sizes, scalars, hooks):
260277
261278######################################################################
262279# You should see some cache miss logs (recompiles):
263- # ...
264- # Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic
265- # Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
266- # Cache miss due to changed shapes: marking size idx 2 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
267- # Cache miss due to changed shapes: marking size idx 3 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
268- # ...
280+ #
281+
282+ """
283+ ...
284+ Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic
285+ Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
286+ Cache miss due to changed shapes: marking size idx 2 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
287+ Cache miss due to changed shapes: marking size idx 3 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
288+ ...
289+ """
269290
270291######################################################################
271292# Compatibility and rough edges
0 commit comments