Skip to content

Commit 7e10d60

Browse files
xmfanwconstabezyangfmassaruisizhang123
authored
Autoparallel as an experiment in main (#2054)
Experiments like SimpleFSDP/Compiler Toolkit/Autoparallel are all being developed at the same time, and SimpleFSDP/Compiler Toolkit both run into issues with PP that requires the PP utilities from Autoparallel. We want to land the Autoparallel experiment into main to facilitate that sharing. --------- Signed-off-by: Edward Z. Yang <ezyang@meta.com> Co-authored-by: Will Constable <whc@meta.com> Co-authored-by: Edward Z. Yang <ezyang@meta.com> Co-authored-by: Francisco Massa <fvsmassa@gmail.com> Co-authored-by: ruisizhang123 <ruisizhang123@gmail.com> Co-authored-by: Brian Hirsh <briandhirsh@gmail.com> Co-authored-by: Will Constable <willconstable@gmail.com>
1 parent 7e1edb6 commit 7e10d60

File tree

9 files changed

+740
-5
lines changed

9 files changed

+740
-5
lines changed

torchtitan/components/optimizer.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
StateDictOptions,
1717
)
1818
from torch.distributed.checkpoint.stateful import Stateful
19+
from torch.distributed.tensor import Replicate
1920
from torch.optim import Optimizer
2021

2122
from torchtitan.components.ft import FTManager, has_torchft
@@ -380,11 +381,19 @@ def _update_expert_bias(
380381
tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list)
381382

382383
if dp_cp_mesh is not None:
383-
# Perform single all-reduce to get global statistics across all processes
384-
pg = dp_cp_mesh.get_group()
385-
torch.distributed.all_reduce(
386-
tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM
387-
)
384+
if isinstance(tokens_per_expert_by_layer, torch.distributed.tensor.DTensor):
385+
tokens_per_expert_by_layer = tokens_per_expert_by_layer.redistribute(
386+
placements=[Replicate()]
387+
* tokens_per_expert_by_layer.device_mesh.ndim
388+
)
389+
else:
390+
# Perform single all-reduce to get global statistics across all processes
391+
pg = dp_cp_mesh.get_group()
392+
torch.distributed.all_reduce(
393+
tokens_per_expert_by_layer,
394+
group=pg,
395+
op=torch.distributed.ReduceOp.SUM,
396+
)
388397

389398
moe_layer_idx = 0
390399
with torch.no_grad():

torchtitan/experiments/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ We provide this `experiments/` folder to host experiments that add significant v
3232
| [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) |
3333
| [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) |
3434
| [transformers_modeling_backend](./transformers_modeling_backend/) | [![Transformers modeling backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) |
35+
| [auto_parallel](./auto_parallel/) | TBA | [@wconstab](https://github.com/wconstab) | [@xmfan](https://github.com/xmfan) |

torchtitan/experiments/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,7 @@
1313
"compiler_toolkit.deepseek_v3",
1414
"compiler_toolkit.llama3",
1515
"transformers_modeling_backend",
16+
"auto_parallel.llama3",
17+
"auto_parallel.deepseek_v3",
1618
]
1719
)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
## Auto Parallel
2+
3+
### Overview
4+
5+
The Auto Parallel experiment integrates PyTorch's AutoParallel framework with TorchTitan to automatically optimize distributed training parallelism strategies given a device mesh. Instead of manually configuring parallelism layouts, AutoParallel uses cost-based analysis to determine optimal sharding placements for model parameters, activations, and gradients.
6+
7+
### Requirements
8+
9+
Requires installing [git@github.com:meta-pytorch/autoparallel.git](https://github.com/meta-pytorch/autoparallel)
10+
11+
### Single Node
12+
13+
**Llama3**
14+
15+
`CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.llama3 --parallelism.tensor_parallel_degree 4 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config`
16+
17+
**DeepSeekv3**
18+
19+
`CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config`
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
import copy
10+
11+
from torchtitan.components.loss import build_cross_entropy_loss
12+
from torchtitan.components.lr_scheduler import build_lr_schedulers
13+
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
14+
from torchtitan.components.tokenizer import build_hf_tokenizer
15+
from torchtitan.distributed.pipeline_parallel import pipeline_llm
16+
from torchtitan.hf_datasets.text_datasets import build_text_dataloader
17+
18+
from torchtitan.models.deepseek_v3 import deepseekv3_args, DeepSeekV3Model
19+
from torchtitan.models.deepseek_v3.model.args import DeepSeekV3ModelArgs
20+
from torchtitan.models.deepseek_v3.model.state_dict_adapter import (
21+
DeepSeekV3StateDictAdapter,
22+
)
23+
from torchtitan.protocols.train_spec import TrainSpec
24+
25+
from .parallelize_deepseekv3 import parallelize_deepseekv3
26+
27+
28+
def get_train_spec() -> TrainSpec:
29+
model_args = copy.deepcopy(deepseekv3_args)
30+
31+
default_args = DeepSeekV3ModelArgs()
32+
for config, args in model_args.items():
33+
if "flex_attn" in config:
34+
continue
35+
36+
args.attn_type = default_args.attn_type
37+
args.attn_mask_type = default_args.attn_mask_type
38+
39+
return TrainSpec(
40+
model_cls=DeepSeekV3Model,
41+
model_args=model_args,
42+
parallelize_fn=parallelize_deepseekv3,
43+
pipelining_fn=pipeline_llm,
44+
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
45+
build_lr_schedulers_fn=build_lr_schedulers,
46+
build_dataloader_fn=build_text_dataloader,
47+
build_tokenizer_fn=build_hf_tokenizer,
48+
build_loss_fn=build_cross_entropy_loss,
49+
state_dict_adapter=DeepSeekV3StateDictAdapter,
50+
)

0 commit comments

Comments
 (0)