Skip to content

Commit b1c4909

Browse files
Enable Zero2: No reshard after forward (#238)
* Enable no reshard after forward (Zero2) * Linting and logging fixes
1 parent 3088776 commit b1c4909

File tree

3 files changed

+103
-5
lines changed

3 files changed

+103
-5
lines changed

autoparallel/_passes/split_fsdp_collectives.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torch._inductor.fx_passes.bucketing import (
1818
is_all_gather_into_tensor,
1919
is_reduce_scatter_tensor,
20+
is_wait_tensor,
2021
)
2122

2223

@@ -77,6 +78,7 @@ def split_fsdp_prefetch(
7778
prefetch_g_outs_map.append(param_g_in)
7879
else:
7980
w_n = next(iter(last_ag.users))
81+
assert is_wait_tensor(w_n)
8082
prefetch_g_outs_map.append(w_n)
8183

8284
prefetch_g_outs = prefetch_g_outs_map
@@ -126,7 +128,7 @@ def split_fsdp_reduce_scatters_epilogue(
126128
grad_outs_map = []
127129
for grad_out in grad_outs:
128130
n = grad_out
129-
last_rs = None
131+
earliest_rs = None
130132
while n is not None:
131133
if len(n.all_input_nodes) != 1:
132134
break
@@ -135,12 +137,13 @@ def split_fsdp_reduce_scatters_epilogue(
135137
break
136138
prev_n = n
137139
n = n_in
140+
# Maybe we also need to track all_reduce?
138141
if is_reduce_scatter_tensor(prev_n):
139142
# In AP for mesh dim > 1
140143
# The reduction of gradients happen in multiple steps
141-
last_rs = n
142-
if last_rs is not None:
143-
grad_outs_map.append(last_rs)
144+
earliest_rs = n
145+
if earliest_rs is not None:
146+
grad_outs_map.append(earliest_rs)
144147
else:
145148
grad_outs_map.append(grad_out)
146149

autoparallel/activation_checkpointing.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,94 @@ def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool:
6666
# mypy: ignore-errors
6767

6868

69+
def force_save_fsdp_all_gather(graph: torch.fx.Graph) -> None:
70+
"""
71+
Force save all_gather nodes from simple fsdp in the graph.
72+
This pass should be added in torch._inductor.config.joint_custom_post_pass
73+
"""
74+
nodes_to_save = []
75+
primal_origins = []
76+
primals = graph.find_nodes(op="placeholder")
77+
# 1. Find last all_gather from each placeholder
78+
for primal in primals:
79+
node = primal
80+
last_ag_chain_node = None
81+
while True:
82+
if len(node.users) != 1:
83+
break
84+
user = next(iter(node.users))
85+
if len(user.all_input_nodes) > 1:
86+
break
87+
node = user
88+
if is_all_gather_into_tensor(node):
89+
last_ag_chain_node = node
90+
if last_ag_chain_node is not None:
91+
# 2. Find last wait_tensor from last all_gather
92+
last_ag_wait_node = next(iter(last_ag_chain_node.users))
93+
assert is_wait_tensor(last_ag_wait_node)
94+
assert is_wait_tensor_from_fsdp(last_ag_wait_node)
95+
# 3. Continue the linear chain from the last wait_tensor
96+
w = last_ag_wait_node
97+
while True:
98+
if len(w.users) != 1:
99+
# Capture this pattern:
100+
# %wait_tensor_5 : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_gather_into_tensor_5,), kwargs = {}) # noqa: E501
101+
# %split : [num_users=4] = call_function[target=torch.ops.aten.split.Tensor](args = (%wait_tensor_5, 576), kwargs = {})
102+
# %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
103+
# %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
104+
# %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 2), kwargs = {})
105+
# %getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 3), kwargs = {})
106+
# %cat_1 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%getitem_2, %getitem_3, %getitem_4, %getitem_5], 1), kwargs = {}) # noqa: E501
107+
if (
108+
w.op == "call_function"
109+
and w.target == torch.ops.aten.split.Tensor
110+
):
111+
if all(
112+
split_user.op == "call_function"
113+
and split_user.target == operator.getitem
114+
and len(split_user.users) == 1
115+
for split_user in w.users
116+
):
117+
getitem_users = list(
118+
next(iter(getitem_node.users))
119+
for getitem_node in w.users
120+
)
121+
potential_cat_op = getitem_users[0]
122+
if all(
123+
potential_cat_op == getitem_user
124+
for getitem_user in getitem_users
125+
) and (
126+
potential_cat_op.op == "call_function"
127+
and potential_cat_op.target
128+
== torch.ops.aten.cat.default
129+
):
130+
w = potential_cat_op
131+
continue
132+
break
133+
user = next(iter(w.users))
134+
if len(user.all_input_nodes) > 1:
135+
break
136+
w = user
137+
# 4. Stores the last node in this chain as `last_wait_chain_user`
138+
last_wait_chain_user = w
139+
# 5. Check if the last node in this chain is used in backward
140+
is_used_in_backward = False
141+
for downstream_user in last_wait_chain_user.users:
142+
if _has_tag_is_backward(downstream_user):
143+
is_used_in_backward = True
144+
break
145+
if is_used_in_backward:
146+
# 6. If the last node in this chain is used in backward, only then we save the wait_tensor
147+
nodes_to_save.append(last_ag_wait_node)
148+
primal_origins.append(primal)
149+
logger.info("force_save_fsdp_all_gather, primal_origins: %s", primal_origins)
150+
logger.info("force_save_fsdp_all_gather, nodes_to_save: %s", nodes_to_save)
151+
152+
for node in nodes_to_save:
153+
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
154+
node.meta["ac_graph_id"] = AP_AC_GRAPH_ID
155+
156+
69157
def force_recompute_fsdp_all_gather(graph: torch.fx.Graph) -> None:
70158
"""
71159
Force recompute all_gather nodes from simple fsdp in the graph.
@@ -354,6 +442,8 @@ def ac_joint_pass(
354442
):
355443
if reshard_after_forward:
356444
force_recompute_fsdp_all_gather(graph)
445+
else:
446+
force_save_fsdp_all_gather(graph)
357447
mark_nodes_as_must_save_to_stage_recomputation(
358448
graph, stage_size_in_GiB=ac_stage_size_in_GiB
359449
)

examples/example_pp_graph_passes.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ def _get_pp_module_and_graphs(
3939
) -> tuple[torch.nn.Module, GraphCallables, GraphMeta]:
4040

4141
with AutoParallelPP(
42-
model, tracing_input_fn, mesh, dynamic=True, reshard_after_forward=False
42+
model,
43+
tracing_input_fn,
44+
mesh,
45+
dynamic=True,
46+
compile=False,
47+
reshard_after_forward=False,
4348
) as autop:
4449
autop.add_parameter_memory_constraint(low=None, high=None)
4550

0 commit comments

Comments
 (0)