@@ -351,18 +351,9 @@ def should_manual_allreduce(tokens_per_expert_by_layer):
351351 tokens_per_expert_by_layer , torch .distributed .tensor .DTensor
352352 )
353353
354- def get_transformer_blocks (model_part ):
355- if isinstance (model_part .layers , nn .ModuleDict ):
356- # regular torchtitan
357- blocks = model_part .layers .values ()
358- else :
359- # TODO: fix autoparallel to preserve the module dict
360- blocks = model_part .layers .children ()
361- return blocks
362-
363354 def _should_register_moe_balancing_hook (model_parts : list [nn .Module ]) -> bool :
364355 for model_part in model_parts :
365- for transformer_block in get_transformer_blocks ( model_part ):
356+ for transformer_block in model_part . layers . values ( ):
366357 if is_moe_block (transformer_block ):
367358 # Assumption: load_balance_coeff is set universally on all moe blocks.
368359 return bool (transformer_block .moe .load_balance_coeff )
@@ -384,8 +375,7 @@ def _update_expert_bias(
384375 # default compute stream. Need to assess if this is OK performance-wise.
385376 tokens_per_expert_list = []
386377 for model_part in model_parts :
387- blocks = get_transformer_blocks (model_part )
388- for transformer_block in blocks :
378+ for transformer_block in model_part .layers .values ():
389379 if not is_moe_block (transformer_block ):
390380 continue
391381 if transformer_block .moe .load_balance_coeff is None :
@@ -414,8 +404,7 @@ def _update_expert_bias(
414404 moe_layer_idx = 0
415405 with torch .no_grad ():
416406 for model_part in model_parts :
417- blocks = get_transformer_blocks (model_part )
418- for transformer_block in blocks :
407+ for transformer_block in model_part .layers .values ():
419408 if not is_moe_block (transformer_block ):
420409 continue
421410 moe = transformer_block .moe
0 commit comments