@@ -66,6 +66,94 @@ def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool:
6666# mypy: ignore-errors
6767
6868
69+ def force_save_fsdp_all_gather (graph : torch .fx .Graph ) -> None :
70+ """
71+ Force save all_gather nodes from simple fsdp in the graph.
72+ This pass should be added in torch._inductor.config.joint_custom_post_pass
73+ """
74+ nodes_to_save = []
75+ primal_origins = []
76+ primals = graph .find_nodes (op = "placeholder" )
77+ # 1. Find last all_gather from each placeholder
78+ for primal in primals :
79+ node = primal
80+ last_ag_chain_node = None
81+ while True :
82+ if len (node .users ) != 1 :
83+ break
84+ user = next (iter (node .users ))
85+ if len (user .all_input_nodes ) > 1 :
86+ break
87+ node = user
88+ if is_all_gather_into_tensor (node ):
89+ last_ag_chain_node = node
90+ if last_ag_chain_node is not None :
91+ # 2. Find last wait_tensor from last all_gather
92+ last_ag_wait_node = next (iter (last_ag_chain_node .users ))
93+ assert is_wait_tensor (last_ag_wait_node )
94+ assert is_wait_tensor_from_fsdp (last_ag_wait_node )
95+ # 3. Continue the linear chain from the last wait_tensor
96+ w = last_ag_wait_node
97+ while True :
98+ if len (w .users ) != 1 :
99+ # Capture this pattern:
100+ # %wait_tensor_5 : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_gather_into_tensor_5,), kwargs = {}) # noqa: E501
101+ # %split : [num_users=4] = call_function[target=torch.ops.aten.split.Tensor](args = (%wait_tensor_5, 576), kwargs = {})
102+ # %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
103+ # %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
104+ # %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 2), kwargs = {})
105+ # %getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 3), kwargs = {})
106+ # %cat_1 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%getitem_2, %getitem_3, %getitem_4, %getitem_5], 1), kwargs = {}) # noqa: E501
107+ if (
108+ w .op == "call_function"
109+ and w .target == torch .ops .aten .split .Tensor
110+ ):
111+ if all (
112+ split_user .op == "call_function"
113+ and split_user .target == operator .getitem
114+ and len (split_user .users ) == 1
115+ for split_user in w .users
116+ ):
117+ getitem_users = list (
118+ next (iter (getitem_node .users ))
119+ for getitem_node in w .users
120+ )
121+ potential_cat_op = getitem_users [0 ]
122+ if all (
123+ potential_cat_op == getitem_user
124+ for getitem_user in getitem_users
125+ ) and (
126+ potential_cat_op .op == "call_function"
127+ and potential_cat_op .target
128+ == torch .ops .aten .cat .default
129+ ):
130+ w = potential_cat_op
131+ continue
132+ break
133+ user = next (iter (w .users ))
134+ if len (user .all_input_nodes ) > 1 :
135+ break
136+ w = user
137+ # 4. Stores the last node in this chain as `last_wait_chain_user`
138+ last_wait_chain_user = w
139+ # 5. Check if the last node in this chain is used in backward
140+ is_used_in_backward = False
141+ for downstream_user in last_wait_chain_user .users :
142+ if _has_tag_is_backward (downstream_user ):
143+ is_used_in_backward = True
144+ break
145+ if is_used_in_backward :
146+ # 6. If the last node in this chain is used in backward, only then we save the wait_tensor
147+ nodes_to_save .append (last_ag_wait_node )
148+ primal_origins .append (primal )
149+ logger .info ("force_save_fsdp_all_gather, primal_origins: %s" , primal_origins )
150+ logger .info ("force_save_fsdp_all_gather, nodes_to_save: %s" , nodes_to_save )
151+
152+ for node in nodes_to_save :
153+ node .meta ["recompute" ] = CheckpointPolicy .MUST_SAVE
154+ node .meta ["ac_graph_id" ] = AP_AC_GRAPH_ID
155+
156+
69157def force_recompute_fsdp_all_gather (graph : torch .fx .Graph ) -> None :
70158 """
71159 Force recompute all_gather nodes from simple fsdp in the graph.
@@ -354,6 +442,8 @@ def ac_joint_pass(
354442):
355443 if reshard_after_forward :
356444 force_recompute_fsdp_all_gather (graph )
445+ else :
446+ force_save_fsdp_all_gather (graph )
357447 mark_nodes_as_must_save_to_stage_recomputation (
358448 graph , stage_size_in_GiB = ac_stage_size_in_GiB
359449 )
0 commit comments