Skip to content

Commit 5c04ffa

Browse files
authored
initial commit on compressed-tensors quantization support for fp8 (#1011)
Signed-off-by: Han Qi <hanq@google.com>
1 parent 028298c commit 5c04ffa

File tree

4 files changed

+398
-12
lines changed

4 files changed

+398
-12
lines changed
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import os
2+
import tempfile
3+
4+
import jax
5+
import jax.numpy as jnp
6+
import pytest
7+
import torch
8+
import torch.nn.functional as F
9+
import torchax
10+
import utils as test_utils
11+
from compressed_tensors.quantization import QuantizationArgs
12+
from jax.sharding import PartitionSpec
13+
from vllm.config import set_current_vllm_config
14+
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
15+
init_distributed_environment)
16+
from vllm.engine.arg_utils import EngineArgs
17+
from vllm.model_executor.layers.fused_moe import FusedMoE
18+
# yapf: disable
19+
from vllm.model_executor.layers.fused_moe.config import (
20+
FusedMoEConfig, FusedMoEParallelConfig)
21+
22+
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
23+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
24+
VllmCompressedTensorsConfig
25+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
26+
VllmCompressedTensorsW8A8Fp8MoEMethod
27+
28+
# yapf: enable
29+
30+
P = PartitionSpec
31+
32+
os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
33+
34+
MODEL = 'BCCard/Qwen3-30B-A3B-FP8-Dynamic'
35+
36+
37+
@pytest.fixture(autouse=True)
38+
def setup_environment():
39+
# This is a fake config used for init dist env.
40+
# RowParallelLinear needs dist env to be initialized.
41+
engine_args = EngineArgs(
42+
model=MODEL,
43+
max_model_len=64,
44+
max_num_batched_tokens=64,
45+
max_num_seqs=4,
46+
)
47+
48+
vllm_config = engine_args.create_engine_config()
49+
50+
with set_current_vllm_config(vllm_config):
51+
temp_file = tempfile.mkstemp()[1]
52+
init_distributed_environment(
53+
1,
54+
0,
55+
local_rank=0,
56+
distributed_init_method=f"file://{temp_file}",
57+
backend="gloo")
58+
ensure_model_parallel_initialized(1, 1)
59+
60+
61+
def _ref_math_in_bf16(w1, w2, w3, x, router_logits, top_k):
62+
seqlen = x.shape[0]
63+
expert_weights = F.softmax(router_logits, dim=-1)
64+
expert_weights, expert_indices = torch.topk(expert_weights, top_k, dim=-1)
65+
expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
66+
67+
# cond ffn
68+
# e = total num of exp = 160
69+
# t = seqlen
70+
# o = config.imtermediate size
71+
# i = config.dim
72+
x1 = torch.einsum("ti, eoi -> teo", x, w1)
73+
x1 = F.silu(x1)
74+
x3 = torch.einsum("ti, eoi -> teo", x, w3)
75+
expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), w2)
76+
77+
seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1)
78+
expert_outs = expert_outs[seq_indexes, expert_indices]
79+
out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
80+
return out
81+
82+
83+
def test_fused_moe_method():
84+
mesh = test_utils.get_spmd_mesh(jax.local_device_count())
85+
86+
engine_args = EngineArgs(
87+
model=MODEL,
88+
max_model_len=64,
89+
max_num_batched_tokens=64,
90+
max_num_seqs=4,
91+
)
92+
vllm_config = engine_args.create_engine_config()
93+
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False
94+
95+
# Call tpu_inference code
96+
vllm_config.model_config.dtype = torch.bfloat16
97+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
98+
99+
num_experts = 8
100+
top_k = 2
101+
hidden_size = 128
102+
intermediate_size = hidden_size * 2
103+
104+
with set_current_vllm_config(vllm_config):
105+
layer = FusedMoE(num_experts=num_experts,
106+
top_k=top_k,
107+
hidden_size=hidden_size,
108+
intermediate_size=intermediate_size)
109+
quant_config = VllmCompressedTensorsConfig(
110+
target_scheme_map={
111+
'Linear': {
112+
'weights':
113+
QuantizationArgs(num_bits=8,
114+
type='float',
115+
symmetric=True,
116+
group_size=None,
117+
strategy='channel',
118+
block_structure=None,
119+
dynamic=False,
120+
actorder=None,
121+
observer='minmax',
122+
observer_kwargs={}),
123+
'input_activations':
124+
QuantizationArgs(num_bits=8,
125+
type='float',
126+
symmetric=True,
127+
group_size=None,
128+
strategy='token',
129+
block_structure=None,
130+
dynamic=True,
131+
actorder=None,
132+
observer=None,
133+
observer_kwargs={}),
134+
'format':
135+
None
136+
}
137+
},
138+
ignore=[],
139+
quant_format='compressed-tensors',
140+
sparsity_scheme_map={},
141+
sparsity_ignore_list=[],
142+
)
143+
moe = FusedMoEConfig(
144+
num_experts=8,
145+
experts_per_token=2,
146+
hidden_dim=hidden_size,
147+
num_local_experts=8,
148+
moe_parallel_config=FusedMoEParallelConfig(
149+
tp_size=1,
150+
dp_size=1,
151+
ep_size=1,
152+
tp_rank=0,
153+
dp_rank=0,
154+
ep_rank=0,
155+
use_ep=False,
156+
all2all_backend='',
157+
),
158+
in_dtype=torch.bfloat16,
159+
)
160+
method = VllmCompressedTensorsW8A8Fp8MoEMethod(quant_config, moe, mesh)
161+
method.create_weights(layer,
162+
num_experts,
163+
hidden_size,
164+
intermediate_size,
165+
params_dtype=torch.float8_e4m3fn)
166+
method.process_weights_after_loading(layer)
167+
168+
seqlen = 10
169+
with torchax.default_env():
170+
x = torch.ones((seqlen, hidden_size), dtype=torch.bfloat16).to('jax')
171+
router_logits = torch.randn((seqlen, num_experts),
172+
dtype=torch.bfloat16).to('jax')
173+
result = method.apply(layer,
174+
x,
175+
router_logits,
176+
top_k=2,
177+
renormalize=True)
178+
179+
result_reference = _ref_math_in_bf16(
180+
layer.w13_weight.to(torch.bfloat16) * layer.w13_weight_scale,
181+
layer.w2_weight.to(torch.bfloat16) * layer.w2_weight_scale,
182+
layer.w3_weight.to(torch.bfloat16) * layer.w3_weight_scale, x,
183+
router_logits, top_k)
184+
185+
assert jnp.allclose(result.jax(), result_reference.jax())

tpu_inference/layers/vllm/quantization/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
2222
"compressed-tensors": VllmCompressedTensorsConfig,
2323
"awq": VllmAWQConfig,
2424
}
25-
2625
if model_config.quantization not in method_to_config:
27-
raise NotImplementedError
26+
raise NotImplementedError(
27+
f"{model_config.quantization} quantization method not supported."
28+
f" Supported methods are {method_to_config.keys()}")
2829
quant_config = method_to_config[model_config.quantization]
2930
assert issubclass(quant_config, JaxCommonConfig)
3031
quant_config.set_configs(vllm_config, mesh)

tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
CompressedTensorsConfig, CompressedTensorsKVCacheMethod,
1515
CompressedTensorsLinearMethod, CompressedTensorsScheme)
1616
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
17-
find_matched_target, is_activation_quantization_format,
18-
should_ignore_layer)
17+
find_matched_target, should_ignore_layer)
1918

2019
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
20+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
21+
VllmCompressedTensorsW8A8Fp8MoEMethod
2122
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
2223
VllmCompressedTensorsW8A8Fp8
2324
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
@@ -60,12 +61,12 @@ def get_scheme(self,
6061
layer_name=layer_name,
6162
module=layer,
6263
targets=self.target_scheme_map.keys(),
63-
fused_mapping=self.packed_modules_mapping)
64+
fused_mapping=self.packed_modules_mapping,
65+
)
6466

6567
scheme_dict = self.target_scheme_map[matched_target]
6668
weight_quant = scheme_dict.get("weights")
6769
input_quant = scheme_dict.get("input_activations")
68-
format = scheme_dict.get("format")
6970

7071
if weight_quant is None:
7172
logger.warning_once("Acceleration for non-quantized schemes is "
@@ -74,10 +75,6 @@ def get_scheme(self,
7475
return None
7576

7677
# TODO(kyuyeunk): Add support for different act_quant_format
77-
act_quant_format = is_activation_quantization_format( # noqa: F841
78-
format
79-
) if format is not None else is_activation_quantization_format(
80-
self.quant_format)
8178

8279
linear_config = self.get_linear_config(layer)
8380
if self._is_fp8_w8a8(weight_quant, input_quant):
@@ -114,8 +111,8 @@ def get_quant_method(
114111
layer.scheme = scheme
115112
return CompressedTensorsLinearMethod(self)
116113
if isinstance(layer, FusedMoE):
117-
raise NotImplementedError(
118-
"FusedMoE quantization is currently not supported.")
114+
return VllmCompressedTensorsW8A8Fp8MoEMethod(
115+
self, layer.quant_config, self.mesh)
119116
if isinstance(layer, Attention):
120117
return CompressedTensorsKVCacheMethod(self)
121118
return None

0 commit comments

Comments
 (0)