Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions tests/layers/vllm/test_unquantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,134 @@ def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
)


@pytest.mark.parametrize("mesh", [
test_utils.get_spmd_mesh(1),
test_utils.get_spmd_mesh(jax.local_device_count())
])
@pytest.mark.parametrize("num_tokens", [8])
@pytest.mark.parametrize("intermediate_size", [1024])
@pytest.mark.parametrize("hidden_size", [128])
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("topk", [2])
def test_fused_moe_bias(mesh, num_tokens, intermediate_size, hidden_size,
num_experts, topk):
os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
torch.manual_seed(42)
dtype = torch.bfloat16

a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
w1 = torch.randn(
(num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
w2 = torch.randn(
(num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
w1_bias = torch.randn(
(num_experts, 2 * intermediate_size), dtype=dtype) / 10
w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
score = torch.randn((num_tokens, num_experts), dtype=dtype)

engine_args = EngineArgs(
model="Qwen/Qwen2-1.5B-Instruct",
max_model_len=64,
max_num_batched_tokens=64,
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.model_config.dtype = dtype

quant_config = get_tpu_quantization_config(vllm_config, mesh)
with set_current_vllm_config(vllm_config):
vllm_fused_moe = FusedMoE(
num_experts=num_experts,
top_k=topk,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=False,
renormalize=False,
tp_size=1,
dp_size=1,
quant_config=quant_config,
has_bias=True,
)
vllm_fused_moe.w13_weight.data = w1
vllm_fused_moe.w2_weight.data = w2
vllm_fused_moe.w13_bias.data = w1_bias
vllm_fused_moe.w2_bias.data = w2_bias

jax_a = torch_view(t2j(a, use_dlpack=False))
jax_a.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))
score = torch_view(t2j(score))
score.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))

with torchax.default_env(), set_forward_context(None, vllm_config):
assert isinstance(vllm_fused_moe.quant_method,
VllmUnquantizedFusedMoEMethod)
vllm_fused_moe.quant_method.process_weights_after_loading(
vllm_fused_moe)
vllm_fused_moe(jax_a, score)


@pytest.mark.parametrize("mesh", [
test_utils.get_spmd_mesh(1),
test_utils.get_spmd_mesh(jax.local_device_count())
])
@pytest.mark.parametrize("num_tokens", [8])
@pytest.mark.parametrize("intermediate_size", [1024])
@pytest.mark.parametrize("hidden_size", [128])
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("activation", ["silu", "swigluoai"])
def test_fused_moe_activation(mesh, num_tokens, intermediate_size, hidden_size,
num_experts, topk, activation):
os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
torch.manual_seed(42)
dtype = torch.bfloat16

a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
w1 = torch.randn(
(num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
w2 = torch.randn(
(num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
score = torch.randn((num_tokens, num_experts), dtype=dtype)

engine_args = EngineArgs(
model="Qwen/Qwen2-1.5B-Instruct",
max_model_len=64,
max_num_batched_tokens=64,
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.model_config.dtype = dtype

quant_config = get_tpu_quantization_config(vllm_config, mesh)
with set_current_vllm_config(vllm_config):
vllm_fused_moe = FusedMoE(
num_experts=num_experts,
top_k=topk,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=False,
renormalize=False,
tp_size=1,
dp_size=1,
quant_config=quant_config,
activation=activation,
)
vllm_fused_moe.w13_weight.data = w1
vllm_fused_moe.w2_weight.data = w2

jax_a = torch_view(t2j(a, use_dlpack=False))
jax_a.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))
score = torch_view(t2j(score))
score.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))

with torchax.default_env(), set_forward_context(None, vllm_config):
assert isinstance(vllm_fused_moe.quant_method,
VllmUnquantizedFusedMoEMethod)
vllm_fused_moe.quant_method.process_weights_after_loading(
vllm_fused_moe)
vllm_fused_moe(jax_a, score)


@pytest.mark.parametrize("use_ep", [True])
@pytest.mark.parametrize("mesh",
[test_utils.get_spmd_mesh(jax.local_device_count())])
Expand Down
Loading