Skip to content

Commit f609a73

Browse files
aditvenkAditya Venkataraman
andauthored
Fix inference mode compilation in update_joint_with_descriptors (#265)
Fixed bug where updated_flat_args was incorrectly formatted as a tuple for inference mode, causing AOT compilation to fail. The format should be a list for inference mode and tuple (primals, tangents) for autograd mode, per PyTorch's AOTGraphCapture schema. ref: https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/schemas.py#L1197 Testing: New unit test test_inference_mode_compilation() Co-authored-by: Aditya Venkataraman <avenkataraman@fb.com>
1 parent 449e35e commit f609a73

File tree

2 files changed

+64
-4
lines changed

2 files changed

+64
-4
lines changed

autoparallel/graph_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,17 @@ def update_joint_with_descriptors(
6565

6666
tangent_idx = len(joint_with_descriptors._aot_state.flat_args)
6767
new_local_tangents = new_local_args[tangent_idx:]
68-
joint_with_descriptors._aot_graph_capture.updated_flat_args = (
69-
new_flat_args,
70-
new_local_tangents,
71-
)
68+
69+
# For inference mode (no tangents), updated_flat_args should be a list.
70+
# For autograd mode (with tangents), it should be a tuple of (primals, tangents).
71+
if new_local_tangents:
72+
joint_with_descriptors._aot_graph_capture.updated_flat_args = (
73+
new_flat_args,
74+
new_local_tangents,
75+
)
76+
else:
77+
joint_with_descriptors._aot_graph_capture.updated_flat_args = new_flat_args
78+
7279
joint_with_descriptors._aot_state.flat_args = new_flat_args
7380
joint_with_descriptors._aot_state.fw_metadata.traced_tangents = new_local_tangents
7481

tests/test_api.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,56 @@ def input_fn():
305305
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_1, 10), kwargs = {})
306306
# %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, 30), kwargs = {})
307307
# return ((add, add_2), (tangents_1, None))
308+
309+
310+
def test_inference_mode_compilation(device_mesh_1d):
311+
"""Test that inference mode (no gradients) works with compile=True.
312+
313+
This test verifies the fix for the bug where updated_flat_args was incorrectly
314+
formatted as a tuple for inference mode, causing compilation to fail.
315+
316+
Regression test for: updated_flat_args should be a list for inference mode,
317+
not a tuple of (primals, tangents).
318+
"""
319+
dim = 128
320+
321+
class SimpleLinear(nn.Module):
322+
def __init__(self, in_features, out_features):
323+
super().__init__()
324+
self.linear = nn.Linear(in_features, out_features, bias=False)
325+
326+
def forward(self, x):
327+
return self.linear(x)
328+
329+
def input_fn():
330+
batch_size = 256
331+
return torch.rand(batch_size, dim, device="cuda")
332+
333+
with torch.device("meta"):
334+
model = SimpleLinear(dim, dim * 2)
335+
336+
# Set model to inference mode (no gradients)
337+
for param in model.parameters():
338+
param.requires_grad = False
339+
340+
# Test with compile=True - this should succeed with the fix
341+
with AutoParallel(model, input_fn, device_mesh_1d, None, compile=True) as autop:
342+
autop.add_parameter_memory_constraint(low=None, high=device_mesh_1d.ndim)
343+
344+
# R -> S(0)
345+
autop.add_input_constraints([(Replicate(),)])
346+
autop.add_output_constraints([(Shard(0),)])
347+
348+
sharding_placement = autop.optimize_placement()
349+
parallel_mod = autop.apply_placement(sharding_placement)
350+
351+
# Verify the model was created
352+
assert parallel_mod is not None
353+
assert hasattr(autop, "parallel_gm")
354+
355+
# Verify graph has expected structure (forward-only, no backward pass)
356+
placeholders = [
357+
n for n in autop.parallel_gm.graph.nodes if n.op == "placeholder"
358+
]
359+
# Should only have 2 placeholders: weight and input (no tangents for inference)
360+
assert len(placeholders) == 2

0 commit comments

Comments
 (0)