-
Notifications
You must be signed in to change notification settings - Fork 615
Autoparallel as an experiment in main #2054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 46 commits
3ccd12c
e6d2caf
68476b3
9ee9f75
f6e4099
4d7ee8a
b801d0b
b099cf9
b3587d9
42c2c07
d93845e
60f5f11
6c782eb
4712163
3f04d22
8e50870
91c5639
1233902
4f8677b
714cc5b
45647b3
bfa9f7f
75fb2eb
8769396
db22479
87ef4e0
9dc0bd8
c6e25bd
26410e8
e6ea814
7abede8
d2e76b7
472b4ad
ac0def9
a24ef07
da611e4
6cc8caa
d54a6d4
acd9588
2b1fb92
68245d6
e592e22
737ad2c
64e6050
de6dca6
fe0b6cc
2b37f30
bc18d87
5fdf737
c1a307f
58d5349
c480cd1
aa739f6
f03fe9e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,3 +32,4 @@ We provide this `experiments/` folder to host experiments that add significant v | |
| | [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | | ||
| | [compiler_toolkit](./compiler_toolkit/) | [](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) | | ||
| | [transformers_backend](./transformers_backend/) | [](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | | ||
|
||
| | [auto_parallel](./auto_parallel/) | TBA | [@wconstab](https://github.com/wconstab) | [@xmfan](https://github.com/xmfan) | | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,5 +13,7 @@ | |
| "compiler_toolkit.deepseek_v3", | ||
| "compiler_toolkit.llama3", | ||
| "transformers_backend", | ||
|
||
| "auto_parallel.llama3", | ||
| "auto_parallel.deepseek_v3", | ||
| ] | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| ## Auto Parallel | ||
|
|
||
| ### Overview | ||
|
|
||
| The Auto Parallel experiment integrates PyTorch's AutoParallel framework with TorchTitan to automatically optimize distributed training parallelism strategies given a device mesh. Instead of manually configuring parallelism layouts, AutoParallel uses cost-based analysis to determine optimal sharding placements for model parameters, activations, and gradients. | ||
|
|
||
| ### Requirements | ||
|
|
||
| Requires installing [git@github.com:meta-pytorch/autoparallel.git](https://github.com/meta-pytorch/autoparallel) | ||
|
|
||
| ### Single Node | ||
|
|
||
| **Llama3** | ||
|
|
||
| `CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.llama3 --parallelism.tensor_parallel_degree 4 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config` | ||
|
|
||
| **DeepSeekv3** | ||
|
|
||
| `CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
|
||
| import copy | ||
|
|
||
| from torchtitan.components.loss import build_cross_entropy_loss | ||
| from torchtitan.components.lr_scheduler import build_lr_schedulers | ||
| from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing | ||
| from torchtitan.components.tokenizer import build_hf_tokenizer | ||
| from torchtitan.distributed.pipeline_parallel import pipeline_llm | ||
| from torchtitan.hf_datasets.text_datasets import build_text_dataloader | ||
|
|
||
| from torchtitan.models.deepseek_v3 import deepseekv3_args, DeepSeekV3Model | ||
| from torchtitan.models.deepseek_v3.model.args import DeepSeekV3ModelArgs | ||
| from torchtitan.models.deepseek_v3.model.state_dict_adapter import ( | ||
| DeepSeekV3StateDictAdapter, | ||
| ) | ||
| from torchtitan.protocols.train_spec import TrainSpec | ||
|
|
||
| from .parallelize_deepseekv3 import parallelize_deepseekv3 | ||
|
|
||
|
|
||
| def get_train_spec() -> TrainSpec: | ||
| model_args = copy.deepcopy(deepseekv3_args) | ||
|
|
||
| default_args = DeepSeekV3ModelArgs() | ||
| for config, args in model_args.items(): | ||
| if "flex_attn" in config: | ||
| continue | ||
|
|
||
| use_flex_attn = (default_args.use_flex_attn,) | ||
|
||
| attn_mask_type = (default_args.attn_mask_type,) | ||
|
|
||
| return TrainSpec( | ||
| model_cls=DeepSeekV3Model, | ||
| model_args=model_args, | ||
| parallelize_fn=parallelize_deepseekv3, | ||
| pipelining_fn=pipeline_llm, | ||
| build_optimizers_fn=build_optimizers_with_moe_load_balancing, | ||
| build_lr_schedulers_fn=build_lr_schedulers, | ||
| build_dataloader_fn=build_text_dataloader, | ||
| build_tokenizer_fn=build_hf_tokenizer, | ||
| build_loss_fn=build_cross_entropy_loss, | ||
| state_dict_adapter=DeepSeekV3StateDictAdapter, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this trying to do? also do you really need a function for it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tokens_per_expert_by_layeris usually a plain tensor, so you need to call dist.AR on it for stats. but if it's a dtensor, we just need to redistribute