From cf4ab79d5485164ae367b1a6a0eae3acf7408787 Mon Sep 17 00:00:00 2001 From: Kyuyeun Kim Date: Fri, 7 Nov 2025 06:13:30 +0000 Subject: [PATCH] [Torchax] Support bias and swiglu in MoE Signed-off-by: Kyuyeun Kim --- tests/layers/vllm/test_unquantized.py | 128 +++++++++ tpu_inference/layers/vllm/fused_moe.py | 256 +++++++++++++----- .../layers/vllm/quantization/unquantized.py | 106 ++++++-- .../models/vllm/vllm_model_wrapper.py | 5 + 4 files changed, 391 insertions(+), 104 deletions(-) diff --git a/tests/layers/vllm/test_unquantized.py b/tests/layers/vllm/test_unquantized.py index 507a6e23d..78ecfd935 100644 --- a/tests/layers/vllm/test_unquantized.py +++ b/tests/layers/vllm/test_unquantized.py @@ -485,6 +485,134 @@ def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size, ) +@pytest.mark.parametrize("mesh", [ + test_utils.get_spmd_mesh(1), + test_utils.get_spmd_mesh(jax.local_device_count()) +]) +@pytest.mark.parametrize("num_tokens", [8]) +@pytest.mark.parametrize("intermediate_size", [1024]) +@pytest.mark.parametrize("hidden_size", [128]) +@pytest.mark.parametrize("num_experts", [8]) +@pytest.mark.parametrize("topk", [2]) +def test_fused_moe_bias(mesh, num_tokens, intermediate_size, hidden_size, + num_experts, topk): + os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1' + torch.manual_seed(42) + dtype = torch.bfloat16 + + a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10 + w1 = torch.randn( + (num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10 + w2 = torch.randn( + (num_experts, hidden_size, intermediate_size), dtype=dtype) / 10 + w1_bias = torch.randn( + (num_experts, 2 * intermediate_size), dtype=dtype) / 10 + w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10 + score = torch.randn((num_tokens, num_experts), dtype=dtype) + + engine_args = EngineArgs( + model="Qwen/Qwen2-1.5B-Instruct", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.model_config.dtype = dtype + + quant_config = get_tpu_quantization_config(vllm_config, mesh) + with set_current_vllm_config(vllm_config): + vllm_fused_moe = FusedMoE( + num_experts=num_experts, + top_k=topk, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + reduce_results=False, + renormalize=False, + tp_size=1, + dp_size=1, + quant_config=quant_config, + has_bias=True, + ) + vllm_fused_moe.w13_weight.data = w1 + vllm_fused_moe.w2_weight.data = w2 + vllm_fused_moe.w13_bias.data = w1_bias + vllm_fused_moe.w2_bias.data = w2_bias + + jax_a = torch_view(t2j(a, use_dlpack=False)) + jax_a.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) + score = torch_view(t2j(score)) + score.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) + + with torchax.default_env(), set_forward_context(None, vllm_config): + assert isinstance(vllm_fused_moe.quant_method, + VllmUnquantizedFusedMoEMethod) + vllm_fused_moe.quant_method.process_weights_after_loading( + vllm_fused_moe) + vllm_fused_moe(jax_a, score) + + +@pytest.mark.parametrize("mesh", [ + test_utils.get_spmd_mesh(1), + test_utils.get_spmd_mesh(jax.local_device_count()) +]) +@pytest.mark.parametrize("num_tokens", [8]) +@pytest.mark.parametrize("intermediate_size", [1024]) +@pytest.mark.parametrize("hidden_size", [128]) +@pytest.mark.parametrize("num_experts", [8]) +@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize("activation", ["silu", "swigluoai"]) +def test_fused_moe_activation(mesh, num_tokens, intermediate_size, hidden_size, + num_experts, topk, activation): + os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1' + torch.manual_seed(42) + dtype = torch.bfloat16 + + a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10 + w1 = torch.randn( + (num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10 + w2 = torch.randn( + (num_experts, hidden_size, intermediate_size), dtype=dtype) / 10 + score = torch.randn((num_tokens, num_experts), dtype=dtype) + + engine_args = EngineArgs( + model="Qwen/Qwen2-1.5B-Instruct", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.model_config.dtype = dtype + + quant_config = get_tpu_quantization_config(vllm_config, mesh) + with set_current_vllm_config(vllm_config): + vllm_fused_moe = FusedMoE( + num_experts=num_experts, + top_k=topk, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + reduce_results=False, + renormalize=False, + tp_size=1, + dp_size=1, + quant_config=quant_config, + activation=activation, + ) + vllm_fused_moe.w13_weight.data = w1 + vllm_fused_moe.w2_weight.data = w2 + + jax_a = torch_view(t2j(a, use_dlpack=False)) + jax_a.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) + score = torch_view(t2j(score)) + score.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) + + with torchax.default_env(), set_forward_context(None, vllm_config): + assert isinstance(vllm_fused_moe.quant_method, + VllmUnquantizedFusedMoEMethod) + vllm_fused_moe.quant_method.process_weights_after_loading( + vllm_fused_moe) + vllm_fused_moe(jax_a, score) + + @pytest.mark.parametrize("use_ep", [True]) @pytest.mark.parametrize("mesh", [test_utils.get_spmd_mesh(jax.local_device_count())]) diff --git a/tpu_inference/layers/vllm/fused_moe.py b/tpu_inference/layers/vllm/fused_moe.py index 390ae7ea4..fa9a45288 100644 --- a/tpu_inference/layers/vllm/fused_moe.py +++ b/tpu_inference/layers/vllm/fused_moe.py @@ -12,22 +12,42 @@ P = PartitionSpec +def activation_fn(activation: str, x1, x2): + match activation: + case "silu": + return jax.nn.silu(x1) * x2 + case "swigluoai": + return _swigluoai(x1, x2) + case _: + raise NotImplementedError( + f"FusedMoE does not support {activation} activation") + + +def _swigluoai(x1, x2, alpha=1.702, limit=7.0): + x1 = jnp.clip(x1, a_max=limit) + x2 = jnp.clip(x2, a_min=-limit, a_max=limit) + + gated_activation = x1 * jax.nn.sigmoid(alpha * x1) + + return gated_activation * (x2 + 1) + + def _round_up_to_multiple_of_128_within_limit(x: int, limit: int) -> int: """ - Rounds the given integer `x` up to the nearest multiple of 128, without exceeding - the specified `limit`. + Rounds the given integer `x` up to the nearest multiple of 128, without + exceeding the specified `limit`. If `x` is less than or equal to 128, returns 128. - If `x` is less than `limit`, returns the smallest multiple of 128 greater than or - equal to `x`. - If `x` is greater than or equal to `limit`, searches for the largest multiple of - 128 less than or equal to `limit` (down to 512) that divides `x` evenly, and - returns it. + If `x` is less than `limit`, returns the smallest multiple of 128 greater + than or equal to `x`. + If `x` is greater than or equal to `limit`, searches for the largest + multiple of 128 less than or equal to `limit` (down to 512) that divides `x` + evenly, and returns it. If no such candidate is found, returns `limit`. Args: x (int): The integer to round up. - limit (int): The upper bound (must be a multiple of 128 and at least 128). + limit (int): The upper bound (must be a multiple of 128). Returns: int: The rounded value according to the rules above. @@ -64,22 +84,29 @@ def _get_tiling_size_for_gmm_kernel(m: int, k: int, n: int, # TODO(Chengji): increase the upper limit tiling size of m when we can set # the vmem size to be used for gmm kernel. - # NOTE: In average each expert has m // g tokens, but as it might be unbalanced, - # here we doubled the token size when choosing tiling size of m. 2m//g can be - # either greater or less than 512. If there are 32 tokens and topk=2, - # m=topk * num_tokens=64, in this case, 2*m//g will be less than 512. + # NOTE: In average each expert has m // g tokens, but as it might be + # unbalanced, here we doubled the token size when choosing tiling size of m. + # 2m//g can be either greater or less than 512. If there are 32 tokens and + # topk=2, m=topk * num_tokens=64, in this case, 2*m//g will be less than + # 512. tm = _round_up_to_multiple_of_128_within_limit(2 * m // g, 512) tm = min(tm, m) # there's a requirement that m % tm == 0 - # k/n correspond to n_input_features/n_output_features in the matmul so they are - # normally greater than 2048, unless the num shards is large. + # k/n correspond to n_input_features/n_output_features in the matmul so they + # are normally greater than 2048, unless the num shards is large. tk = _round_up_to_multiple_of_128_within_limit(k, 2048) tn = _round_up_to_multiple_of_128_within_limit(n, 2048) return tm, tk, tn def tensor_sharded_gmm_merged_column_parallel( - lhs: jax.Array, rhs: jax.Array, group_sizes: jax.Array, - transpose_rhs: bool, mesh: Mesh, intermediate_size: int) -> jax.Array: + lhs: jax.Array, + rhs: jax.Array, + rhs_bias: jax.Array | None, + group_sizes: jax.Array, + transpose_rhs: bool, + mesh: Mesh, + intermediate_size: int, +) -> jax.Array: # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401 m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0] n = rhs.shape[1] if transpose_rhs else rhs.shape[2] @@ -101,6 +128,10 @@ def tensor_sharded_gmm_merged_column_parallel( check_rep=False, )(lhs, rhs, group_sizes) + if rhs_bias is not None: + rhs_bis = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m) + gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype) + n_shards = mesh.shape["model"] output_sizes = [intermediate_size, intermediate_size] @@ -111,6 +142,7 @@ def tensor_sharded_gmm_merged_column_parallel( def tensor_sharded_gmm_row_parallel( lhs: jax.Array, rhs: jax.Array, + rhs_bias: jax.Array | None, group_sizes: jax.Array, transpose_rhs: bool, mesh: Mesh, @@ -132,7 +164,7 @@ def _gmm_all_reduce(lhs, rhs, group_sizes): r = _gmm(lhs, rhs, group_sizes) return jax.lax.psum(r, axis_name="model") - return shard_map( + gmm_result = shard_map( _gmm_all_reduce, mesh=mesh, in_specs=(P(None, "model"), P(None, None, "model"), P()), @@ -140,6 +172,12 @@ def _gmm_all_reduce(lhs, rhs, group_sizes): check_rep=False, )(lhs, rhs, group_sizes) + if rhs_bias is not None: + rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m) + gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype) + + return gmm_result + def expert_sharded_gmm( lhs: jax.Array, @@ -161,19 +199,24 @@ def expert_sharded_gmm( group_offset, NamedSharding(mesh, P("model"))) def _gmm(lhs, rhs, group_sizes, group_offset): - # Group offset for this shard. `group_offset` is sharded, and in this sharded - # function, it has only 1 element and `group_offset.shape` is (1,) but gmm kernel requires - # the group_offset to be a ()-shaped array, so we group_offset[0]. + # Group offset for this shard. `group_offset` is sharded, and in this + # sharded function, it has only 1 element and `group_offset.shape` is + # (1,) but gmm kernel requires the group_offset to be a ()-shaped array, + # so we group_offset[0]. group_offset_of_shard = group_offset[0] - return gmm(lhs=lhs, - rhs=rhs, - group_sizes=group_sizes, - preferred_element_type=lhs.dtype, - tiling=(tm, tk, tn), - transpose_rhs=transpose_rhs, - group_offset=group_offset_of_shard) - - # The result from gmm on each shard has the same shape, but only the rows for this shard has non-zero values. Taking below as an working example: + gmm_res = gmm( + lhs=lhs, + rhs=rhs, + group_sizes=group_sizes, + preferred_element_type=lhs.dtype, + tiling=(tm, tk, tn), + transpose_rhs=transpose_rhs, + group_offset=group_offset_of_shard, + ) + return gmm_res + + # The result from gmm on each shard has the same shape, but only the rows + # for this shard has non-zero values. Taking below as an working example: # A, A, A, A 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 # A, A, A, A 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 # A, A, A, A 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 @@ -189,7 +232,7 @@ def _gmm(lhs, rhs, group_sizes, group_offset): # 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D # 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D # shard-0 shard-1 shard-2 shard-3 - # The shard 0,1,2,3 each has 3 (A rows), 2 (B rows), 5 (C rows) and 4 (D rows). + # Each shards has 3 (row A), 2 (row B), 5 (row C) and 4 (row D). gmm_res = shard_map( _gmm, mesh=mesh, @@ -198,9 +241,10 @@ def _gmm(lhs, rhs, group_sizes, group_offset): check_rep=False, )(lhs, rhs, group_sizes, group_offset) - # For i-th shard, it is responsible groups (AKA experts) from i*num_experts_per_shard to (i+1)*num_experts_per_shard - # We sum them up to get total rows in that shard, and that is the size for shard to send to its peers. This is also - # the number of non-zero rows from the gmm results. + # For i-th shard, it is responsible groups (AKA experts) from + # i*num_experts_per_shard to (i+1)*num_experts_per_shard We sum them up to + # get total rows in that shard, and that is the size for shard to send to + # its peers. This is also the number of non-zero rows from the gmm results. # In the working example, send_sizes would be [3, 2, 5, 4] send_sizes = jnp.array([ group_sizes[i * num_experts_per_shard:(i + 1) * @@ -222,17 +266,21 @@ def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets, recv_sizes): output = jnp.zeros_like(operand) - # input_offsets, send_sizes and output_offsets are sharded and there is only 1 elemnt in each shard, we - # are taking the 0-th element from them just so that jnp.repeat generates the arrays with correct shape. + # input_offsets, send_sizes and output_offsets are sharded and there is + # only 1 elemnt in each shard, we are taking the 0-th element from them + # just so that jnp.repeat generates the arrays with correct shape. input_offsets_of_shard = jnp.repeat(input_offsets[0], ep_size) send_sizes_of_shard = jnp.repeat(send_sizes[0], ep_size) output_offsets_of_shard = jnp.repeat(output_offsets[0], ep_size) - # recv_sizes is replicated across shards, because all the shards receive the same data and write to the - # output in the same way (same output_offsets and same recv_sizes) and thus generates replicated output. + # recv_sizes is replicated across shards, because all the shards receive + # the same data and write to the output in the same way (same + # output_offsets and same recv_sizes) and thus generates replicated + # output. recv_sizes_of_shard = recv_sizes - # In the working example, for each shard, the values of the offsets and sizes would be: + # In the working example, for each shard, the values of the offsets and + # sizes would be: # shard-0 shard-1 shard-2 shard-3 # input_offsets_of_shard [0, 0, 0, 0] [3, 3, 3, 3] [5, 5, 5, 5] [10,10,10,10] # send_sizes_of_shard [3, 3, 3, 3] [2, 2, 2, 2] [5, 5, 5, 5] [4, 4, 4, 4 ] @@ -246,8 +294,8 @@ def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets, recv_sizes_of_shard, axis_name="model") - # Use ragged_all_to_all to send the result from gmm for each expert to all the shards. - # In the working example, the result would be: + # Use ragged_all_to_all to send the result from gmm for each expert to all + # the shards. In the working example, the result would be: # A, A, A, A A, A, A, A A, A, A, A A, A, A, A # A, A, A, A A, A, A, A A, A, A, A A, A, A, A # A, A, A, A A, A, A, A A, A, A, A A, A, A, A @@ -272,10 +320,12 @@ def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets, )(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes) -def jax_fused_moe_func( +def fused_moe_func( hidden_states: jax.Array, w1: jax.Array, w2: jax.Array, + w1_bias: jax.Array | None, + w2_bias: jax.Array | None, gating_output: jax.Array, topk: int, global_num_experts: int, @@ -283,6 +333,7 @@ def jax_fused_moe_func( reduce_results: bool, mesh: Mesh, use_ep: bool, + activation: str, ): """ Args: @@ -292,6 +343,9 @@ def jax_fused_moe_func( gating_output: [*, num_experts] """ # adapted from https://github.com/vllm-project/vllm/blob/29fa5cac1cd731026f59084d93a822921507573c/vllm/model_executor/layers/fused_moe/moe_pallas.py#L26 + if use_ep and (w1_bias is not None or w2_bias is not None): + raise NotImplementedError( + "Bias is not supported when using expert parallelism.") orig_shape = hidden_states.shape hidden_size = hidden_states.shape[-1] num_tokens = hidden_states.size // hidden_size @@ -322,41 +376,50 @@ def jax_fused_moe_func( x = hidden_states[token_indices_sorted] if use_ep: - x = expert_sharded_gmm(x, - w1, - group_sizes, - transpose_rhs=True, - mesh=mesh, - num_experts=global_num_experts, - ep_size=ep_size) + x = expert_sharded_gmm( + x, + w1, + group_sizes, + transpose_rhs=True, + mesh=mesh, + num_experts=global_num_experts, + ep_size=ep_size, + ) x1, x2 = x[..., :intermediate_size], x[..., intermediate_size:] else: x1, x2 = tensor_sharded_gmm_merged_column_parallel( x, w1, + w1_bias, group_sizes, transpose_rhs=True, mesh=mesh, - intermediate_size=intermediate_size) + intermediate_size=intermediate_size, + ) - x = jax.nn.silu(x1) * x2 + x = activation_fn(activation, x1, x2) if use_ep: - x = expert_sharded_gmm(x, - w2, - group_sizes, - transpose_rhs=True, - mesh=mesh, - num_experts=global_num_experts, - ep_size=ep_size) + x = expert_sharded_gmm( + x, + w2, + group_sizes, + transpose_rhs=True, + mesh=mesh, + num_experts=global_num_experts, + ep_size=ep_size, + ) else: x = jax.lax.with_sharding_constraint( x, NamedSharding(mesh, P(None, "model"))) - x = tensor_sharded_gmm_row_parallel(x, - w2, - group_sizes, - transpose_rhs=True, - mesh=mesh) + x = tensor_sharded_gmm_row_parallel( + x, + w2, + w2_bias, + group_sizes, + transpose_rhs=True, + mesh=mesh, + ) x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size) x = x * jnp.expand_dims(topk_weights, axis=-1) @@ -368,11 +431,33 @@ def jax_fused_moe_func( return x -def jax_fused_moe_func_padded(hidden_states: jax.Array, w1: jax.Array, - w2: jax.Array, gating_output: jax.Array, - topk: int, global_num_experts: int, - renormalize: bool, reduce_results: bool, - mesh: Mesh, use_ep: bool): +@functools.partial( + jax.jit, + static_argnames=( + "topk", + "global_num_experts", + "renormalize", + "reduce_results", + "mesh", + "use_ep", + "activation", + ), +) +def fused_moe_func_padded( + hidden_states: jax.Array, + w1: jax.Array, + w2: jax.Array, + w1_bias: jax.Array | None, + w2_bias: jax.Array | None, + gating_output: jax.Array, + topk: int, + global_num_experts: int, + renormalize: bool, + reduce_results: bool, + mesh: Mesh, + use_ep: bool, + activation: str, +): # TODO(fanhongmin@google.com): Once the jax runner pads the input, we no longer need this. hidden_size = hidden_states.shape[-1] num_tokens = hidden_states.size // hidden_size @@ -387,13 +472,36 @@ def jax_fused_moe_func_padded(hidden_states: jax.Array, w1: jax.Array, reps = (n_repeats, ) + (1, ) * (gating_output.ndim - 1) expanded_gating_output = jnp.tile(gating_output, reps) - expanded_x = jax_fused_moe_func(expanded_hidden_states, w1, w2, - expanded_gating_output, topk, - global_num_experts, renormalize, - reduce_results, mesh, use_ep) + expanded_x = fused_moe_func( + expanded_hidden_states, + w1, + w2, + w1_bias, + w2_bias, + expanded_gating_output, + topk, + global_num_experts, + renormalize, + reduce_results, + mesh, + use_ep, + activation, + ) x = expanded_x[:hidden_states.shape[0]] return x else: - return jax_fused_moe_func(hidden_states, w1, w2, gating_output, topk, - global_num_experts, renormalize, - reduce_results, mesh, use_ep) + return fused_moe_func( + hidden_states, + w1, + w2, + w1_bias, + w2_bias, + gating_output, + topk, + global_num_experts, + renormalize, + reduce_results, + mesh, + use_ep, + activation, + ) diff --git a/tpu_inference/layers/vllm/quantization/unquantized.py b/tpu_inference/layers/vllm/quantization/unquantized.py index 984c541c0..7881332f7 100644 --- a/tpu_inference/layers/vllm/quantization/unquantized.py +++ b/tpu_inference/layers/vllm/quantization/unquantized.py @@ -1,4 +1,3 @@ -import functools import os from typing import Any, Callable, Optional, Union @@ -24,7 +23,7 @@ QuantizationConfig, QuantizeMethodBase) from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe -from tpu_inference.layers.vllm.fused_moe import jax_fused_moe_func_padded +from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded from tpu_inference.layers.vllm.linear_common import ( reorder_concatenated_tensor_for_sharding, slice_sharded_tensor_for_concatenation, torch_to_jax_param) @@ -191,8 +190,26 @@ def select_gemm_impl( def process_weights_after_loading(self, layer: torch.nn.Module) -> None: assert isinstance(layer, FusedMoE) - w2_weight = t2j(layer.w2_weight, use_dlpack=False) w13_weight = t2j(layer.w13_weight, use_dlpack=False) + w2_weight = t2j(layer.w2_weight, use_dlpack=False) + + if self.moe.has_bias: + w13_bias = t2j(layer.w13_bias, use_dlpack=False) + w2_bias = t2j(layer.w2_bias, use_dlpack=False) + + if layer.activation == "swigluoai": + # When using swigluoai, vLLM splits gmm output in a interleaved way. + # However, interleaved split is not performant on TPU. Therefore, + # we preprocess the weight so that splitting gmm output by middle + # can still get the same result. + w1_weight = w13_weight[:, ::2, :] + w3_weight = w13_weight[:, 1::2, :] + w13_weight = jnp.concat([w1_weight, w3_weight], axis=1) + + if self.moe.has_bias: + w1_bias = w13_bias[:, ::2] + w3_bias = w13_bias[:, 1::2] + w13_bias = jnp.concat([w1_bias, w3_bias], axis=1) if self.use_kernel and layer.use_ep: # Kernel expects: @@ -208,25 +225,34 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size) w13_reshaped = w13_weight.reshape(num_experts, 2, intermediate_size, hidden_size) - w13_weight = jnp.transpose(w13_reshaped, (0, 1, 3, 2)) + w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2)) # Transpose w2_weight to (num_experts, intermediate_size, hidden_size) w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1)) # Apply EP sharding w13_weight = jax.device_put( - w13_weight, + w13_weight_transposed, Format(Layout((0, 1, 2, 3)), NamedSharding(self.mesh, P("model", None, None, None)))) - w2_weight_transposed = jax.device_put( + w2_weight = jax.device_put( w2_weight_transposed, Format(Layout((0, 1, 2)), NamedSharding(self.mesh, P("model", None, None)))) - layer.w13_weight = Parameter(torch_view(w13_weight), - requires_grad=False) - layer.w2_weight = Parameter(torch_view(w2_weight_transposed), - requires_grad=False) + if self.moe.has_bias: + w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size) + + # Apply EP sharding + w13_bias = jax.device_put( + w13_bias, + Format(Layout((0, 1, 2)), + NamedSharding(self.mesh, P("model", None, None)))) + w2_bias = jax.device_put( + w2_bias, + Format(Layout((0, 1)), + NamedSharding(self.mesh, P("model", None)))) + else: # Original logic for non-kernel path if layer.use_ep: @@ -238,6 +264,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w2_weight, Format(Layout((0, 1, 2)), NamedSharding(self.mesh, P("model", None, None)))) + + if self.moe.has_bias: + w13_bias = jax.device_put( + w13_bias, + Format(Layout((0, 1)), + NamedSharding(self.mesh, P("model", None)))) + w2_bias = jax.device_put( + w2_bias, + Format(Layout((0, 1)), + NamedSharding(self.mesh, P("model", None)))) + else: intermediate_size = w13_weight.shape[1] // 2 assert intermediate_size == w2_weight.shape[-1] @@ -255,10 +292,26 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: Format(Layout((0, 1, 2)), NamedSharding(self.mesh, P(None, None, "model")))) - layer.w13_weight = Parameter(torch_view(w13_weight), - requires_grad=False) - layer.w2_weight = Parameter(torch_view(w2_weight), - requires_grad=False) + if self.moe.has_bias: + w13_bias = reorder_concatenated_tensor_for_sharding( + w13_bias, output_sizes, n_shards, dim=1) + w13_bias = jax.device_put( + w13_bias, + Format(Layout((0, 1)), + NamedSharding(self.mesh, P(None, "model")))) + w2_bias = jax.device_put( + w2_bias, + Format(Layout((0, 1)), + NamedSharding(self.mesh, P(None, None)))) + + layer.w13_weight = Parameter(torch_view(w13_weight), + requires_grad=False) + layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False) + + if self.moe.has_bias: + layer.w13_bias = Parameter(torch_view(w13_bias), + requires_grad=False) + layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False) def apply( self, @@ -284,9 +337,6 @@ def apply( logical_replica_count: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert isinstance(layer, FusedMoE) - if activation != "silu": - raise NotImplementedError( - "Only silu is supported for activation function.") if scoring_func != "softmax": raise NotImplementedError( "Only softmax is supported for scoring_func") @@ -304,24 +354,20 @@ def apply( ) else: # Use the original implementation - _fused_moe_func = functools.partial( - jax.jit(jax_fused_moe_func_padded, - static_argnames=[ - "topk", "global_num_experts", "renormalize", - "reduce_results", "mesh", "use_ep" - ]), + output = fused_moe_func_padded( + jax_view(x), + jax_view(layer.w13_weight), + jax_view(layer.w2_weight), + jax_view(layer.w13_bias) if self.moe.has_bias else None, + jax_view(layer.w2_bias) if self.moe.has_bias else None, + jax_view(router_logits), topk=top_k, global_num_experts=global_num_experts, renormalize=renormalize, reduce_results=layer.reduce_results, mesh=self.mesh, - use_ep=layer.use_ep) - - output = _fused_moe_func( - jax_view(x), - jax_view(layer.w13_weight), - jax_view(layer.w2_weight), - jax_view(router_logits), + use_ep=layer.use_ep, + activation=activation, ) return torch_view(output) diff --git a/tpu_inference/models/vllm/vllm_model_wrapper.py b/tpu_inference/models/vllm/vllm_model_wrapper.py index 2a2513689..afe8552f0 100644 --- a/tpu_inference/models/vllm/vllm_model_wrapper.py +++ b/tpu_inference/models/vllm/vllm_model_wrapper.py @@ -86,6 +86,11 @@ def load_weights(self): assert self.vllm_config.model_config.dtype in TORCH_DTYPE_TO_JAX, "The model_config.dtype must be a PyTorch dtype." vllm_config_for_load.device_config.device = "cpu" + # When expert parallelism is enabled, vLLM loads weight in sharding + # aware manner. Since tpu-inference has its own sharding logic, this + # may casue errors. Therefore, we disable it during weight loading. + vllm_config_for_load.parallel_config.enable_expert_parallel = False + if os.getenv("JAX_RANDOM_WEIGHTS", False): vllm_config_for_load.load_config.load_format = "dummy" use_random_weights = True