Skip to content

Commit b12845a

Browse files
committed
Fix ModuleDict wrapping
stack-info: PR: #260, branch: xmfan/stack/24
1 parent 10d8208 commit b12845a

File tree

2 files changed

+90
-4
lines changed

2 files changed

+90
-4
lines changed

autoparallel/api.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,50 @@
5252
_APPLY_VIEW_MM_VIEW_PATTERN = False
5353

5454

55+
# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
56+
# This installs empty Modules where none exist yet if they are subpaths of target
57+
def assign_attr(
58+
from_module: torch.nn.Module,
59+
from_obj: Union[torch.Tensor, torch.ScriptObject, torch.nn.Module],
60+
to_module: torch.nn.Module,
61+
target: str,
62+
attr_kind: _AttrKind,
63+
persistent: bool = True,
64+
):
65+
# _assign_attr assumes we should assign every field as nn.Module
66+
# this patch adds support for nn.ModuleDict (used by torchtitan)
67+
*prefix, field = target.split(".")
68+
module_map = {to_module: from_module}
69+
for item in prefix:
70+
submod_map: dict[torch.nn.Module, torch.nn.Module] = {}
71+
for t_module, f_module in module_map.items():
72+
if not hasattr(t_module, item):
73+
from_item = getattr(f_module, item, None)
74+
if isinstance(from_item, torch.nn.ModuleDict):
75+
setattr(t_module, item, torch.nn.ModuleDict())
76+
elif isinstance(from_item, torch.nn.Module):
77+
setattr(t_module, item, torch.nn.Module())
78+
else:
79+
raise RuntimeError(
80+
f"Unsupported type {type(from_item)} for item {item}"
81+
)
82+
from_children = f_module._modules.items()
83+
to_children = t_module._modules.items()
84+
# >= may seem odd, but it's because to_module is being mutated
85+
assert len(from_children) >= len(to_children)
86+
new_submods = {}
87+
for (f_attr_name, f_call), (t_attr_name, t_call) in zip(
88+
from_children, to_children
89+
):
90+
assert f_attr_name == t_attr_name
91+
new_submods[t_call] = f_call
92+
93+
submod_map.update(new_submods)
94+
module_map = submod_map
95+
96+
_assign_attr(from_obj, to_module, target, attr_kind, persistent)
97+
98+
5599
def _get_decomp_table():
56100
decomp_table = copy.copy(select_decomp_table())
57101
# TODO: removing those as they cause missing DTensor propagation rules
@@ -550,10 +594,14 @@ def _register_params_and_init_weights(
550594
# e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
551595
# create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
552596
for k, v in sharded_param_dict.items():
553-
_assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.PARAMETER)
597+
assign_attr(
598+
self.model, v, self.parallel_model, k, attr_kind=_AttrKind.PARAMETER
599+
)
554600

555601
for k, v in sharded_buffer_dict.items():
556-
_assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.BUFFER)
602+
assign_attr(
603+
self.model, v, self.parallel_model, k, attr_kind=_AttrKind.BUFFER
604+
)
557605

558606
# Right now we require a convention that the user model provides an init_weights method,
559607
# although we could snoop for other methods too.
@@ -644,10 +692,10 @@ def _register_params_and_buffers(self, sharded_param_dict, sharded_buffer_dict):
644692
# e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
645693
# create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
646694
for k, v in sharded_param_dict.items():
647-
_assign_attr(v, self, k, attr_kind=_AttrKind.PARAMETER)
695+
assign_attr(self.model, v, self, k, attr_kind=_AttrKind.PARAMETER)
648696

649697
for k, v in sharded_buffer_dict.items():
650-
_assign_attr(v, self, k, attr_kind=_AttrKind.BUFFER)
698+
assign_attr(self.model, v, self, k, attr_kind=_AttrKind.BUFFER)
651699

652700
def forward(self, *args):
653701
raise NotImplementedError("This is a placeholder for the pipeline model")

tests/test_api.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,41 @@ def input_fn():
305305
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_1, 10), kwargs = {})
306306
# %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, 30), kwargs = {})
307307
# return ((add, add_2), (tangents_1, None))
308+
309+
310+
def test_torchtitan_module_dict(device_mesh_1d):
311+
class ToyModel(torch.nn.Module):
312+
def __init__(self):
313+
super().__init__()
314+
self.layers = nn.ModuleDict(
315+
{
316+
"layer_0": torch.nn.Linear(10, 10),
317+
"layer_1": torch.nn.Linear(10, 10),
318+
"layer_2": torch.nn.Linear(10, 10),
319+
}
320+
)
321+
322+
def forward(self, x):
323+
for layer in self.layers.values():
324+
x = layer(x)
325+
return x
326+
327+
def input_fn():
328+
return torch.rand(10, 10, device="cuda")
329+
330+
with torch.device("meta"):
331+
model = ToyModel()
332+
333+
with AutoParallel(model, input_fn, device_mesh_1d) as autop:
334+
autop.add_parameter_memory_constraint(low=None, high=None)
335+
336+
x_sharding = (Replicate(),)
337+
338+
autop.add_input_constraints([x_sharding])
339+
autop.add_output_constraints([x_sharding])
340+
341+
sharding_placement = autop.optimize_placement(verbose=False)
342+
parallel_mod = autop.apply_placement(sharding_placement)
343+
344+
assert isinstance(model.layers, torch.nn.ModuleDict)
345+
assert isinstance(parallel_mod.layers, torch.nn.ModuleDict)

0 commit comments

Comments
 (0)