-
Notifications
You must be signed in to change notification settings - Fork 9
Fix ModuleDict wrapping #260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
stack-info: PR: #260, branch: xmfan/stack/24
0d60f4a to
b12845a
Compare
stack-info: PR: #260, branch: xmfan/stack/24
b12845a to
655a7f5
Compare
stack-info: PR: #260, branch: xmfan/stack/24
655a7f5 to
501386f
Compare
stack-info: PR: #260, branch: xmfan/stack/24
501386f to
23edb8e
Compare
stack-info: PR: #260, branch: xmfan/stack/24
23edb8e to
8f7ce99
Compare
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This generally LGTM and I was also thinking about doing something like that!
I wonder if we could (or should?) simplify/generalize the implementation to keep the original subclass information around as well?
| ref_submod = getattr(ref_curr_mod, attr_name) | ||
| if isinstance(ref_submod, torch.nn.ModuleDict): | ||
| setattr(curr_mod, attr_name, torch.nn.ModuleDict()) | ||
| else: | ||
| setattr(curr_mod, attr_name, torch.nn.Module()) | ||
| else: | ||
| setattr(curr_mod, attr_name, torch.nn.Module()) | ||
| else: | ||
| setattr(curr_mod, attr_name, torch.nn.Module()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we would want to keep the whole original class structure around (maybe with a nn.Module subclass indicating that the class has been AutoParallelized).
Something like
cls = type(ref_submod)
new_inst = ref_submod.__new__(cls)
new_inst.__dict__ = ref_submod.__dict__.copy()
setattr(curr_mod, attr_name, new_inst)or if we want a subclass
cls = type(ref_submod)
new_cls = type(f"AutoP[{cls.__name__}]", (cls,), ref_submod.__dict__.copy())
new_inst = new_cls.__new__(new_cls)
new_inst.__dict__ = ref_submod.__dict__.copy()
setattr(curr_mod, attr_name, new_inst)(but we need to cache those new classes to avoid creating too many redundant classes maybe?)
Stacked PRs:
Claude'd
Before this PR, we directly use
_assign_attrwhich defaults to always use nn.Module: https://github.com/pytorch/pytorch/blob/064f80dfa0482f6bd365a2f7db2e9c2f9f3ea88c/torch/export/unflatten.py#L102For Torchtitan, the optimizer has some flag accessing pattern using.the ModuleDict APIs, we had to patch them until now (https://github.com/pytorch/torchtitan/blob/d54a6d4c92e063189c91b433e3b3d85d79dfb657/torchtitan/components/optimizer.py#L354-L361). After this PR, we'll use nn.ModuleDict when applicable.