Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@ DRY_RUN=${DRY_RUN:-0}

TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}

# need to turn off expandable segments when using cudagraph, since
# it does not work with cg and nccl yet.
# https://github.com/pytorch/pytorch/issues/158029

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we turn this off only when using cudagraph ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently it's on by default. when using cudagraph, we need to explicitly turn it off with USE_EXPANDABLE_SEGMENTS=False [other commands].

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of turning off expandable segments you can turn off nccl memory registration, as the issue suggests

USE_EXPANDABLE_SEGMENTS=${USE_EXPANDABLE_SEGMENTS:-True}

if [ "$DRY_RUN" = "1" ]; then
# Dry run mode: validate configuration without GPU/distributed setup
echo "Running in DRY RUN mode - configuration validation only"
python scripts/dry_run.py --job.config_file ${CONFIG_FILE} "$@"
else
# Normal training with torchrun
PYTORCH_ALLOC_CONF="expandable_segments:True" \
PYTORCH_ALLOC_CONF="expandable_segments:${USE_EXPANDABLE_SEGMENTS}" \
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/experiments/compiler_toolkit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,9 @@ NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./r
```shell
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor
```

**SimpleFSDP + TP + FlexAttention + transformer-block-bucketing + regional-inductor + cudagraph**

```shell
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor,cudagraph
```
9 changes: 9 additions & 0 deletions torchtitan/experiments/compiler_toolkit/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from contextlib import contextmanager
from typing import Callable

import torch
from torch.distributed.tensor import DTensor, Replicate
Expand Down Expand Up @@ -53,3 +54,11 @@ def register_blockmask_pytree_node():
flatten_with_keys_fn=BlockMask._flatten_with_keys,
serialized_type_name="torch.nn.attention.flex_attention.BlockMask",
)


def end_with_pass(passes: list[Callable], names: list[str]) -> bool:
return (
len(passes) > 0
and (last_pass_name := getattr(passes[-1], "__name__", None))
and (last_pass_name in names)
)
145 changes: 145 additions & 0 deletions torchtitan/experiments/compiler_toolkit/cudagraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
CUDAGraph pass for the compiler toolkit.
This module provides a cudagraph pass that can be applied to graph modules
during compilation.
"""

import warnings
from typing import Any, Callable, Optional, Sequence

import torch
from torch._inductor.cudagraph_trees import _use_cuda_memory_pool_manager
from torch.utils._ordered_set import OrderedSet


def init_global_graph_pool() -> tuple[
torch.cuda.CUDAGraph, torch.cuda._POOL_HANDLE, torch.cuda.Stream
]:
dummy_graph = torch.cuda.CUDAGraph()

# create a global cudagraph memory pool to allow memory reuse across cudagraphs.
graph_pool = torch.cuda.graph_pool_handle()

# create a global cuda stream for graph capture. we need to use a single stream
# for all allocations to the memory pool, otherwise the allocations to separate streams
# will not be used.
graph_capture_stream = torch.cuda.Stream()

# use a dummy graph to keep the global graph pool alive
with (
# suppress an empty cudagraph warning, since we intentionally create
# an empty cudagraph here
warnings.catch_warnings(record=True),
torch.cuda.graph(
dummy_graph,
pool=graph_pool,
stream=graph_capture_stream,
capture_error_mode="thread_local",
),
):
pass

return dummy_graph, graph_pool, graph_capture_stream


(
_global_dummy_graph,
_global_graph_pool,
_global_graph_capture_stream,
) = init_global_graph_pool()
Comment on lines +52 to +56

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work when backward is on a separate stream ? or not an issue?

Copy link
Contributor Author

@BoyuanFeng BoyuanFeng Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this is not an issue currently. since fwd and bwd are on the same cuda stream by default.

cudagraph trees has used the same graph capture stream for both fwd and bwd.
https://github.com/pytorch/pytorch/blob/7a928397cda89b71c24b0efe9db6df7fb04a46cb/torch/_inductor/cudagraph_trees.py#L1945



class CUDAGraphWrapper:
def __init__(
self,
runnable: Callable,
example_inputs: Sequence[Any],
static_input_indices: Optional[tuple[int]] = None,
):
self.runnable = runnable
self.graph_pool = _global_graph_pool
self.stream = _global_graph_capture_stream
self.static_input_indices = OrderedSet(
static_input_indices if static_input_indices is not None else []
)
self.input_indices_to_copy = [
i
for i, inp in enumerate(example_inputs)
if isinstance(inp, torch.Tensor) and i not in self.static_input_indices
]
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
self.has_warmup = False

self.args = None
self.output = None

def copy_static_inputs(self, *args):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if any of the static inputs changes, you'll get silent incorrectness. you might consider at least a config to check this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes added

for i in self.input_indices_to_copy:
self.args[i].copy_(args[i])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could replace this for loop with foreach copy. However, I empirically observed there is only 1 tensor to copy for fwd and 1 tensor to copy for bwd. So no need to add code complexity here.


def __call__(self, *args):
if not self.has_warmup:
self.has_warmup = True
device = torch.cuda.current_device()

# warmup in cudagraph memory pool to avoid fragmentation
# across eager memory pool and cudagraph memory pool.
with _use_cuda_memory_pool_manager(device, self.graph_pool, self.stream):
out = self.runnable(*args)
return out

if self.cudagraph is None:
self.args = args
input_addresses = [
x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args
]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we're assuming that the non tensor inputs are the same every time ? Should we just assert they're all tensors if we're not handling the other cases ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, there would only be tensor and symint (for moe layer). let me add assertion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are also rng_state: torch._C.Generator, used by

graphsafe_run_with_rng_state_2 = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten._scaled_dot_product_flash_attention.default, transpose_20, transpose_21, transpose_22, 0.0, True, scale = 0.25, rng_state = fwd_rng_state_2);

See the last 3 args in P2047035404

self.input_addresses = input_addresses

self.cudagraph = torch.cuda.CUDAGraph()

with torch.cuda.graph(
self.cudagraph, pool=self.graph_pool, stream=self.stream
):
# `output` is managed by pytorch's cudagraph pool
self.output = self.runnable(*args)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could potentially use weakref for output tensor to reduce memory. Will do in a followup pr.


self.copy_static_inputs(*args)
self.cudagraph.replay()
return self.output

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The persistent input and output is not good for memory, as you've commented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes will add in the next pr.



def get_static_input_indices(gm: torch.fx.GraphModule, is_forward: bool) -> list[int]:
"""
Get indices of gm inputs that are static input tensors whose tensor addresses do not
change across runs. Example of static input tensors include weights, buffers, and
outputs of previous cudagraph wrapped functions.
"""
from torch._inductor.utils import count_tangents

static_input_indices = []
if (
is_forward
and (tracing_context := torch._guards.TracingContext.try_get())
and hasattr(tracing_context, "fw_metadata")
):
# for forward, we rely on graph capture (i.e., dynamo or export) to provide
# the correct static input indices stored in tracing context. Typical examples
# include weights and buffers.
static_input_indices = tracing_context.fw_metadata.static_input_indices

elif not is_forward:
# for backward, we identify saved tensors as static inputs, since saved tensors
# are outputs of cudagraph-wrapped forward run. In PT2-generated backward gm,
# saved tensors are always the leading args. So we can get the number of saved
# tensors and generate static input indices.
fixed = count_tangents(gm)
static_input_indices = list(range(fixed))

return static_input_indices
59 changes: 50 additions & 9 deletions torchtitan/experiments/compiler_toolkit/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.distributed.tensor import DTensor
from torchtitan.config import JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.experiments.compiler_toolkit.common_utils import end_with_pass
from torchtitan.tools.logging import logger


Expand Down Expand Up @@ -217,6 +218,7 @@ def compiler(
example_inputs,
passes: List[Callable] = None,
dump_folder: str | None = None,
is_forward: bool = True,
):
"""
Compile a graph module by applying a sequence of compiler passes.
Expand All @@ -239,6 +241,17 @@ def compiler(
)
_dump_gm(dump_folder, gm, f"{name}_before_compiler")

if end_with_pass(passes, ["cudagraph_pass"]):
# cudagraph pass is always the last pass if it is applied
cg_pass = passes[-1]

# to identify static input indices, cudagraph passes behaves differently for
# forward and backward pass. so we explicitly pass the info.
_cg_pass = functools.partial(cg_pass, is_forward=is_forward)

# keep the function name for debug log
passes[-1] = functools.wraps(cg_pass)(_cg_pass)

for pass_fn in passes:
pass_name = (
pass_fn.func.__name__
Expand Down Expand Up @@ -271,17 +284,42 @@ def make_compiler_with_passes(

def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler(
"fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
"fwd_gm",
gm,
example_inputs,
passes=passes,
dump_folder=dump_folder,
is_forward=True,
)

def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler(
"bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
"bwd_gm",
gm,
example_inputs,
passes=passes,
dump_folder=dump_folder,
is_forward=False,
)

return fw_compiler, bw_compiler


def validate_pass_names(pass_names: list[str]) -> None:
if "cudagraph" in pass_names:
assert (
pass_names[-1] == "cudagraph"
), "cudagraph has to be the last pass to apply"

if (
"autobucketing_reordering" in pass_names
and "transformer_block_bucketing" in pass_names
):
raise ValueError(
"Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!"
)


def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfig):
"""
Extract and validate compiler passes from job config.
Expand All @@ -298,16 +336,15 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi
)

pass_names = getattr(job_config.compile, "passes", [])
if (
"autobucketing_reordering" in pass_names
and "transformer_block_bucketing" in pass_names
):
raise ValueError(
"Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!"
)
validate_pass_names(pass_names)
compiler_passes = []

use_cudagraph = "cudagraph" in pass_names

for pass_name in pass_names:
if pass_name == "cudagraph":
continue

if pass_name not in AVAILABLE_COMPILER_PASSES:
raise ValueError(
f"Unknown compiler pass: {pass_name}. "
Expand All @@ -323,6 +360,10 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi
else:
compiler_passes.append(AVAILABLE_COMPILER_PASSES[pass_name])

if use_cudagraph:
# cudagraph should always be the last fx pass to apply
compiler_passes.append(AVAILABLE_COMPILER_PASSES["cudagraph"])

if pass_names:
logger.info(f"Using compiler passes from config: {pass_names}")

Expand Down
24 changes: 24 additions & 0 deletions torchtitan/experiments/compiler_toolkit/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
during compilation. Passes can be selected and configured via job config.
"""

from typing import Any, Sequence

import torch
from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
from torch.fx.passes.regional_inductor import regional_inductor
from torchtitan.experiments.compiler_toolkit.cudagraph import (
CUDAGraphWrapper,
get_static_input_indices,
)
from torchtitan.experiments.simple_fsdp.reshard_after_forward import (
annotate_fsdp_all_gather,
)
Expand Down Expand Up @@ -56,6 +62,23 @@ def regional_inductor_pass(
return regional_inductor(gm, example_inputs)


def cudagraph_pass(
gm: torch.fx.GraphModule, example_inputs: Sequence[Any], is_forward: bool
) -> torch.fx.GraphModule:
"""
Apply cudagraph.

This pass wraps the forward function with cudagraph during compilation and does
not record cudagraph until runtime.
- For the first run, it will warm up operators such as nccl.
- For the second run, it will record cudagraph and replay cudagraph.
- For the following runs, it will replay cudagraph.
"""
static_input_indices = get_static_input_indices(gm, is_forward)
gm.forward = CUDAGraphWrapper(gm.forward, example_inputs, static_input_indices)
return gm


def validate_flex_attn_annotation_pass(
gm: torch.fx.GraphModule,
) -> torch.fx.GraphModule:
Expand Down Expand Up @@ -88,4 +111,5 @@ def fsdp_reshard_after_fwd_pass(
"autobucketing_reordering": autobucketing_reordering_pass,
"transformer_block_bucketing": transformer_block_bucketing_reordering_pass,
"regional_inductor": regional_inductor_pass,
"cudagraph": cudagraph_pass,
}
29 changes: 29 additions & 0 deletions torchtitan/experiments/compiler_toolkit/tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]:
"llama3_fsdp_tp_manualbucketing",
ngpu=4,
),
OverrideDefinitions(
[
[
"--model.name compiler_toolkit.llama3",
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.tensor_parallel_degree 2",
"--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config",
"--compile.passes cudagraph",
],
],
"llama3 FSDP+TP+cudagraph",
"llama3_fsdp_tp_cudagraph",
ngpu=4,
),
OverrideDefinitions(
[
[
Expand Down Expand Up @@ -86,6 +100,21 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]:
"llama3_fsdp_tp_flexattn_autobucketing_regional_inductor",
ngpu=4,
),
OverrideDefinitions(
[
[
"--model.name compiler_toolkit.llama3",
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.tensor_parallel_degree 2",
"--model.flavor debugmodel_flex_attn",
"--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config",
"--compile.passes autobucketing_reordering,regional_inductor,cudagraph",
],
],
"llama3 FSDP+TP+FlexAttn autobucketing regional_inductor+cudagraph",
"llama3_fsdp_tp_flexattn_autobucketing_regional_inductor_cudagraph",
ngpu=4,
),
OverrideDefinitions(
[
[
Expand Down
Loading
Loading