Skip to content

Commit 3819737

Browse files
authored
[compiler toolkit] Port joint_ac_pass from simplefsdp (#2051)
This PR integrates the changes in #1970 to compiler toolkit (applying `joint_ac_pass` on the joint graph graph to tag nodes based on `reshard_after_forward` flag) Also did some refactor for applying graph passes in compiler toolkit experiments. We will have two kinds of passes 1. joint_custom_passes: these are passes to be applied on the captured joint graph before partitioner. By default we `validate_flex_attn_annotation_pass` and `fsdp_reshard_after_fwd_pass` 2. compiler_passes: there are passes to be applied on partitioned fwd and bwd graphs as backend optimizations. By default there is none. We can indicate `autobucketing_reordering_pass` and `regional_inductor_pass` using configs.
1 parent 22e959a commit 3819737

File tree

5 files changed

+98
-24
lines changed

5 files changed

+98
-24
lines changed

torchtitan/experiments/compiler_toolkit/common_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,3 @@ def register_blockmask_pytree_node():
5353
flatten_with_keys_fn=BlockMask._flatten_with_keys,
5454
serialized_type_name="torch.nn.attention.flex_attention.BlockMask",
5555
)
56-
57-
58-
def validate_flex_attention_annotation(joint_with_descriptors):
59-
"""Verify user annotations show up in the graph."""
60-
for node in joint_with_descriptors.graph_module.graph.nodes:
61-
if node.target in {
62-
torch.ops.higher_order.flex_attention,
63-
torch.ops.higher_order.flex_attention_backward,
64-
}:
65-
assert "compile_with_inductor" in node.meta.get("custom", {})

torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
disable_compile,
1818
parallelize_inputs,
1919
register_blockmask_pytree_node,
20-
validate_flex_attention_annotation,
2120
)
2221

2322
from torchtitan.experiments.compiler_toolkit.graph_utils import (
2423
CompiledModule,
2524
get_compiler_passes_from_config,
25+
get_joint_custom_passes_from_config,
2626
joint_graph_builder,
2727
make_compiler_with_passes,
2828
)
@@ -76,6 +76,9 @@ def parallelize_deepseekv3(
7676
with disable_compile(job_config):
7777
model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config)
7878

79+
# Get joint custom passes from config
80+
joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config)
81+
7982
# Get compiler passes from config
8083
compiler_passes = get_compiler_passes_from_config(job_config)
8184

@@ -89,7 +92,7 @@ def parallelize_deepseekv3(
8992
joint_graph_builder,
9093
fw_compiler=fw_compiler,
9194
bw_compiler=bw_compiler,
92-
joint_custom_pass=validate_flex_attention_annotation,
95+
joint_custom_passes=joint_custom_passes,
9396
dump_folder=job_config.job.dump_folder,
9497
)
9598

torchtitan/experiments/compiler_toolkit/graph_utils.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import contextlib
8+
import functools
89
from pathlib import Path
910
from typing import Any, Callable, List, Optional
1011

@@ -86,7 +87,7 @@ def joint_graph_builder(
8687
model_kwargs: dict,
8788
fw_compiler: Optional[Callable] = None,
8889
bw_compiler: Optional[Callable] = None,
89-
joint_custom_pass: Optional[Callable] = None,
90+
joint_custom_passes: Optional[List[Callable]] = None,
9091
dump_folder: str | None = None,
9192
):
9293
"""
@@ -98,7 +99,7 @@ def joint_graph_builder(
9899
model_kwargs: Dict of model input keyword arguments
99100
fw_compiler: Optional custom forward compiler function
100101
bw_compiler: Optional custom backward compiler function
101-
joint_custom_pass: Optional custom pass to run on the joint graph
102+
joint_custom_passes: list of custom passes to run on the joint graph
102103
dump_folder: Optional folder to dump the graph to
103104
"""
104105
assert isinstance(model_args, tuple)
@@ -112,8 +113,11 @@ def joint_graph_builder(
112113
) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder)
113114

114115
# Optional validation
115-
if joint_custom_pass is not None:
116-
joint_custom_pass(joint_with_descriptors)
116+
if joint_custom_passes is not None:
117+
for joint_custom_pass in joint_custom_passes:
118+
joint_with_descriptors.graph_module = joint_custom_pass(
119+
joint_with_descriptors.graph_module
120+
)
117121

118122
with tracing(tracing_context):
119123
fn = aot_compile_joint_with_descriptors(
@@ -283,20 +287,64 @@ def get_compiler_passes_from_config(job_config: JobConfig):
283287
Returns:
284288
List of compiler pass functions
285289
"""
286-
from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_PASSES
290+
from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_COMPILER_PASSES
287291

288292
pass_names = getattr(job_config.compile, "passes", [])
289293
compiler_passes = []
290294

291295
for pass_name in pass_names:
292-
if pass_name not in AVAILABLE_PASSES:
296+
if pass_name not in AVAILABLE_COMPILER_PASSES:
293297
raise ValueError(
294298
f"Unknown compiler pass: {pass_name}. "
295-
f"Available passes: {list(AVAILABLE_PASSES.keys())}"
299+
f"Available compiler passes: {list(AVAILABLE_COMPILER_PASSES.keys())}"
296300
)
297-
compiler_passes.append(AVAILABLE_PASSES[pass_name])
301+
compiler_passes.append(AVAILABLE_COMPILER_PASSES[pass_name])
298302

299303
if pass_names:
300304
logger.info(f"Using compiler passes from config: {pass_names}")
301305

302306
return compiler_passes
307+
308+
309+
def get_joint_custom_passes_from_config(
310+
parallel_dims: ParallelDims, job_config: JobConfig
311+
):
312+
"""
313+
Extract and validate joint custom passes from job config.
314+
315+
Args:
316+
job_config: Job configuration containing parallelism.fsdp_reshard_after_forward
317+
318+
Returns:
319+
List of joint custom pass functions
320+
"""
321+
from torchtitan.experiments.compiler_toolkit.passes import (
322+
fsdp_reshard_after_fwd_pass,
323+
validate_flex_attn_annotation_pass,
324+
)
325+
326+
joint_custom_passes = []
327+
joint_custom_passes.append(validate_flex_attn_annotation_pass)
328+
329+
match job_config.parallelism.fsdp_reshard_after_forward:
330+
case "always":
331+
fsdp_reshard_after_forward = True
332+
case "never":
333+
fsdp_reshard_after_forward = False
334+
case "default":
335+
# For PP, by default do not reshard after forward to avoid per-microbatch
336+
# all-gathers, which can be expensive and non-overlapped
337+
fsdp_reshard_after_forward = not parallel_dims.pp_enabled
338+
case _:
339+
raise ValueError(
340+
f"Invalid fsdp_reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
341+
)
342+
343+
joint_custom_passes.append(
344+
functools.partial(
345+
fsdp_reshard_after_fwd_pass,
346+
reshard_after_forward=fsdp_reshard_after_forward,
347+
)
348+
)
349+
350+
return joint_custom_passes

torchtitan/experiments/compiler_toolkit/llama3/parallelize.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
disable_compile,
1717
parallelize_inputs,
1818
register_blockmask_pytree_node,
19-
validate_flex_attention_annotation,
2019
)
2120

2221
from torchtitan.experiments.compiler_toolkit.graph_utils import (
2322
CompiledModule,
2423
get_compiler_passes_from_config,
24+
get_joint_custom_passes_from_config,
2525
joint_graph_builder,
2626
make_compiler_with_passes,
2727
)
@@ -63,6 +63,9 @@ def parallelize_llama(
6363
with disable_compile(job_config):
6464
model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config)
6565

66+
# Get joint custom passes from config
67+
joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config)
68+
6669
# Get compiler passes from config
6770
compiler_passes = get_compiler_passes_from_config(job_config)
6871

@@ -71,12 +74,12 @@ def parallelize_llama(
7174
compiler_passes, dump_folder=job_config.job.dump_folder
7275
)
7376

74-
# Create custom joint_graph_builder with llama-specific compilers and validation
77+
# Create custom joint_graph_builder with llama-specific compilers
7578
llama_joint_graph_builder = functools.partial(
7679
joint_graph_builder,
7780
fw_compiler=fw_compiler,
7881
bw_compiler=bw_compiler,
79-
joint_custom_pass=validate_flex_attention_annotation,
82+
joint_custom_passes=joint_custom_passes,
8083
dump_folder=job_config.job.dump_folder,
8184
)
8285

torchtitan/experiments/compiler_toolkit/passes.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import torch
1515
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
1616
from torch.fx.passes.regional_inductor import regional_inductor
17+
from torchtitan.experiments.simple_fsdp.reshard_after_forward import (
18+
annotate_fsdp_all_gather,
19+
)
1720

1821

1922
def autobucketing_reordering_pass(
@@ -39,8 +42,35 @@ def regional_inductor_pass(
3942
return regional_inductor(gm, example_inputs)
4043

4144

45+
def validate_flex_attn_annotation_pass(
46+
gm: torch.fx.GraphModule,
47+
) -> torch.fx.GraphModule:
48+
"""Verify user annotations show up in the graph."""
49+
for node in gm.graph.nodes:
50+
if node.target in {
51+
torch.ops.higher_order.flex_attention,
52+
torch.ops.higher_order.flex_attention_backward,
53+
}:
54+
assert "compile_with_inductor" in node.meta.get("custom", {})
55+
return gm
56+
57+
58+
# Apply activation checkpointing on joint graph before partitioner
59+
def fsdp_reshard_after_fwd_pass(
60+
gm: torch.fx.GraphModule, reshard_after_forward: bool
61+
) -> torch.fx.GraphModule:
62+
# this pass implements simplefsdp's fsdp_reshard_after_forward behavior
63+
# when fsdp_reshard_after_forward set to True, it will annotate simple_fsdp AG
64+
# to CheckpointPolicy.MUST_RECOMPUTE.
65+
# when fsdp_reshard_after_forward set to False, it will annotate simple_fsdp AG
66+
# to CheckpointPolicy.MUST_SAVE.
67+
gm = annotate_fsdp_all_gather(gm, reshard_after_forward)
68+
gm.recompile()
69+
return gm
70+
71+
4272
# Registry mapping pass names to pass functions
43-
AVAILABLE_PASSES = {
73+
AVAILABLE_COMPILER_PASSES = {
4474
"autobucketing_reordering": autobucketing_reordering_pass,
4575
"regional_inductor": regional_inductor_pass,
4676
}

0 commit comments

Comments
 (0)