Skip to content

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Nov 21, 2025

Stacked PRs:


Claude'd

Before this PR, we directly use _assign_attr which defaults to always use nn.Module: https://github.com/pytorch/pytorch/blob/064f80dfa0482f6bd365a2f7db2e9c2f9f3ea88c/torch/export/unflatten.py#L102

For Torchtitan, the optimizer has some flag accessing pattern using.the ModuleDict APIs, we had to patch them until now (https://github.com/pytorch/torchtitan/blob/d54a6d4c92e063189c91b433e3b3d85d79dfb657/torchtitan/components/optimizer.py#L354-L361). After this PR, we'll use nn.ModuleDict when applicable.

xmfan added a commit that referenced this pull request Nov 21, 2025
stack-info: PR: #260, branch: xmfan/stack/24
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 21, 2025
xmfan added a commit that referenced this pull request Nov 21, 2025
stack-info: PR: #260, branch: xmfan/stack/24
@xmfan xmfan requested a review from fmassa November 21, 2025 01:57
xmfan added a commit that referenced this pull request Nov 21, 2025
stack-info: PR: #260, branch: xmfan/stack/24
@xmfan xmfan marked this pull request as draft November 21, 2025 02:17
xmfan added a commit that referenced this pull request Nov 21, 2025
stack-info: PR: #260, branch: xmfan/stack/24
stack-info: PR: #260, branch: xmfan/stack/24
xmfan added a commit to pytorch/torchtitan that referenced this pull request Nov 21, 2025
@xmfan xmfan requested review from bdhirsh and wconstab November 21, 2025 17:56
@xmfan xmfan marked this pull request as ready for review November 21, 2025 18:00
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generally LGTM and I was also thinking about doing something like that!

I wonder if we could (or should?) simplify/generalize the implementation to keep the original subclass information around as well?

Comment on lines +91 to +99
ref_submod = getattr(ref_curr_mod, attr_name)
if isinstance(ref_submod, torch.nn.ModuleDict):
setattr(curr_mod, attr_name, torch.nn.ModuleDict())
else:
setattr(curr_mod, attr_name, torch.nn.Module())
else:
setattr(curr_mod, attr_name, torch.nn.Module())
else:
setattr(curr_mod, attr_name, torch.nn.Module())
Copy link
Contributor

@fmassa fmassa Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we would want to keep the whole original class structure around (maybe with a nn.Module subclass indicating that the class has been AutoParallelized).
Something like

cls = type(ref_submod)
new_inst = ref_submod.__new__(cls)
new_inst.__dict__ = ref_submod.__dict__.copy()
setattr(curr_mod, attr_name, new_inst)

or if we want a subclass

cls = type(ref_submod)
new_cls = type(f"AutoP[{cls.__name__}]", (cls,), ref_submod.__dict__.copy())
new_inst = new_cls.__new__(new_cls)
new_inst.__dict__ = ref_submod.__dict__.copy()
setattr(curr_mod, attr_name, new_inst)

(but we need to cache those new classes to avoid creating too many redundant classes maybe?)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants