Skip to content

Commit bfdc974

Browse files
authored
[compiler toolkit] Port manual bucketing from SimpleFSDP experiment (#2056)
This PR integrates the manual bucketing pass (transformer block bucketing) added in SimpleFSDP experiment (#1881) to compiler toolkit So now compiler toolkit can also run manual bucketing pass by specifying the config ``` NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing ``` Also updated README and integration test to include the newly ported pass
1 parent 3819737 commit bfdc974

File tree

7 files changed

+106
-13
lines changed

7 files changed

+106
-13
lines changed

torchtitan/experiments/compiler_toolkit/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./r
3434
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering
3535
```
3636

37+
**SimpleFSDP + TP + transformer-block-bucketing**
38+
```shell
39+
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing
40+
```
41+
3742
**SimpleFSDP + TP + FlexAttention**
3843
```shell
3944
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn
@@ -44,3 +49,9 @@ NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./r
4449
```shell
4550
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor
4651
```
52+
53+
**SimpleFSDP + TP + FlexAttention + transformer-block-bucketing + regional-inductor**
54+
55+
```shell
56+
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor
57+
```

torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def parallelize_deepseekv3(
8080
joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config)
8181

8282
# Get compiler passes from config
83-
compiler_passes = get_compiler_passes_from_config(job_config)
83+
compiler_passes = get_compiler_passes_from_config(model, job_config)
8484

8585
# Create compilers with specified passes (defaults to no passes)
8686
fw_compiler, bw_compiler = make_compiler_with_passes(

torchtitan/experiments/compiler_toolkit/graph_utils.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def joint_graph_builder(
112112
tracing_context,
113113
) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder)
114114

115-
# Optional validation
115+
# run custom passes on joint-graph before partitioner
116116
if joint_custom_passes is not None:
117117
for joint_custom_pass in joint_custom_passes:
118118
joint_with_descriptors.graph_module = joint_custom_pass(
@@ -240,7 +240,12 @@ def compiler(
240240
_dump_gm(dump_folder, gm, f"{name}_before_compiler")
241241

242242
for pass_fn in passes:
243-
logger.info(f"Applying pass: {pass_fn.__name__}")
243+
pass_name = (
244+
pass_fn.func.__name__
245+
if isinstance(pass_fn, functools.partial)
246+
else pass_fn.__name__
247+
)
248+
logger.info(f"Applying pass: {pass_name}")
244249
gm = pass_fn(gm, example_inputs)
245250

246251
logger.debug(f"{name} after compiler:")
@@ -277,7 +282,7 @@ def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
277282
return fw_compiler, bw_compiler
278283

279284

280-
def get_compiler_passes_from_config(job_config: JobConfig):
285+
def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfig):
281286
"""
282287
Extract and validate compiler passes from job config.
283288
@@ -288,8 +293,18 @@ def get_compiler_passes_from_config(job_config: JobConfig):
288293
List of compiler pass functions
289294
"""
290295
from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_COMPILER_PASSES
296+
from torchtitan.experiments.simple_fsdp.llama3.parallelize import (
297+
get_transformer_block_buckets,
298+
)
291299

292300
pass_names = getattr(job_config.compile, "passes", [])
301+
if (
302+
"autobucketing_reordering" in pass_names
303+
and "transformer_block_bucketing" in pass_names
304+
):
305+
raise ValueError(
306+
"Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!"
307+
)
293308
compiler_passes = []
294309

295310
for pass_name in pass_names:
@@ -298,7 +313,15 @@ def get_compiler_passes_from_config(job_config: JobConfig):
298313
f"Unknown compiler pass: {pass_name}. "
299314
f"Available compiler passes: {list(AVAILABLE_COMPILER_PASSES.keys())}"
300315
)
301-
compiler_passes.append(AVAILABLE_COMPILER_PASSES[pass_name])
316+
if pass_name == "transformer_block_bucketing":
317+
compiler_passes.append(
318+
functools.partial(
319+
AVAILABLE_COMPILER_PASSES[pass_name],
320+
fsdp_manual_buckets=get_transformer_block_buckets(model),
321+
)
322+
)
323+
else:
324+
compiler_passes.append(AVAILABLE_COMPILER_PASSES[pass_name])
302325

303326
if pass_names:
304327
logger.info(f"Using compiler passes from config: {pass_names}")

torchtitan/experiments/compiler_toolkit/llama3/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def parallelize_llama(
6767
joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config)
6868

6969
# Get compiler passes from config
70-
compiler_passes = get_compiler_passes_from_config(job_config)
70+
compiler_passes = get_compiler_passes_from_config(model, job_config)
7171

7272
# Create compilers with specified passes (defaults to no passes)
7373
fw_compiler, bw_compiler = make_compiler_with_passes(

torchtitan/experiments/compiler_toolkit/passes.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"""
1313

1414
import torch
15+
from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing
1516
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
1617
from torch.fx.passes.regional_inductor import regional_inductor
1718
from torchtitan.experiments.simple_fsdp.reshard_after_forward import (
@@ -26,13 +27,26 @@ def autobucketing_reordering_pass(
2627
Apply autobucketing and reordering optimization.
2728
2829
This pass applies schedule_overlap_bucketing with collective_bucketing enabled
29-
to optimize communication patterns in distributed training.
30+
to optimize comm/compute overlap patterns in the graph.
3031
"""
3132
schedule_overlap_bucketing(gm, collective_bucketing=True)
3233
gm.recompile()
3334
return gm
3435

3536

37+
def transformer_block_bucketing_reordering_pass(
38+
gm: torch.fx.GraphModule, example_inputs, fsdp_manual_buckets
39+
) -> torch.fx.GraphModule:
40+
"""
41+
Apply aten-level manual bucketing and reordering optimization.
42+
"""
43+
manual_overlap_bucketing(
44+
gm, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=False
45+
)
46+
gm.recompile()
47+
return gm
48+
49+
3650
def regional_inductor_pass(
3751
gm: torch.fx.GraphModule, example_inputs
3852
) -> torch.fx.GraphModule:
@@ -72,5 +86,6 @@ def fsdp_reshard_after_fwd_pass(
7286
# Registry mapping pass names to pass functions
7387
AVAILABLE_COMPILER_PASSES = {
7488
"autobucketing_reordering": autobucketing_reordering_pass,
89+
"transformer_block_bucketing": transformer_block_bucketing_reordering_pass,
7590
"regional_inductor": regional_inductor_pass,
7691
}

torchtitan/experiments/compiler_toolkit/tests/integration_tests.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]:
2424
"--model.name compiler_toolkit.llama3",
2525
"--parallelism.data_parallel_shard_degree 2",
2626
"--parallelism.tensor_parallel_degree 2",
27-
"--activation_checkpoint.mode none",
2827
],
2928
],
3029
"llama3 FSDP+TP",
@@ -37,7 +36,6 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]:
3736
"--model.name compiler_toolkit.llama3",
3837
"--parallelism.data_parallel_shard_degree 2",
3938
"--parallelism.tensor_parallel_degree 2",
40-
"--activation_checkpoint.mode none",
4139
"--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config",
4240
"--compile.passes autobucketing_reordering",
4341
],
@@ -46,14 +44,27 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]:
4644
"llama3_fsdp_tp_autobucketing",
4745
ngpu=4,
4846
),
47+
OverrideDefinitions(
48+
[
49+
[
50+
"--model.name compiler_toolkit.llama3",
51+
"--parallelism.data_parallel_shard_degree 2",
52+
"--parallelism.tensor_parallel_degree 2",
53+
"--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config",
54+
"--compile.passes transformer_block_bucketing",
55+
],
56+
],
57+
"llama3 FSDP+TP manualbucketing",
58+
"llama3_fsdp_tp_manualbucketing",
59+
ngpu=4,
60+
),
4961
OverrideDefinitions(
5062
[
5163
[
5264
"--model.name compiler_toolkit.llama3",
5365
"--parallelism.data_parallel_shard_degree 2",
5466
"--parallelism.tensor_parallel_degree 2",
5567
"--model.flavor debugmodel_flex_attn",
56-
"--activation_checkpoint.mode none",
5768
],
5869
],
5970
"llama3 FSDP+TP+FlexAttn",
@@ -67,7 +78,6 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]:
6778
"--parallelism.data_parallel_shard_degree 2",
6879
"--parallelism.tensor_parallel_degree 2",
6980
"--model.flavor debugmodel_flex_attn",
70-
"--activation_checkpoint.mode none",
7181
"--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config",
7282
"--compile.passes autobucketing_reordering,regional_inductor",
7383
],
@@ -76,6 +86,21 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]:
7686
"llama3_fsdp_tp_flexattn_autobucketing_regional_inductor",
7787
ngpu=4,
7888
),
89+
OverrideDefinitions(
90+
[
91+
[
92+
"--model.name compiler_toolkit.llama3",
93+
"--parallelism.data_parallel_shard_degree 2",
94+
"--parallelism.tensor_parallel_degree 2",
95+
"--model.flavor debugmodel_flex_attn",
96+
"--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config",
97+
"--compile.passes transformer_block_bucketing,regional_inductor",
98+
],
99+
],
100+
"llama3 FSDP+TP+FlexAttn manualbucketing regional_inductor",
101+
"llama3_fsdp_tp_flexattn_manualbucketing_regional_inductor",
102+
ngpu=4,
103+
),
79104
# deepseek_v3 tests
80105
OverrideDefinitions(
81106
[

torchtitan/experiments/compiler_toolkit/tests/test_numerics.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,30 @@ def test_llama3_fsdp_tp_autobucketing(self):
4242
ac_mode="selective",
4343
steps=10,
4444
seed=42,
45-
eager_tb_folder="tb/test_llama3_fsdp_tp_eager",
46-
compiled_tb_folder="tb/test_llama3_fsdp_tp_compiled",
45+
eager_tb_folder="tb/test_llama3_fsdp_tp_autobucketing_eager",
46+
compiled_tb_folder="tb/test_llama3_fsdp_tp_autobucketing_compiled",
4747
metrics=["loss_metrics/global_avg_loss", "grad_norm"],
4848
passes="autobucketing_reordering",
4949
)
50+
self.assertTrue(result, "Llama3 FSDP+TP+autobucketing numerics test failed")
51+
52+
def test_llama3_fsdp_tp_manualbucketing(self):
53+
result = run_numerics_test(
54+
ngpu=4,
55+
config_file="./torchtitan/models/llama3/train_configs/debug_model.toml",
56+
dp_shard_degree=2,
57+
tp_degree=2,
58+
cp_degree=1,
59+
ep_degree=1,
60+
ac_mode="selective",
61+
steps=10,
62+
seed=42,
63+
eager_tb_folder="tb/test_llama3_fsdp_tp_manualbucketing_eager",
64+
compiled_tb_folder="tb/test_llama3_fsdp_tp_manualbucketing_compiled",
65+
metrics=["loss_metrics/global_avg_loss", "grad_norm"],
66+
passes="transformer_block_bucketing",
67+
)
68+
self.assertTrue(result, "Llama3 FSDP+TP+manualbucketing numerics test failed")
5069

5170
def test_deepseek_v3_fsdp_tp_ep(self):
5271
"""Test DeepSeek V3 with FSDP + TP + EP configuration."""

0 commit comments

Comments
 (0)