|
52 | 52 | _APPLY_VIEW_MM_VIEW_PATTERN = False |
53 | 53 |
|
54 | 54 |
|
| 55 | +# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module |
| 56 | +# This installs empty Modules where none exist yet if they are subpaths of target |
| 57 | +def assign_attr( |
| 58 | + from_module: torch.nn.Module, |
| 59 | + from_obj: Union[torch.Tensor, torch.ScriptObject, torch.nn.Module], |
| 60 | + to_module: torch.nn.Module, |
| 61 | + target: str, |
| 62 | + attr_kind: _AttrKind, |
| 63 | + persistent: bool = True, |
| 64 | +): |
| 65 | + # _assign_attr assumes we should assign every field as nn.Module |
| 66 | + # this patch adds support for nn.ModuleDict (used by torchtitan) |
| 67 | + *prefix, field = target.split(".") |
| 68 | + module_map = {to_module: from_module} |
| 69 | + for item in prefix: |
| 70 | + submod_map: dict[torch.nn.Module, torch.nn.Module] = {} |
| 71 | + for t_module, f_module in module_map.items(): |
| 72 | + if not hasattr(t_module, item): |
| 73 | + from_item = getattr(f_module, item, None) |
| 74 | + if isinstance(from_item, torch.nn.ModuleDict): |
| 75 | + setattr(t_module, item, torch.nn.ModuleDict()) |
| 76 | + elif isinstance(from_item, torch.nn.Module): |
| 77 | + setattr(t_module, item, torch.nn.Module()) |
| 78 | + else: |
| 79 | + raise RuntimeError( |
| 80 | + f"Unsupported type {type(from_item)} for item {item}" |
| 81 | + ) |
| 82 | + from_children = f_module._modules.items() |
| 83 | + to_children = t_module._modules.items() |
| 84 | + # >= may seem odd, but it's because to_module is being mutated |
| 85 | + assert len(from_children) >= len(to_children) |
| 86 | + new_submods = {} |
| 87 | + for (f_attr_name, f_call), (t_attr_name, t_call) in zip( |
| 88 | + from_children, to_children |
| 89 | + ): |
| 90 | + assert f_attr_name == t_attr_name |
| 91 | + new_submods[t_call] = f_call |
| 92 | + |
| 93 | + submod_map.update(new_submods) |
| 94 | + module_map = submod_map |
| 95 | + |
| 96 | + _assign_attr(from_obj, to_module, target, attr_kind, persistent) |
| 97 | + |
| 98 | + |
55 | 99 | def _get_decomp_table(): |
56 | 100 | decomp_table = copy.copy(select_decomp_table()) |
57 | 101 | # TODO: removing those as they cause missing DTensor propagation rules |
@@ -550,10 +594,14 @@ def _register_params_and_init_weights( |
550 | 594 | # e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally |
551 | 595 | # create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot |
552 | 596 | for k, v in sharded_param_dict.items(): |
553 | | - _assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.PARAMETER) |
| 597 | + assign_attr( |
| 598 | + self.model, v, self.parallel_model, k, attr_kind=_AttrKind.PARAMETER |
| 599 | + ) |
554 | 600 |
|
555 | 601 | for k, v in sharded_buffer_dict.items(): |
556 | | - _assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.BUFFER) |
| 602 | + assign_attr( |
| 603 | + self.model, v, self.parallel_model, k, attr_kind=_AttrKind.BUFFER |
| 604 | + ) |
557 | 605 |
|
558 | 606 | # Right now we require a convention that the user model provides an init_weights method, |
559 | 607 | # although we could snoop for other methods too. |
@@ -644,10 +692,10 @@ def _register_params_and_buffers(self, sharded_param_dict, sharded_buffer_dict): |
644 | 692 | # e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally |
645 | 693 | # create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot |
646 | 694 | for k, v in sharded_param_dict.items(): |
647 | | - _assign_attr(v, self, k, attr_kind=_AttrKind.PARAMETER) |
| 695 | + assign_attr(self.model, v, self, k, attr_kind=_AttrKind.PARAMETER) |
648 | 696 |
|
649 | 697 | for k, v in sharded_buffer_dict.items(): |
650 | | - _assign_attr(v, self, k, attr_kind=_AttrKind.BUFFER) |
| 698 | + assign_attr(self.model, v, self, k, attr_kind=_AttrKind.BUFFER) |
651 | 699 |
|
652 | 700 | def forward(self, *args): |
653 | 701 | raise NotImplementedError("This is a placeholder for the pipeline model") |
|
0 commit comments