@@ -48,9 +48,6 @@ def __init__(self,
4848 # TODO(fhzhang): figure out whether we need to actually enable this.
4949 # strategy_dict = {"tensor_parallelism": 4, "expert_parallelism": 2}
5050
51- # TODO(fhzhang): remove these once we confirm that the values we get from config are good.
52- # self.hidden_size: int = 5120
53- # vocab_size = 202048
5451 self .vocab_size = model_config .get_vocab_size ()
5552 self .hidden_size = model_config .get_hidden_size ()
5653
@@ -118,7 +115,7 @@ def __init__(self,
118115 router_act = "sigmoid" ,
119116 rngs = self .rng ,
120117 activation_ffw_td = ('data' , None ),
121- ed_sharding = (None , 'expert' ),
118+ ed_sharding = (None , None ),
122119 random_init = force_random_weights )
123120
124121 moe_ffw = MoE (
@@ -132,8 +129,8 @@ def __init__(self,
132129 rngs = self .rng ,
133130 activation_ffw_td = ('data' , None ),
134131 activation_ffw_ted = ('data' , 'expert' , None ),
135- edf_sharding = ('expert ' , None , 'model' ),
136- efd_sharding = ('expert ' , 'model' , None ),
132+ edf_sharding = ('model ' , None , None ),
133+ efd_sharding = ('model ' , None , None ),
137134 random_init = force_random_weights ) if is_moe_layer else None
138135
139136 dense_ffw = DenseFFW (
@@ -196,6 +193,7 @@ def __init__(self,
196193 rngs = self .rng ,
197194 with_scale = True ,
198195 dtype = dtype ,
196+ activation_ffw_td = ('data' , None ),
199197 )
200198
201199 pre_mlp_norm = RMSNorm (
@@ -205,6 +203,7 @@ def __init__(self,
205203 with_scale = True ,
206204 dtype = dtype ,
207205 random_init = force_random_weights ,
206+ activation_ffw_td = ('data' , None ),
208207 )
209208
210209 block = SharedExpertsTransformerBlock (
@@ -344,8 +343,7 @@ def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
344343 "o_proj" : (hidden_size , attn_heads , attn_head_dim ),
345344 }
346345
347- # Set the mappings from loaded parameter keys to standardized names.
348-
346+ # Set the mappings from loaded parameter keys to standardized names.\
349347 # 1. EXPERT_MAPPINGS_FUSED: Used for non-quantized (e.g., BF16) checkpoints.
350348 # - This format typically comes from standard checkpoints where 'gate' and 'up' projection weights might be combined (FUSED) into a single tensor.
351349 # - Expert weights are usually stacked, with the expert dimension (E) being the first dimension.
@@ -513,7 +511,6 @@ def load_weights(self, model_for_loading: nnx.Module):
513511 is_scale = loaded_name .endswith (".weight_scale" )
514512
515513 if is_unfused_expert :
516- # if layer_num is not None:
517514 mapped_name = self .map_loaded_to_standardized_name (
518515 loaded_name )
519516 model_weight = get_param (model_params , mapped_name )
@@ -604,19 +601,10 @@ def load_weights(self, model_for_loading: nnx.Module):
604601 f"does not match model shape for { loaded_name } : { model_weight .array .scale .value .shape } !"
605602 )
606603
607- if buffer_key .endswith ("kernel_down_proj_EFD_scale" ):
608- # The model's default sharding may incorrectly place the 'model' split on an axis that is not divisible by the mesh size (8),
609- # so this explicitly enforces ('expert', None, 'model') to ensure correct Tensor and Expert Parallelism.
610- correct_sharding_names = ('expert' , None , 'model' )
611- model_weight .array .scale .value = shard_put (
612- aggregated_weight ,
613- correct_sharding_names ,
614- mesh = model_for_loading .mesh )
615- else :
616- model_weight .array .scale .value = shard_put (
617- aggregated_weight ,
618- model_weight .array .scale .sharding ,
619- mesh = model_for_loading .mesh )
604+ model_weight .array .scale .value = shard_put (
605+ aggregated_weight ,
606+ model_weight .array .scale .sharding ,
607+ mesh = model_for_loading .mesh )
620608
621609 elif aggregated_weight .itemsize < 2 : # check model weight elem nbits < 16
622610 loaded_name = f"{ base_mapped_name } .array.qvalue.value"
0 commit comments