-
Notifications
You must be signed in to change notification settings - Fork 615
Autoparallel as an experiment in main #2054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 53 commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
3ccd12c
[WIP] Integrate autoparallel into torchtitan
wconstab e6d2caf
Autoparallel support for DP-only, DP+TP, or TP-only
wconstab 68476b3
Update CLI inductor configs for bucketing/reordering
wconstab 9ee9f75
add back llama3_autoparallel_init_fn
wconstab f6e4099
Track API change from new AOTAutograd interface
ezyang 4d7ee8a
Support forcing the model into bf16 for perf debugging
wconstab b801d0b
Integrate MixedPrecision with AutoParallel and fix example_inputs
wconstab b099cf9
Use in-place compile API
ezyang b3587d9
Fix bucketing pass configs
wconstab 42c2c07
Support both eager and autoparallel init based on model.name
wconstab d93845e
Remove llama3 init weights hack
wconstab 60f5f11
Print profiling manifold url
wconstab 6c782eb
Support new compile API from autoparallel PR #77
wconstab 4712163
Fix bucket sizes for AutoParallel 1D (#1545)
fmassa 3f04d22
Add support for loss parallel (#1546)
fmassa 8e50870
Add config for running simple-fsdp bucketing/reordering passes
wconstab 91c5639
Hook up deepseekv3_auto_parallel
wconstab 1233902
[dsv3] patch graph break fix, works up until sharding rules
xmfan 4f8677b
update simplefsdp pass config
ruisizhang123 714cc5b
[dsv3] disable MoE while we fix local_map, works up until optimizer
xmfan 45647b3
Merge branch 'main' into whc/merge_autoparallel
wconstab bfa9f7f
tweak ds3 model.py to reflect main branch for DS3 baseline can run (#…
bdhirsh 75fb2eb
add simplefsdp's autobucketing pass entry (#1658)
ruisizhang123 8769396
[dsv3] 1D AP w/ local_map
xmfan db22479
[dsv3] Turn off Flex for AP
xmfan 87ef4e0
Merge branch 'main' into autoparallel
xmfan 9dc0bd8
Update to new model registration API
xmfan c6e25bd
Whc/knobs (#1994)
wconstab 26410e8
Merge remote-tracking branch 'origin/main' into autoparallel
xmfan e6ea814
lint
xmfan 7abede8
undo moe patching
xmfan d2e76b7
move inductor config into experiment folders
xmfan 472b4ad
fix local_map moe patch
xmfan ac0def9
move flex disables into experiment folder
xmfan a24ef07
fix newline
xmfan da611e4
no longer necessary train.py changes
xmfan 6cc8caa
restore comment
xmfan d54a6d4
temporarily extend hacky optimizer stuff to make dsv3 ap 1d run again
xmfan acd9588
Merge remote-tracking branch 'origin/main' into autoparallel
xmfan 2b1fb92
fix moduledict with AP https://github.com/meta-pytorch/autoparallel/p…
xmfan 68245d6
fix moe_enabled
xmfan e592e22
lint
xmfan 737ad2c
job config
xmfan 64e6050
remove MAST specific profiling logs
xmfan de6dca6
update readme
xmfan fe0b6cc
format readme
xmfan 2b37f30
comments
xmfan bc18d87
manual redistribute
xmfan 5fdf737
imports
xmfan c1a307f
mesh
xmfan 58d5349
Merge remote-tracking branch 'origin/main' into autoparallel
xmfan c480cd1
no flex
xmfan aa739f6
update with new moe
xmfan f03fe9e
remove transformers_backend
xmfan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,9 @@ | |
| "vlm", | ||
| "compiler_toolkit.deepseek_v3", | ||
| "compiler_toolkit.llama3", | ||
| "transformers_backend", | ||
|
||
| "transformers_modeling_backend", | ||
| "auto_parallel.llama3", | ||
| "auto_parallel.deepseek_v3", | ||
| ] | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| ## Auto Parallel | ||
|
|
||
| ### Overview | ||
|
|
||
| The Auto Parallel experiment integrates PyTorch's AutoParallel framework with TorchTitan to automatically optimize distributed training parallelism strategies given a device mesh. Instead of manually configuring parallelism layouts, AutoParallel uses cost-based analysis to determine optimal sharding placements for model parameters, activations, and gradients. | ||
|
|
||
| ### Requirements | ||
|
|
||
| Requires installing [git@github.com:meta-pytorch/autoparallel.git](https://github.com/meta-pytorch/autoparallel) | ||
|
|
||
| ### Single Node | ||
|
|
||
| **Llama3** | ||
|
|
||
| `CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.llama3 --parallelism.tensor_parallel_degree 4 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config` | ||
|
|
||
| **DeepSeekv3** | ||
|
|
||
| `CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config` |
50 changes: 50 additions & 0 deletions
50
torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
|
||
| import copy | ||
|
|
||
| from torchtitan.components.loss import build_cross_entropy_loss | ||
| from torchtitan.components.lr_scheduler import build_lr_schedulers | ||
| from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing | ||
| from torchtitan.components.tokenizer import build_hf_tokenizer | ||
| from torchtitan.distributed.pipeline_parallel import pipeline_llm | ||
| from torchtitan.hf_datasets.text_datasets import build_text_dataloader | ||
|
|
||
| from torchtitan.models.deepseek_v3 import deepseekv3_args, DeepSeekV3Model | ||
| from torchtitan.models.deepseek_v3.model.args import DeepSeekV3ModelArgs | ||
| from torchtitan.models.deepseek_v3.model.state_dict_adapter import ( | ||
| DeepSeekV3StateDictAdapter, | ||
| ) | ||
| from torchtitan.protocols.train_spec import TrainSpec | ||
|
|
||
| from .parallelize_deepseekv3 import parallelize_deepseekv3 | ||
|
|
||
|
|
||
| def get_train_spec() -> TrainSpec: | ||
| model_args = copy.deepcopy(deepseekv3_args) | ||
|
|
||
| default_args = DeepSeekV3ModelArgs() | ||
| for config, args in model_args.items(): | ||
| if "flex_attn" in config: | ||
| continue | ||
|
|
||
| args.attn_type = default_args.attn_type | ||
| args.attn_mask_type = default_args.attn_mask_type | ||
|
|
||
| return TrainSpec( | ||
| model_cls=DeepSeekV3Model, | ||
| model_args=model_args, | ||
| parallelize_fn=parallelize_deepseekv3, | ||
| pipelining_fn=pipeline_llm, | ||
| build_optimizers_fn=build_optimizers_with_moe_load_balancing, | ||
| build_lr_schedulers_fn=build_lr_schedulers, | ||
| build_dataloader_fn=build_text_dataloader, | ||
| build_tokenizer_fn=build_hf_tokenizer, | ||
| build_loss_fn=build_cross_entropy_loss, | ||
| state_dict_adapter=DeepSeekV3StateDictAdapter, | ||
| ) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this has been renamed to the item below, may need to rebase / remove