Skip to content

Commit 649912e

Browse files
committed
[Torchax] Add ability to load MoE bias
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 62763b5 commit 649912e

File tree

3 files changed

+70
-23
lines changed

3 files changed

+70
-23
lines changed

tpu_inference/layers/vllm/fused_moe.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets,
272272
)(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes)
273273

274274

275-
def jax_fused_moe_func(
275+
def fused_moe_func(
276276
hidden_states: jax.Array,
277277
w1: jax.Array,
278278
w2: jax.Array,
@@ -368,11 +368,10 @@ def jax_fused_moe_func(
368368
return x
369369

370370

371-
def jax_fused_moe_func_padded(hidden_states: jax.Array, w1: jax.Array,
372-
w2: jax.Array, gating_output: jax.Array,
373-
topk: int, global_num_experts: int,
374-
renormalize: bool, reduce_results: bool,
375-
mesh: Mesh, use_ep: bool):
371+
def fused_moe_func_padded(hidden_states: jax.Array, w1: jax.Array,
372+
w2: jax.Array, gating_output: jax.Array, topk: int,
373+
global_num_experts: int, renormalize: bool,
374+
reduce_results: bool, mesh: Mesh, use_ep: bool):
376375
# TODO(fanhongmin@google.com): Once the jax runner pads the input, we no longer need this.
377376
hidden_size = hidden_states.shape[-1]
378377
num_tokens = hidden_states.size // hidden_size
@@ -387,13 +386,13 @@ def jax_fused_moe_func_padded(hidden_states: jax.Array, w1: jax.Array,
387386
reps = (n_repeats, ) + (1, ) * (gating_output.ndim - 1)
388387
expanded_gating_output = jnp.tile(gating_output, reps)
389388

390-
expanded_x = jax_fused_moe_func(expanded_hidden_states, w1, w2,
391-
expanded_gating_output, topk,
392-
global_num_experts, renormalize,
393-
reduce_results, mesh, use_ep)
389+
expanded_x = fused_moe_func(expanded_hidden_states, w1, w2,
390+
expanded_gating_output, topk,
391+
global_num_experts, renormalize,
392+
reduce_results, mesh, use_ep)
394393
x = expanded_x[:hidden_states.shape[0]]
395394
return x
396395
else:
397-
return jax_fused_moe_func(hidden_states, w1, w2, gating_output, topk,
398-
global_num_experts, renormalize,
399-
reduce_results, mesh, use_ep)
396+
return fused_moe_func(hidden_states, w1, w2, gating_output, topk,
397+
global_num_experts, renormalize, reduce_results,
398+
mesh, use_ep)

tpu_inference/layers/vllm/quantization/unquantized.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
QuantizationConfig, QuantizeMethodBase)
2525

2626
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
27-
from tpu_inference.layers.vllm.fused_moe import jax_fused_moe_func_padded
27+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
2828
from tpu_inference.layers.vllm.linear_common import (
2929
reorder_concatenated_tensor_for_sharding,
3030
slice_sharded_tensor_for_concatenation, torch_to_jax_param)
@@ -191,8 +191,12 @@ def select_gemm_impl(
191191
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
192192
assert isinstance(layer, FusedMoE)
193193

194-
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
195194
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
195+
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
196+
197+
if self.moe.has_bias:
198+
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
199+
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
196200

197201
if self.use_kernel and layer.use_ep:
198202
# Kernel expects:
@@ -208,25 +212,34 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
208212
# Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
209213
w13_reshaped = w13_weight.reshape(num_experts, 2,
210214
intermediate_size, hidden_size)
211-
w13_weight = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
215+
w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
212216

213217
# Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
214218
w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
215219

216220
# Apply EP sharding
217221
w13_weight = jax.device_put(
218-
w13_weight,
222+
w13_weight_transposed,
219223
Format(Layout((0, 1, 2, 3)),
220224
NamedSharding(self.mesh, P("model", None, None, None))))
221-
w2_weight_transposed = jax.device_put(
225+
w2_weight = jax.device_put(
222226
w2_weight_transposed,
223227
Format(Layout((0, 1, 2)),
224228
NamedSharding(self.mesh, P("model", None, None))))
225229

226-
layer.w13_weight = Parameter(torch_view(w13_weight),
227-
requires_grad=False)
228-
layer.w2_weight = Parameter(torch_view(w2_weight_transposed),
229-
requires_grad=False)
230+
if self.moe.has_bias:
231+
w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
232+
233+
# Apply EP sharding
234+
w13_bias = jax.device_put(
235+
w13_bias,
236+
Format(Layout((0, 1, 2)),
237+
NamedSharding(self.mesh, P("model", None, None))))
238+
w2_bias = jax.device_put(
239+
w2_bias,
240+
Format(Layout((0, 1)),
241+
NamedSharding(self.mesh, P("model", None))))
242+
230243
else:
231244
# Original logic for non-kernel path
232245
if layer.use_ep:
@@ -238,6 +251,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
238251
w2_weight,
239252
Format(Layout((0, 1, 2)),
240253
NamedSharding(self.mesh, P("model", None, None))))
254+
255+
if self.moe.has_bias:
256+
w13_bias = jax.device_put(
257+
w13_bias,
258+
Format(Layout((0, 1)),
259+
NamedSharding(self.mesh, P("model", None))))
260+
w2_bias = jax.device_put(
261+
w2_bias,
262+
Format(Layout((0, 1)),
263+
NamedSharding(self.mesh, P("model", None))))
264+
241265
else:
242266
intermediate_size = w13_weight.shape[1] // 2
243267
assert intermediate_size == w2_weight.shape[-1]
@@ -255,11 +279,27 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
255279
Format(Layout((0, 1, 2)),
256280
NamedSharding(self.mesh, P(None, None, "model"))))
257281

282+
if self.moe.has_bias:
283+
w13_bias = jax.device_put(
284+
w13_bias,
285+
Format(Layout((0, 1)),
286+
NamedSharding(self.mesh, P(None, "model"))))
287+
w2_bias = jax.device_put(
288+
w2_bias,
289+
Format(Layout((0, 1)),
290+
NamedSharding(self.mesh, P(None, None))))
291+
258292
layer.w13_weight = Parameter(torch_view(w13_weight),
259293
requires_grad=False)
260294
layer.w2_weight = Parameter(torch_view(w2_weight),
261295
requires_grad=False)
262296

297+
if self.moe.has_bias:
298+
layer.w13_bias = Parameter(torch_view(w13_bias),
299+
requires_grad=False)
300+
layer.w2_bias = Parameter(torch_view(w2_bias),
301+
requires_grad=False)
302+
263303
def apply(
264304
self,
265305
layer: torch.nn.Module,
@@ -290,6 +330,9 @@ def apply(
290330
if scoring_func != "softmax":
291331
raise NotImplementedError(
292332
"Only softmax is supported for scoring_func")
333+
# TODO(kyuyeunk): Remove this check once MoE bias support has landed.
334+
if self.moe.has_bias:
335+
raise NotImplementedError("Bias is not currently supported.")
293336

294337
if self.use_kernel and layer.use_ep:
295338
output = fused_ep_moe(
@@ -305,7 +348,7 @@ def apply(
305348
else:
306349
# Use the original implementation
307350
_fused_moe_func = functools.partial(
308-
jax.jit(jax_fused_moe_func_padded,
351+
jax.jit(fused_moe_func_padded,
309352
static_argnames=[
310353
"topk", "global_num_experts", "renormalize",
311354
"reduce_results", "mesh", "use_ep"

tpu_inference/models/vllm/vllm_model_wrapper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ def load_weights(self):
8686
assert self.vllm_config.model_config.dtype in TORCH_DTYPE_TO_JAX, "The model_config.dtype must be a PyTorch dtype."
8787
vllm_config_for_load.device_config.device = "cpu"
8888

89+
# When expert parallelism is enabled, vLLM loads weight in sharding
90+
# aware manner. Since tpu-inference has its own sharding logic, this
91+
# may casue errors. Therefore, we disable it during weight loading.
92+
vllm_config_for_load.parallel_config.enable_expert_parallel = False
93+
8994
if os.getenv("JAX_RANDOM_WEIGHTS", False):
9095
vllm_config_for_load.load_config.load_format = "dummy"
9196
use_random_weights = True

0 commit comments

Comments
 (0)