Skip to content

Commit a14b215

Browse files
authored
[Llama4-Maverick/Optimization] Refactor: Standardize Sharding and Parallelism Configs in FFW/MoE Layers (#1067)
1 parent 0d7e995 commit a14b215

File tree

1 file changed

+10
-22
lines changed

1 file changed

+10
-22
lines changed

tpu_inference/models/jax/llama4.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,6 @@ def __init__(self,
4848
# TODO(fhzhang): figure out whether we need to actually enable this.
4949
# strategy_dict = {"tensor_parallelism": 4, "expert_parallelism": 2}
5050

51-
# TODO(fhzhang): remove these once we confirm that the values we get from config are good.
52-
# self.hidden_size: int = 5120
53-
# vocab_size = 202048
5451
self.vocab_size = model_config.get_vocab_size()
5552
self.hidden_size = model_config.get_hidden_size()
5653

@@ -118,7 +115,7 @@ def __init__(self,
118115
router_act="sigmoid",
119116
rngs=self.rng,
120117
activation_ffw_td=('data', None),
121-
ed_sharding=(None, 'expert'),
118+
ed_sharding=(None, None),
122119
random_init=force_random_weights)
123120

124121
moe_ffw = MoE(
@@ -132,8 +129,8 @@ def __init__(self,
132129
rngs=self.rng,
133130
activation_ffw_td=('data', None),
134131
activation_ffw_ted=('data', 'expert', None),
135-
edf_sharding=('expert', None, 'model'),
136-
efd_sharding=('expert', 'model', None),
132+
edf_sharding=('model', None, None),
133+
efd_sharding=('model', None, None),
137134
random_init=force_random_weights) if is_moe_layer else None
138135

139136
dense_ffw = DenseFFW(
@@ -196,6 +193,7 @@ def __init__(self,
196193
rngs=self.rng,
197194
with_scale=True,
198195
dtype=dtype,
196+
activation_ffw_td=('data', None),
199197
)
200198

201199
pre_mlp_norm = RMSNorm(
@@ -205,6 +203,7 @@ def __init__(self,
205203
with_scale=True,
206204
dtype=dtype,
207205
random_init=force_random_weights,
206+
activation_ffw_td=('data', None),
208207
)
209208

210209
block = SharedExpertsTransformerBlock(
@@ -344,8 +343,7 @@ def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
344343
"o_proj": (hidden_size, attn_heads, attn_head_dim),
345344
}
346345

347-
# Set the mappings from loaded parameter keys to standardized names.
348-
346+
# Set the mappings from loaded parameter keys to standardized names.\
349347
# 1. EXPERT_MAPPINGS_FUSED: Used for non-quantized (e.g., BF16) checkpoints.
350348
# - This format typically comes from standard checkpoints where 'gate' and 'up' projection weights might be combined (FUSED) into a single tensor.
351349
# - Expert weights are usually stacked, with the expert dimension (E) being the first dimension.
@@ -513,7 +511,6 @@ def load_weights(self, model_for_loading: nnx.Module):
513511
is_scale = loaded_name.endswith(".weight_scale")
514512

515513
if is_unfused_expert:
516-
# if layer_num is not None:
517514
mapped_name = self.map_loaded_to_standardized_name(
518515
loaded_name)
519516
model_weight = get_param(model_params, mapped_name)
@@ -604,19 +601,10 @@ def load_weights(self, model_for_loading: nnx.Module):
604601
f"does not match model shape for {loaded_name}: {model_weight.array.scale.value.shape}!"
605602
)
606603

607-
if buffer_key.endswith("kernel_down_proj_EFD_scale"):
608-
# The model's default sharding may incorrectly place the 'model' split on an axis that is not divisible by the mesh size (8),
609-
# so this explicitly enforces ('expert', None, 'model') to ensure correct Tensor and Expert Parallelism.
610-
correct_sharding_names = ('expert', None, 'model')
611-
model_weight.array.scale.value = shard_put(
612-
aggregated_weight,
613-
correct_sharding_names,
614-
mesh=model_for_loading.mesh)
615-
else:
616-
model_weight.array.scale.value = shard_put(
617-
aggregated_weight,
618-
model_weight.array.scale.sharding,
619-
mesh=model_for_loading.mesh)
604+
model_weight.array.scale.value = shard_put(
605+
aggregated_weight,
606+
model_weight.array.scale.sharding,
607+
mesh=model_for_loading.mesh)
620608

621609
elif aggregated_weight.itemsize < 2: # check model weight elem nbits < 16
622610
loaded_name = f"{base_mapped_name}.array.qvalue.value"

0 commit comments

Comments
 (0)