Skip to content

Commit fa51dea

Browse files
DualPipeV Fw-Bw Overlapping pass with User Annotations
stack-info: PR: #261, branch: sanketpurandare/stack/2
1 parent 10d8208 commit fa51dea

File tree

4 files changed

+173
-89
lines changed

4 files changed

+173
-89
lines changed

autoparallel/_passes/graph_multiplex.py

Lines changed: 118 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,44 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import copy
7+
from itertools import dropwhile
78

89
import torch
910
import torch.fx as fx
11+
from torch._inductor.fx_passes.bucketing import is_wait_tensor
12+
from torch._logging import trace_structured
13+
14+
15+
def _add_compute_annotations(gm: fx.GraphModule, tag: str):
16+
"""Add compute_region annotations to nodes without custom metadata."""
17+
for n in gm.graph.nodes:
18+
if n.op == "placeholder":
19+
continue
20+
if n.meta.get("custom", None) is None:
21+
n.meta["custom"] = {"compute_region": tag}
22+
else:
23+
assert "comm_region" in n.meta["custom"]
24+
val = n.meta["custom"]["comm_region"]
25+
n.meta["custom"]["comm_region"] = tag + " " + val
26+
27+
28+
def _move_wait_tensors_to_compute_region(gm: fx.GraphModule, tag: str):
29+
"""Move wait_tensor nodes from comm_region to compute_region of their users."""
30+
for n in gm.graph.nodes:
31+
if n.op == "placeholder":
32+
continue
33+
if "comm_region" in n.meta["custom"] and is_wait_tensor(n):
34+
assert len(n.users) >= 1, "wait tensor must have at least one user"
35+
user: fx.Node = next(iter(n.users))
36+
if "compute_region" in user.meta["custom"]:
37+
n.meta["custom"].pop("comm_region")
38+
n.meta["custom"].update({"compute_region": tag + " " + "wait"})
39+
if n.next is not user:
40+
user.prepend(n)
1041

1142

1243
def multiplex_fw_bw_graph(
13-
fw_gm: fx.GraphModule, bw_gm: fx.GraphModule
44+
fw_gm: fx.GraphModule, bw_gm: fx.GraphModule, overlap_with_annotations: bool = True
1445
) -> fx.GraphModule:
1546
"""
1647
Multiplexes forward and backward graphs into a single unified graph module.
@@ -32,67 +63,116 @@ def multiplex_fw_bw_graph(
3263
Note:
3364
The function preserves node metadata during the merging process.
3465
"""
35-
# Mapping to track correspondence between backward graph nodes and new nodes
66+
if overlap_with_annotations:
67+
_add_compute_annotations(fw_gm, "forward")
68+
_add_compute_annotations(bw_gm, "backward")
69+
_move_wait_tensors_to_compute_region(fw_gm, "forward")
70+
_move_wait_tensors_to_compute_region(bw_gm, "backward")
71+
72+
# Mapping to track correspondence between forward graph nodes and new nodes
3673
old_node_to_new_node: dict[torch.fx.Node, torch.fx.Node] = {}
3774

38-
# Start with a deep copy of the forward graph as the base
39-
multiplexed_gm = copy.deepcopy(fw_gm)
75+
# Start with a deep copy of the backward graph as the base
76+
multiplexed_gm = copy.deepcopy(bw_gm)
4077

41-
# Collect all placeholder nodes from the backward graph
78+
# Collect all placeholder nodes from all the graphs
4279
bw_placeholders = bw_gm.graph.find_nodes(op="placeholder")
4380
fw_placeholders = fw_gm.graph.find_nodes(op="placeholder")
81+
insert_point = multiplexed_gm.graph.find_nodes(op="placeholder")[-1]
4482

45-
# Insert backward placeholders at the beginning of the multiplexed graph
46-
# Reversed order ensures correct execution sequence
47-
with multiplexed_gm.graph.inserting_before():
48-
for n in reversed(bw_placeholders):
83+
# Insert forward placeholders after the backward placeholders of the multiplexed graph
84+
for n in fw_placeholders:
85+
with multiplexed_gm.graph.inserting_after(insert_point):
4986
new_placeholder = multiplexed_gm.graph.placeholder(n.name)
50-
new_placeholder.meta = n.meta
87+
new_placeholder.meta = copy.copy(n.meta)
5188
new_placeholder.target = new_placeholder.name
5289
old_node_to_new_node[n] = new_placeholder
90+
insert_point = new_placeholder
5391

54-
# Find the last placeholder and the output node in the multiplexed graph
55-
multiplxed_gm_placeholders = multiplexed_gm.graph.find_nodes(op="placeholder")
56-
assert len(multiplxed_gm_placeholders) == (
57-
len(fw_placeholders) + len(bw_placeholders)
92+
multiplexed_gm_placeholders = multiplexed_gm.graph.find_nodes(op="placeholder")
93+
assert len(multiplexed_gm_placeholders) == len(fw_placeholders) + len(
94+
bw_placeholders
5895
)
59-
insert_point = multiplxed_gm_placeholders[-1]
60-
61-
# Copy all computation nodes from backward graph into multiplexed graph
62-
fw_outputs = fw_gm.graph.find_nodes(op="output")
63-
bw_outputs = bw_gm.graph.find_nodes(op="output")
64-
assert len(bw_outputs) == 1 and len(fw_outputs) == 1
65-
bw_graph_op_node = bw_outputs[0]
66-
for n in bw_gm.graph.nodes:
67-
if n.op == "placeholder":
68-
continue
69-
if n.op == "output":
70-
continue
71-
with multiplexed_gm.graph.inserting_after(insert_point):
96+
fw_nodes_iter = iter(fw_gm.graph.nodes)
97+
fw_nodes_iter = dropwhile(lambda n: n.op == "placeholder", fw_nodes_iter)
98+
# Initialize the forward node to be the first non-placeholder node
99+
fn = next(fw_nodes_iter)
100+
if overlap_with_annotations:
101+
# Interleave forward and backward nodes to create overlap pattern:
102+
# bw_compute (if any) -> bw_comm -> fw_compute (if any) -> fw_comm -> [repeat]
103+
# This allows bw_comm to overlap with fw_compute, and fw_comm to overlap with bw_compute
104+
bw_in_comm = False
105+
for bn in multiplexed_gm.graph.nodes:
106+
if bn.op == "placeholder" or bn.op == "output":
107+
continue
108+
# Track when we enter a backward comm region
109+
if "comm_region" in bn.meta["custom"] and not bw_in_comm:
110+
bw_in_comm = True
111+
# When we transition from bw_comm to bw_compute, insert forward nodes
112+
elif "compute_region" in bn.meta["custom"] and bw_in_comm:
113+
bw_in_comm = False
114+
fw_in_comm = False
115+
insert_point = bn
116+
# Insert forward nodes before this bw_compute node
117+
# Note: We cannot reorder nodes within a graph, only their relative order between graphs
118+
while fn.op != "output":
119+
if "comm_region" in fn.meta["custom"] and not fw_in_comm:
120+
fw_in_comm = True
121+
elif "compute_region" in fn.meta["custom"] and fw_in_comm:
122+
# Stop when we reach the next fw_compute after fw_comm
123+
# This ensures we insert one fw_compute + fw_comm cycle per bw_comm -> bw_compute transition
124+
# If fw starts with comm (no compute before it), we still insert it to overlap with future bw_compute
125+
fw_in_comm = False
126+
break
127+
with multiplexed_gm.graph.inserting_before(insert_point):
128+
# Copy node and remap its arguments using the node mapping
129+
new_node = multiplexed_gm.graph.node_copy(
130+
fn, lambda x: old_node_to_new_node[x]
131+
)
132+
new_node.meta = copy.copy(fn.meta)
133+
old_node_to_new_node[fn] = new_node
134+
fn = next(fw_nodes_iter)
135+
# Insert any remaining forward nodes at the end
136+
# If overlap_with_annotations is False, this concatenates all fw nodes after bw nodes
137+
insert_point = multiplexed_gm.graph.find_nodes(op="output")[-1]
138+
while fn.op != "output":
139+
with multiplexed_gm.graph.inserting_before(insert_point):
72140
# Copy node and remap its arguments using the node mapping
73141
new_node = multiplexed_gm.graph.node_copy(
74-
n, lambda x: old_node_to_new_node[x]
142+
fn, lambda x: old_node_to_new_node[x]
75143
)
76-
new_node.meta = n.meta
77-
old_node_to_new_node[n] = new_node
78-
insert_point = new_node
144+
new_node.meta = copy.copy(fn.meta)
145+
old_node_to_new_node[fn] = new_node
146+
fn = next(fw_nodes_iter)
79147

80-
# Collect output arguments from backward graph, remapping to new nodes
81-
bw_op_node_args = [
148+
# Collect output arguments from forward graph, remapping to new nodes
149+
fw_outputs = fw_gm.graph.find_nodes(op="output")
150+
multiplexed_graph_outputs = multiplexed_gm.graph.find_nodes(op="output")
151+
assert len(multiplexed_graph_outputs) == 1 and len(fw_outputs) == 1
152+
fw_graph_op_node = fw_outputs[0]
153+
fw_op_node_args = [
82154
old_node_to_new_node[n] if n is not None else None
83-
for n in bw_graph_op_node.args[0]
155+
for n in fw_graph_op_node.args[0]
84156
]
85157

86-
# Collect output arguments from multiplexed graph (will contain only fwd_outs)
87-
multiplexed_graph_outputs = multiplexed_gm.graph.find_nodes(op="output")
88-
assert len(multiplexed_graph_outputs) == 1
158+
# Collect output arguments from multiplexed graph (will contain only bwd_outs)
89159
multiplexed_graph_op_node = multiplexed_graph_outputs[0]
90-
fw_op_node_args = list(multiplexed_graph_op_node.args[0])
160+
bw_op_node_args = list(multiplexed_graph_op_node.args[0])
91161

92162
# Update output node args to prepend backward outputs before forward outputs
93163
multiplexed_graph_op_node.args = (tuple(bw_op_node_args + fw_op_node_args),)
94164

95165
multiplexed_gm.graph.eliminate_dead_code()
96166
multiplexed_gm.graph.lint()
97167
multiplexed_gm.recompile()
168+
trace_structured(
169+
"artifact",
170+
metadata_fn=lambda: {
171+
"name": "autoparallel_multiplexed_graph",
172+
"encoding": "string",
173+
},
174+
payload_fn=lambda: multiplexed_gm.print_readable(
175+
print_output=False, include_stride=True, include_device=True
176+
),
177+
)
98178
return multiplexed_gm

autoparallel/_testing/models/dsv3.py

Lines changed: 51 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Callable, ClassVar, Literal, Optional, Tuple, Union
99

1010
import torch
11+
import torch.fx.traceback as fx_traceback
1112
import torch.nn.functional as F
1213
import triton
1314
import triton.language as tl
@@ -631,61 +632,63 @@ def forward(
631632

632633

633634
def _token_dispatch(routed_input, num_tokens_per_expert, axis_name):
634-
# annotate module input placements/sharding with input_layouts
635-
# ep_size = device_mesh.shape[0]
636-
ep_size = axis_size(axis_name)
637-
638-
# generate the input splits and output splits for all-to-all
639-
with torch.no_grad():
640-
num_tokens_per_expert_group = all_to_all(
641-
num_tokens_per_expert,
642-
None,
643-
None,
635+
with fx_traceback.annotate({"comm_region": "token_dispatch"}):
636+
# annotate module input placements/sharding with input_layouts
637+
# ep_size = device_mesh.shape[0]
638+
ep_size = axis_size(axis_name)
639+
640+
# generate the input splits and output splits for all-to-all
641+
with torch.no_grad():
642+
num_tokens_per_expert_group = all_to_all(
643+
num_tokens_per_expert,
644+
None,
645+
None,
646+
axis_name,
647+
)
648+
input_splits = (
649+
num_tokens_per_expert.view(ep_size, -1)
650+
.sum(dim=1)
651+
.to(torch.device("cpu"), non_blocking=True)
652+
)
653+
# NOTE: this would incur a device-to-host sync
654+
output_splits = (
655+
num_tokens_per_expert_group.view(ep_size, -1)
656+
.sum(dim=1)
657+
.to(torch.device("cpu"), non_blocking=False)
658+
)
659+
input_splits = input_splits.tolist()
660+
output_splits = output_splits.tolist()
661+
662+
# perform all-to-all
663+
routed_input = all_to_all(
664+
routed_input,
665+
output_splits,
666+
input_splits,
644667
axis_name,
645668
)
646-
input_splits = (
647-
num_tokens_per_expert.view(ep_size, -1)
648-
.sum(dim=1)
649-
.to(torch.device("cpu"), non_blocking=True)
650-
)
651-
# NOTE: this would incur a device-to-host sync
652-
output_splits = (
653-
num_tokens_per_expert_group.view(ep_size, -1)
654-
.sum(dim=1)
655-
.to(torch.device("cpu"), non_blocking=False)
656-
)
657-
input_splits = input_splits.tolist()
658-
output_splits = output_splits.tolist()
659-
660-
# perform all-to-all
661-
routed_input = all_to_all(
662-
routed_input,
663-
output_splits,
664-
input_splits,
665-
axis_name,
666-
)
667669

668-
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
669-
# However, the num_tokens_per_expert_group is not of the final target format
670-
# [#tokens for local expert 0, #tokens for local expert 1, ...]
671-
# Rather, it is of the format
672-
# [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ...,
673-
# #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...]
674-
# We need to perform another shuffle to get the correct format -- this is done via the function
675-
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
676-
# each expert gets locally is a multiple of ALIGN_SIZE_M.
670+
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
671+
# However, the num_tokens_per_expert_group is not of the final target format
672+
# [#tokens for local expert 0, #tokens for local expert 1, ...]
673+
# Rather, it is of the format
674+
# [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ...,
675+
# #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...]
676+
# We need to perform another shuffle to get the correct format -- this is done via the function
677+
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
678+
# each expert gets locally is a multiple of ALIGN_SIZE_M.
677679

678-
return routed_input, num_tokens_per_expert_group, input_splits, output_splits
680+
return routed_input, num_tokens_per_expert_group, input_splits, output_splits
679681

680682

681683
def _token_combine(routed_output, input_splits, output_splits, axis_name):
682-
routed_output = all_to_all(
683-
routed_output,
684-
input_splits,
685-
output_splits,
686-
axis_name,
687-
)
688-
return routed_output
684+
with fx_traceback.annotate({"comm_region": "token_combine"}):
685+
routed_output = all_to_all(
686+
routed_output,
687+
input_splits,
688+
output_splits,
689+
axis_name,
690+
)
691+
return routed_output
689692

690693

691694
# @torch.library.custom_op("autoparallel::local_mapped_region", mutates_args=())

autoparallel/graph_pp_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ def get_multiplexed_graph_callables(
7070
for fw_stage_idx, fw_stage_graph_callables in stage_graphs.items():
7171
if bw_stage_idx != fw_stage_idx:
7272
fw_bw_module = multiplex_fw_bw_graph(
73-
fw_stage_graph_callables.fw, bw_stage_graph_callables.full_bw
73+
fw_stage_graph_callables.fw,
74+
bw_stage_graph_callables.full_bw,
75+
overlap_with_annotations=True,
7476
)
7577
multiplexed_graph_callables[(fw_stage_idx, bw_stage_idx)] = fw_bw_module
7678
return multiplexed_graph_callables

examples/example_pp_graph_passes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,8 @@ def _run_graph_test(
183183
"""Execute forward and backward passes with specified graph options."""
184184
if use_multiplexed_graph:
185185
multiplexed_fw_bw_module = multiplex_fw_bw_graph(
186-
graph_modules.fw, graph_modules.full_bw
186+
graph_modules.fw, graph_modules.full_bw, overlap_with_annotations=True
187187
)
188-
189188
with (
190189
FakeTensorMode(
191190
allow_non_fake_inputs=True,

0 commit comments

Comments
 (0)