Skip to content

Commit 6c68136

Browse files
committed
Add unit test
Signed-off-by: Han Qi <hanq@google.com>
1 parent a73e34a commit 6c68136

File tree

2 files changed

+192
-0
lines changed

2 files changed

+192
-0
lines changed
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import os
2+
import tempfile
3+
4+
import jax
5+
import jax.numpy as jnp
6+
import pytest
7+
import torch
8+
import torch.nn.functional as F
9+
import torchax
10+
import utils as test_utils
11+
from compressed_tensors.quantization import QuantizationArgs
12+
from jax.sharding import PartitionSpec
13+
from vllm.config import set_current_vllm_config
14+
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
15+
init_distributed_environment)
16+
from vllm.engine.arg_utils import EngineArgs
17+
from vllm.model_executor.layers.fused_moe import FusedMoE
18+
# yapf: disable
19+
from vllm.model_executor.layers.fused_moe.config import (
20+
FusedMoEConfig, FusedMoEParallelConfig)
21+
22+
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
23+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
24+
VllmCompressedTensorsConfig
25+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
26+
VllmCompressedTensorsW8A8Fp8MoEMethod
27+
28+
# yapf: enable
29+
30+
P = PartitionSpec
31+
32+
os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
33+
34+
MODEL = 'BCCard/Qwen3-30B-A3B-FP8-Dynamic'
35+
36+
37+
@pytest.fixture(autouse=True)
38+
def setup_environment():
39+
# This is a fake config used for init dist env.
40+
# RowParallelLinear needs dist env to be initialized.
41+
engine_args = EngineArgs(
42+
model=MODEL,
43+
max_model_len=64,
44+
max_num_batched_tokens=64,
45+
max_num_seqs=4,
46+
)
47+
48+
vllm_config = engine_args.create_engine_config()
49+
50+
with set_current_vllm_config(vllm_config):
51+
temp_file = tempfile.mkstemp()[1]
52+
init_distributed_environment(
53+
1,
54+
0,
55+
local_rank=0,
56+
distributed_init_method=f"file://{temp_file}",
57+
backend="gloo")
58+
ensure_model_parallel_initialized(1, 1)
59+
60+
61+
def _ref_math_in_bf16(w1, w2, w3, x, router_logits, top_k):
62+
seqlen = x.shape[0]
63+
expert_weights = F.softmax(router_logits, dim=-1)
64+
expert_weights, expert_indices = torch.topk(expert_weights, top_k, dim=-1)
65+
expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
66+
67+
# cond ffn
68+
# e = total num of exp = 160
69+
# t = seqlen
70+
# o = config.imtermediate size
71+
# i = config.dim
72+
x1 = torch.einsum("ti, eoi -> teo", x, w1)
73+
x1 = F.silu(x1)
74+
x3 = torch.einsum("ti, eoi -> teo", x, w3)
75+
expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), w2)
76+
77+
seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1)
78+
expert_outs = expert_outs[seq_indexes, expert_indices]
79+
out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
80+
return out
81+
82+
83+
def test_fused_moe_method():
84+
mesh = test_utils.get_spmd_mesh(jax.local_device_count())
85+
86+
engine_args = EngineArgs(
87+
model=MODEL,
88+
max_model_len=64,
89+
max_num_batched_tokens=64,
90+
max_num_seqs=4,
91+
)
92+
vllm_config = engine_args.create_engine_config()
93+
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False
94+
95+
# Call tpu_inference code
96+
vllm_config.model_config.dtype = torch.bfloat16
97+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
98+
99+
num_experts = 8
100+
top_k = 2
101+
hidden_size = 128
102+
intermediate_size = hidden_size * 2
103+
104+
with set_current_vllm_config(vllm_config):
105+
layer = FusedMoE(num_experts=num_experts,
106+
top_k=top_k,
107+
hidden_size=hidden_size,
108+
intermediate_size=intermediate_size)
109+
quant_config = VllmCompressedTensorsConfig(
110+
target_scheme_map={
111+
'Linear': {
112+
'weights':
113+
QuantizationArgs(num_bits=8,
114+
type='float',
115+
symmetric=True,
116+
group_size=None,
117+
strategy='channel',
118+
block_structure=None,
119+
dynamic=False,
120+
actorder=None,
121+
observer='minmax',
122+
observer_kwargs={}),
123+
'input_activations':
124+
QuantizationArgs(num_bits=8,
125+
type='float',
126+
symmetric=True,
127+
group_size=None,
128+
strategy='token',
129+
block_structure=None,
130+
dynamic=True,
131+
actorder=None,
132+
observer=None,
133+
observer_kwargs={}),
134+
'format':
135+
None
136+
}
137+
},
138+
ignore=[],
139+
quant_format='compressed-tensors',
140+
sparsity_scheme_map={},
141+
sparsity_ignore_list=[],
142+
)
143+
moe = FusedMoEConfig(
144+
num_experts=8,
145+
experts_per_token=2,
146+
hidden_dim=hidden_size,
147+
num_local_experts=8,
148+
moe_parallel_config=FusedMoEParallelConfig(
149+
tp_size=1,
150+
dp_size=1,
151+
ep_size=1,
152+
tp_rank=0,
153+
dp_rank=0,
154+
ep_rank=0,
155+
use_ep=False,
156+
all2all_backend='',
157+
),
158+
in_dtype=torch.bfloat16,
159+
)
160+
method = VllmCompressedTensorsW8A8Fp8MoEMethod(quant_config, moe, mesh)
161+
method.create_weights(layer,
162+
num_experts,
163+
hidden_size,
164+
intermediate_size,
165+
params_dtype=torch.float8_e4m3fn)
166+
method.process_weights_after_loading(layer)
167+
168+
seqlen = 10
169+
with torchax.default_env():
170+
x = torch.ones((seqlen, hidden_size), dtype=torch.bfloat16).to('jax')
171+
router_logits = torch.randn((seqlen, num_experts),
172+
dtype=torch.bfloat16).to('jax')
173+
result = method.apply(layer,
174+
x,
175+
router_logits,
176+
top_k=2,
177+
renormalize=True)
178+
179+
result_reference = _ref_math_in_bf16(
180+
layer.w13_weight.to(torch.bfloat16) * layer.w13_weight_scale,
181+
layer.w2_weight.to(torch.bfloat16) * layer.w2_weight_scale,
182+
layer.w3_weight.to(torch.bfloat16) * layer.w3_weight_scale, x,
183+
router_logits, top_k)
184+
185+
assert jnp.allclose(result.jax(), result_reference.jax())

tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def __init__(self, quant_config: "CompressedTensorsConfig",
3838

3939
self.mesh = mesh
4040
self.quant_config = quant_config
41+
# import sys
42+
# sys.stdin = open(0)
43+
# breakpoint()
4144
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
4245
"weights")
4346
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
@@ -177,6 +180,10 @@ def apply(
177180
raise NotImplementedError(
178181
"Only softmax is supported for scoring_func")
179182

183+
# import sys
184+
# sys.stdin = open(0)
185+
# breakpoint()
186+
180187
# TODO: Use MoE kernel when it supports fp8
181188

182189
seqlen = x.shape[0]

0 commit comments

Comments
 (0)