44# LICENSE file in the root directory of this source tree.
55
66import copy
7+ from itertools import dropwhile
78
89import torch
910import torch .fx as fx
11+ from torch ._inductor .fx_passes .bucketing import is_wait_tensor
12+ from torch ._logging import trace_structured
13+
14+
15+ def _add_compute_annotations (gm : fx .GraphModule , tag : str ):
16+ """Add compute_region annotations to nodes without custom metadata."""
17+ for n in gm .graph .nodes :
18+ if n .op == "placeholder" :
19+ continue
20+ if n .meta .get ("custom" , None ) is None :
21+ n .meta ["custom" ] = {"compute_region" : tag }
22+ else :
23+ assert "comm_region" in n .meta ["custom" ]
24+ val = n .meta ["custom" ]["comm_region" ]
25+ n .meta ["custom" ]["comm_region" ] = tag + " " + val
26+
27+
28+ def _move_wait_tensors_to_compute_region (gm : fx .GraphModule , tag : str ):
29+ """Move wait_tensor nodes from comm_region to compute_region of their users."""
30+ for n in gm .graph .nodes :
31+ if n .op == "placeholder" :
32+ continue
33+ if "comm_region" in n .meta ["custom" ] and is_wait_tensor (n ):
34+ assert len (n .users ) >= 1 , "wait tensor must have at least one user"
35+ user : fx .Node = next (iter (n .users ))
36+ if "compute_region" in user .meta ["custom" ]:
37+ n .meta ["custom" ].pop ("comm_region" )
38+ n .meta ["custom" ].update ({"compute_region" : tag + " " + "wait" })
39+ if n .next is not user :
40+ user .prepend (n )
1041
1142
1243def multiplex_fw_bw_graph (
13- fw_gm : fx .GraphModule , bw_gm : fx .GraphModule
44+ fw_gm : fx .GraphModule , bw_gm : fx .GraphModule , overlap_with_annotations : bool = True
1445) -> fx .GraphModule :
1546 """
1647 Multiplexes forward and backward graphs into a single unified graph module.
@@ -32,67 +63,116 @@ def multiplex_fw_bw_graph(
3263 Note:
3364 The function preserves node metadata during the merging process.
3465 """
35- # Mapping to track correspondence between backward graph nodes and new nodes
66+ if overlap_with_annotations :
67+ _add_compute_annotations (fw_gm , "forward" )
68+ _add_compute_annotations (bw_gm , "backward" )
69+ _move_wait_tensors_to_compute_region (fw_gm , "forward" )
70+ _move_wait_tensors_to_compute_region (bw_gm , "backward" )
71+
72+ # Mapping to track correspondence between forward graph nodes and new nodes
3673 old_node_to_new_node : dict [torch .fx .Node , torch .fx .Node ] = {}
3774
38- # Start with a deep copy of the forward graph as the base
39- multiplexed_gm = copy .deepcopy (fw_gm )
75+ # Start with a deep copy of the backward graph as the base
76+ multiplexed_gm = copy .deepcopy (bw_gm )
4077
41- # Collect all placeholder nodes from the backward graph
78+ # Collect all placeholder nodes from all the graphs
4279 bw_placeholders = bw_gm .graph .find_nodes (op = "placeholder" )
4380 fw_placeholders = fw_gm .graph .find_nodes (op = "placeholder" )
81+ insert_point = multiplexed_gm .graph .find_nodes (op = "placeholder" )[- 1 ]
4482
45- # Insert backward placeholders at the beginning of the multiplexed graph
46- # Reversed order ensures correct execution sequence
47- with multiplexed_gm .graph .inserting_before ():
48- for n in reversed (bw_placeholders ):
83+ # Insert forward placeholders after the backward placeholders of the multiplexed graph
84+ for n in fw_placeholders :
85+ with multiplexed_gm .graph .inserting_after (insert_point ):
4986 new_placeholder = multiplexed_gm .graph .placeholder (n .name )
50- new_placeholder .meta = n .meta
87+ new_placeholder .meta = copy . copy ( n .meta )
5188 new_placeholder .target = new_placeholder .name
5289 old_node_to_new_node [n ] = new_placeholder
90+ insert_point = new_placeholder
5391
54- # Find the last placeholder and the output node in the multiplexed graph
55- multiplxed_gm_placeholders = multiplexed_gm .graph .find_nodes (op = "placeholder" )
56- assert len (multiplxed_gm_placeholders ) == (
57- len (fw_placeholders ) + len (bw_placeholders )
92+ multiplexed_gm_placeholders = multiplexed_gm .graph .find_nodes (op = "placeholder" )
93+ assert len (multiplexed_gm_placeholders ) == len (fw_placeholders ) + len (
94+ bw_placeholders
5895 )
59- insert_point = multiplxed_gm_placeholders [- 1 ]
60-
61- # Copy all computation nodes from backward graph into multiplexed graph
62- fw_outputs = fw_gm .graph .find_nodes (op = "output" )
63- bw_outputs = bw_gm .graph .find_nodes (op = "output" )
64- assert len (bw_outputs ) == 1 and len (fw_outputs ) == 1
65- bw_graph_op_node = bw_outputs [0 ]
66- for n in bw_gm .graph .nodes :
67- if n .op == "placeholder" :
68- continue
69- if n .op == "output" :
70- continue
71- with multiplexed_gm .graph .inserting_after (insert_point ):
96+ fw_nodes_iter = iter (fw_gm .graph .nodes )
97+ fw_nodes_iter = dropwhile (lambda n : n .op == "placeholder" , fw_nodes_iter )
98+ # Initialize the forward node to be the first non-placeholder node
99+ fn = next (fw_nodes_iter )
100+ if overlap_with_annotations :
101+ # Interleave forward and backward nodes to create overlap pattern:
102+ # bw_compute (if any) -> bw_comm -> fw_compute (if any) -> fw_comm -> [repeat]
103+ # This allows bw_comm to overlap with fw_compute, and fw_comm to overlap with bw_compute
104+ bw_in_comm = False
105+ for bn in multiplexed_gm .graph .nodes :
106+ if bn .op == "placeholder" or bn .op == "output" :
107+ continue
108+ # Track when we enter a backward comm region
109+ if "comm_region" in bn .meta ["custom" ] and not bw_in_comm :
110+ bw_in_comm = True
111+ # When we transition from bw_comm to bw_compute, insert forward nodes
112+ elif "compute_region" in bn .meta ["custom" ] and bw_in_comm :
113+ bw_in_comm = False
114+ fw_in_comm = False
115+ insert_point = bn
116+ # Insert forward nodes before this bw_compute node
117+ # Note: We cannot reorder nodes within a graph, only their relative order between graphs
118+ while fn .op != "output" :
119+ if "comm_region" in fn .meta ["custom" ] and not fw_in_comm :
120+ fw_in_comm = True
121+ elif "compute_region" in fn .meta ["custom" ] and fw_in_comm :
122+ # Stop when we reach the next fw_compute after fw_comm
123+ # This ensures we insert one fw_compute + fw_comm cycle per bw_comm -> bw_compute transition
124+ # If fw starts with comm (no compute before it), we still insert it to overlap with future bw_compute
125+ fw_in_comm = False
126+ break
127+ with multiplexed_gm .graph .inserting_before (insert_point ):
128+ # Copy node and remap its arguments using the node mapping
129+ new_node = multiplexed_gm .graph .node_copy (
130+ fn , lambda x : old_node_to_new_node [x ]
131+ )
132+ new_node .meta = copy .copy (fn .meta )
133+ old_node_to_new_node [fn ] = new_node
134+ fn = next (fw_nodes_iter )
135+ # Insert any remaining forward nodes at the end
136+ # If overlap_with_annotations is False, this concatenates all fw nodes after bw nodes
137+ insert_point = multiplexed_gm .graph .find_nodes (op = "output" )[- 1 ]
138+ while fn .op != "output" :
139+ with multiplexed_gm .graph .inserting_before (insert_point ):
72140 # Copy node and remap its arguments using the node mapping
73141 new_node = multiplexed_gm .graph .node_copy (
74- n , lambda x : old_node_to_new_node [x ]
142+ fn , lambda x : old_node_to_new_node [x ]
75143 )
76- new_node .meta = n . meta
77- old_node_to_new_node [n ] = new_node
78- insert_point = new_node
144+ new_node .meta = copy . copy ( fn . meta )
145+ old_node_to_new_node [fn ] = new_node
146+ fn = next ( fw_nodes_iter )
79147
80- # Collect output arguments from backward graph, remapping to new nodes
81- bw_op_node_args = [
148+ # Collect output arguments from forward graph, remapping to new nodes
149+ fw_outputs = fw_gm .graph .find_nodes (op = "output" )
150+ multiplexed_graph_outputs = multiplexed_gm .graph .find_nodes (op = "output" )
151+ assert len (multiplexed_graph_outputs ) == 1 and len (fw_outputs ) == 1
152+ fw_graph_op_node = fw_outputs [0 ]
153+ fw_op_node_args = [
82154 old_node_to_new_node [n ] if n is not None else None
83- for n in bw_graph_op_node .args [0 ]
155+ for n in fw_graph_op_node .args [0 ]
84156 ]
85157
86- # Collect output arguments from multiplexed graph (will contain only fwd_outs)
87- multiplexed_graph_outputs = multiplexed_gm .graph .find_nodes (op = "output" )
88- assert len (multiplexed_graph_outputs ) == 1
158+ # Collect output arguments from multiplexed graph (will contain only bwd_outs)
89159 multiplexed_graph_op_node = multiplexed_graph_outputs [0 ]
90- fw_op_node_args = list (multiplexed_graph_op_node .args [0 ])
160+ bw_op_node_args = list (multiplexed_graph_op_node .args [0 ])
91161
92162 # Update output node args to prepend backward outputs before forward outputs
93163 multiplexed_graph_op_node .args = (tuple (bw_op_node_args + fw_op_node_args ),)
94164
95165 multiplexed_gm .graph .eliminate_dead_code ()
96166 multiplexed_gm .graph .lint ()
97167 multiplexed_gm .recompile ()
168+ trace_structured (
169+ "artifact" ,
170+ metadata_fn = lambda : {
171+ "name" : "autoparallel_multiplexed_graph" ,
172+ "encoding" : "string" ,
173+ },
174+ payload_fn = lambda : multiplexed_gm .print_readable (
175+ print_output = False , include_stride = True , include_device = True
176+ ),
177+ )
98178 return multiplexed_gm
0 commit comments