Skip to content

Commit 622d07a

Browse files
committed
[Torchax] Support bias and swiglu in MoE
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 6eef16d commit 622d07a

File tree

4 files changed

+397
-105
lines changed

4 files changed

+397
-105
lines changed

tests/layers/vllm/test_unquantized.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,134 @@ def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
485485
)
486486

487487

488+
@pytest.mark.parametrize("mesh", [
489+
test_utils.get_spmd_mesh(1),
490+
test_utils.get_spmd_mesh(jax.local_device_count())
491+
])
492+
@pytest.mark.parametrize("num_tokens", [8])
493+
@pytest.mark.parametrize("intermediate_size", [1024])
494+
@pytest.mark.parametrize("hidden_size", [128])
495+
@pytest.mark.parametrize("num_experts", [8])
496+
@pytest.mark.parametrize("topk", [2])
497+
def test_fused_moe_bias(mesh, num_tokens, intermediate_size, hidden_size,
498+
num_experts, topk):
499+
os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
500+
torch.manual_seed(42)
501+
dtype = torch.bfloat16
502+
503+
a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
504+
w1 = torch.randn(
505+
(num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
506+
w2 = torch.randn(
507+
(num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
508+
w1_bias = torch.randn(
509+
(num_experts, 2 * intermediate_size), dtype=dtype) / 10
510+
w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
511+
score = torch.randn((num_tokens, num_experts), dtype=dtype)
512+
513+
engine_args = EngineArgs(
514+
model="Qwen/Qwen2-1.5B-Instruct",
515+
max_model_len=64,
516+
max_num_batched_tokens=64,
517+
max_num_seqs=4,
518+
)
519+
vllm_config = engine_args.create_engine_config()
520+
vllm_config.model_config.dtype = dtype
521+
522+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
523+
with set_current_vllm_config(vllm_config):
524+
vllm_fused_moe = FusedMoE(
525+
num_experts=num_experts,
526+
top_k=topk,
527+
hidden_size=hidden_size,
528+
intermediate_size=intermediate_size,
529+
reduce_results=False,
530+
renormalize=False,
531+
tp_size=1,
532+
dp_size=1,
533+
quant_config=quant_config,
534+
has_bias=True,
535+
)
536+
vllm_fused_moe.w13_weight.data = w1
537+
vllm_fused_moe.w2_weight.data = w2
538+
vllm_fused_moe.w13_bias.data = w1_bias
539+
vllm_fused_moe.w2_bias.data = w2_bias
540+
541+
jax_a = torch_view(t2j(a, use_dlpack=False))
542+
jax_a.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))
543+
score = torch_view(t2j(score))
544+
score.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))
545+
546+
with torchax.default_env(), set_forward_context(None, vllm_config):
547+
assert isinstance(vllm_fused_moe.quant_method,
548+
VllmUnquantizedFusedMoEMethod)
549+
vllm_fused_moe.quant_method.process_weights_after_loading(
550+
vllm_fused_moe)
551+
vllm_fused_moe(jax_a, score)
552+
553+
554+
@pytest.mark.parametrize("mesh", [
555+
test_utils.get_spmd_mesh(1),
556+
test_utils.get_spmd_mesh(jax.local_device_count())
557+
])
558+
@pytest.mark.parametrize("num_tokens", [8])
559+
@pytest.mark.parametrize("intermediate_size", [1024])
560+
@pytest.mark.parametrize("hidden_size", [128])
561+
@pytest.mark.parametrize("num_experts", [8])
562+
@pytest.mark.parametrize("topk", [2])
563+
@pytest.mark.parametrize("activation", ["silu", "swigluoai"])
564+
def test_fused_moe_activation(mesh, num_tokens, intermediate_size, hidden_size,
565+
num_experts, topk, activation):
566+
os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
567+
torch.manual_seed(42)
568+
dtype = torch.bfloat16
569+
570+
a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
571+
w1 = torch.randn(
572+
(num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
573+
w2 = torch.randn(
574+
(num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
575+
score = torch.randn((num_tokens, num_experts), dtype=dtype)
576+
577+
engine_args = EngineArgs(
578+
model="Qwen/Qwen2-1.5B-Instruct",
579+
max_model_len=64,
580+
max_num_batched_tokens=64,
581+
max_num_seqs=4,
582+
)
583+
vllm_config = engine_args.create_engine_config()
584+
vllm_config.model_config.dtype = dtype
585+
586+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
587+
with set_current_vllm_config(vllm_config):
588+
vllm_fused_moe = FusedMoE(
589+
num_experts=num_experts,
590+
top_k=topk,
591+
hidden_size=hidden_size,
592+
intermediate_size=intermediate_size,
593+
reduce_results=False,
594+
renormalize=False,
595+
tp_size=1,
596+
dp_size=1,
597+
quant_config=quant_config,
598+
activation=activation,
599+
)
600+
vllm_fused_moe.w13_weight.data = w1
601+
vllm_fused_moe.w2_weight.data = w2
602+
603+
jax_a = torch_view(t2j(a, use_dlpack=False))
604+
jax_a.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))
605+
score = torch_view(t2j(score))
606+
score.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))
607+
608+
with torchax.default_env(), set_forward_context(None, vllm_config):
609+
assert isinstance(vllm_fused_moe.quant_method,
610+
VllmUnquantizedFusedMoEMethod)
611+
vllm_fused_moe.quant_method.process_weights_after_loading(
612+
vllm_fused_moe)
613+
vllm_fused_moe(jax_a, score)
614+
615+
488616
@pytest.mark.parametrize("use_ep", [True])
489617
@pytest.mark.parametrize("mesh",
490618
[test_utils.get_spmd_mesh(jax.local_device_count())])

0 commit comments

Comments
 (0)