Skip to content

Commit 2b1fb92

Browse files
committed
fix moduledict with AP meta-pytorch/autoparallel#260
1 parent acd9588 commit 2b1fb92

File tree

1 file changed

+3
-14
lines changed

1 file changed

+3
-14
lines changed

torchtitan/components/optimizer.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -351,18 +351,9 @@ def should_manual_allreduce(tokens_per_expert_by_layer):
351351
tokens_per_expert_by_layer, torch.distributed.tensor.DTensor
352352
)
353353

354-
def get_transformer_blocks(model_part):
355-
if isinstance(model_part.layers, nn.ModuleDict):
356-
# regular torchtitan
357-
blocks = model_part.layers.values()
358-
else:
359-
# TODO: fix autoparallel to preserve the module dict
360-
blocks = model_part.layers.children()
361-
return blocks
362-
363354
def _should_register_moe_balancing_hook(model_parts: list[nn.Module]) -> bool:
364355
for model_part in model_parts:
365-
for transformer_block in get_transformer_blocks(model_part):
356+
for transformer_block in model_part.layers.values():
366357
if is_moe_block(transformer_block):
367358
# Assumption: load_balance_coeff is set universally on all moe blocks.
368359
return bool(transformer_block.moe.load_balance_coeff)
@@ -384,8 +375,7 @@ def _update_expert_bias(
384375
# default compute stream. Need to assess if this is OK performance-wise.
385376
tokens_per_expert_list = []
386377
for model_part in model_parts:
387-
blocks = get_transformer_blocks(model_part)
388-
for transformer_block in blocks:
378+
for transformer_block in model_part.layers.values():
389379
if not is_moe_block(transformer_block):
390380
continue
391381
if transformer_block.moe.load_balance_coeff is None:
@@ -414,8 +404,7 @@ def _update_expert_bias(
414404
moe_layer_idx = 0
415405
with torch.no_grad():
416406
for model_part in model_parts:
417-
blocks = get_transformer_blocks(model_part)
418-
for transformer_block in blocks:
407+
for transformer_block in model_part.layers.values():
419408
if not is_moe_block(transformer_block):
420409
continue
421410
moe = transformer_block.moe

0 commit comments

Comments
 (0)