Skip to content

Commit 3a94b7e

Browse files
committed
initial commit on compressed-tensors quantization support for fp8
1 parent 60c14f5 commit 3a94b7e

File tree

5 files changed

+546
-104
lines changed

5 files changed

+546
-104
lines changed

tpu_inference/layers/vllm/quantization/__init__.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,38 @@
22

33
from jax.sharding import Mesh
44
from vllm.config import VllmConfig
5-
from vllm.model_executor.layers.quantization.base_config import \
6-
QuantizationConfig
5+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
76

87
from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
98
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
10-
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
11-
VllmCompressedTensorsConfig # noqa: E501
12-
from tpu_inference.layers.vllm.quantization.unquantized import \
13-
VllmUnquantizedConfig
9+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import (
10+
VllmCompressedTensorsConfig,
11+
) # noqa: E501
12+
from tpu_inference.layers.vllm.quantization.unquantized import VllmUnquantizedConfig
1413

1514

16-
def get_tpu_quantization_config(vllm_config: VllmConfig,
17-
mesh: Mesh) -> QuantizationConfig:
15+
def get_tpu_quantization_config(
16+
vllm_config: VllmConfig, mesh: Mesh
17+
) -> QuantizationConfig:
1818
model_config = copy.deepcopy(vllm_config.model_config)
1919
# TODO(kyuyeunk): Add support for "tpu_int8".
2020
method_to_config: dict[str, str] = {
2121
None: VllmUnquantizedConfig,
2222
"compressed-tensors": VllmCompressedTensorsConfig,
2323
"awq": VllmAWQConfig,
24+
"fp8": VllmCompressedTensorsConfig,
2425
}
26+
# import sys
2527

28+
# sys.stdin = open(0)
29+
# breakpoint()
2630
if model_config.quantization not in method_to_config:
27-
raise NotImplementedError
31+
raise NotImplementedError(
32+
f"{model_config.quantization} quantization method not supported."
33+
)
2834
quant_config = method_to_config[model_config.quantization]
2935
assert issubclass(quant_config, JaxCommonConfig)
3036
quant_config.set_configs(vllm_config, mesh)
3137

3238
model_config.quantization = quant_config.get_name()
33-
return VllmConfig.get_quantization_config(model_config,
34-
vllm_config.load_config)
39+
return VllmConfig.get_quantization_config(model_config, vllm_config.load_config)

tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,40 +6,45 @@
66
from vllm.logger import init_logger
77
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
88
from vllm.model_executor.layers.linear import LinearBase
9-
from vllm.model_executor.layers.quantization import \
10-
register_quantization_config
11-
from vllm.model_executor.layers.quantization.base_config import \
12-
QuantizeMethodBase # noqa: E501
9+
from vllm.model_executor.layers.quantization import register_quantization_config
10+
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase # noqa: E501
1311
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
14-
CompressedTensorsConfig, CompressedTensorsKVCacheMethod,
15-
CompressedTensorsLinearMethod, CompressedTensorsScheme)
12+
CompressedTensorsConfig,
13+
CompressedTensorsKVCacheMethod,
14+
CompressedTensorsLinearMethod,
15+
CompressedTensorsScheme,
16+
)
17+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import (
18+
CompressedTensorsW8A8Fp8MoEMethod,
19+
)
1620
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
17-
find_matched_target, is_activation_quantization_format,
18-
should_ignore_layer)
21+
find_matched_target,
22+
is_activation_quantization_format,
23+
should_ignore_layer,
24+
)
1925

2026
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
21-
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
22-
VllmCompressedTensorsW8A8Fp8
23-
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
24-
VllmCompressedTensorsW8A8Int8
25-
from tpu_inference.layers.vllm.quantization.unquantized import \
26-
VllmUnquantizedConfig
27+
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import (
28+
VllmCompressedTensorsW8A8Fp8,
29+
)
30+
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import (
31+
VllmCompressedTensorsW8A8Int8,
32+
)
33+
from tpu_inference.layers.vllm.quantization.unquantized import VllmUnquantizedConfig
2734

2835
P = PartitionSpec
2936
logger = init_logger(__name__)
3037

3138

3239
@register_quantization_config("jax-compressed-tensors")
3340
class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
34-
3541
@classmethod
3642
def get_name(cls) -> str:
3743
return "jax-compressed-tensors"
3844

39-
def get_scheme(self,
40-
layer: torch.nn.Module,
41-
layer_name: Optional[str] = None
42-
) -> Optional["CompressedTensorsScheme"]:
45+
def get_scheme(
46+
self, layer: torch.nn.Module, layer_name: Optional[str] = None
47+
) -> Optional["CompressedTensorsScheme"]:
4348
"""
4449
compressed-tensors supports non uniform in the following way:
4550
@@ -60,24 +65,30 @@ def get_scheme(self,
6065
layer_name=layer_name,
6166
module=layer,
6267
targets=self.target_scheme_map.keys(),
63-
fused_mapping=self.packed_modules_mapping)
68+
fused_mapping=self.packed_modules_mapping,
69+
)
6470

6571
scheme_dict = self.target_scheme_map[matched_target]
6672
weight_quant = scheme_dict.get("weights")
6773
input_quant = scheme_dict.get("input_activations")
6874
format = scheme_dict.get("format")
6975

7076
if weight_quant is None:
71-
logger.warning_once("Acceleration for non-quantized schemes is "
72-
"not supported by Compressed Tensors. "
73-
"Falling back to UnquantizedLinearMethod")
77+
logger.warning_once(
78+
"Acceleration for non-quantized schemes is "
79+
"not supported by Compressed Tensors. "
80+
"Falling back to UnquantizedLinearMethod"
81+
)
7482
return None
7583

7684
# TODO(kyuyeunk): Add support for different act_quant_format
77-
act_quant_format = is_activation_quantization_format( # noqa: F841
78-
format
79-
) if format is not None else is_activation_quantization_format(
80-
self.quant_format)
85+
act_quant_format = (
86+
is_activation_quantization_format( # noqa: F841
87+
format
88+
)
89+
if format is not None
90+
else is_activation_quantization_format(self.quant_format)
91+
)
8192

8293
linear_config = self.get_linear_config(layer)
8394
if self._is_fp8_w8a8(weight_quant, input_quant):
@@ -94,28 +105,28 @@ def get_scheme(self,
94105
input_symmetric=input_quant.symmetric,
95106
jax_config=linear_config,
96107
)
97-
raise NotImplementedError(
98-
"No compressed-tensors compatible scheme was found.")
108+
raise NotImplementedError("No compressed-tensors compatible scheme was found.")
99109

100110
def get_quant_method(
101111
self,
102112
layer: torch.nn.Module,
103113
prefix: str,
104114
) -> Optional[QuantizeMethodBase]:
105-
if should_ignore_layer(prefix,
106-
ignore=self.ignore,
107-
fused_mapping=self.packed_modules_mapping):
115+
if should_ignore_layer(
116+
prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
117+
):
108118
return VllmUnquantizedConfig.get_quant_method(self, layer, prefix)
109119
if isinstance(layer, LinearBase):
110120
scheme = self.get_scheme(layer=layer, layer_name=prefix)
111121
if scheme is None:
112-
return VllmUnquantizedConfig.get_quant_method(
113-
self, layer, prefix)
122+
return VllmUnquantizedConfig.get_quant_method(self, layer, prefix)
114123
layer.scheme = scheme
115124
return CompressedTensorsLinearMethod(self)
116125
if isinstance(layer, FusedMoE):
117-
raise NotImplementedError(
118-
"FusedMoE quantization is currently not supported.")
126+
print("HERE", layer)
127+
return CompressedTensorsW8A8Fp8MoEMethod(
128+
self, layer.quant_config, self.mesh
129+
)
119130
if isinstance(layer, Attention):
120131
return CompressedTensorsKVCacheMethod(self)
121132
return None

0 commit comments

Comments
 (0)