66from vllm .logger import init_logger
77from vllm .model_executor .layers .fused_moe .layer import FusedMoE
88from vllm .model_executor .layers .linear import LinearBase
9- from vllm .model_executor .layers .quantization import \
10- register_quantization_config
11- from vllm .model_executor .layers .quantization .base_config import \
12- QuantizeMethodBase # noqa: E501
9+ from vllm .model_executor .layers .quantization import register_quantization_config
10+ from vllm .model_executor .layers .quantization .base_config import QuantizeMethodBase # noqa: E501
1311from vllm .model_executor .layers .quantization .compressed_tensors .compressed_tensors import (
14- CompressedTensorsConfig , CompressedTensorsKVCacheMethod ,
15- CompressedTensorsLinearMethod , CompressedTensorsScheme )
12+ CompressedTensorsConfig ,
13+ CompressedTensorsKVCacheMethod ,
14+ CompressedTensorsLinearMethod ,
15+ CompressedTensorsScheme ,
16+ )
17+ from tpu_inference .layers .vllm .quantization .compressed_tensors .compressed_tensors_moe import (
18+ CompressedTensorsW8A8Fp8MoEMethod ,
19+ )
1620from vllm .model_executor .layers .quantization .compressed_tensors .utils import (
17- find_matched_target , is_activation_quantization_format ,
18- should_ignore_layer )
21+ find_matched_target ,
22+ is_activation_quantization_format ,
23+ should_ignore_layer ,
24+ )
1925
2026from tpu_inference .layers .vllm .quantization .common import JaxCommonConfig
21- from tpu_inference .layers .vllm .quantization .compressed_tensors .schemes .compressed_tensors_w8a8_fp8 import \
22- VllmCompressedTensorsW8A8Fp8
23- from tpu_inference .layers .vllm .quantization .compressed_tensors .schemes .compressed_tensors_w8a8_int8 import \
24- VllmCompressedTensorsW8A8Int8
25- from tpu_inference .layers .vllm .quantization .unquantized import \
26- VllmUnquantizedConfig
27+ from tpu_inference .layers .vllm .quantization .compressed_tensors .schemes .compressed_tensors_w8a8_fp8 import (
28+ VllmCompressedTensorsW8A8Fp8 ,
29+ )
30+ from tpu_inference .layers .vllm .quantization .compressed_tensors .schemes .compressed_tensors_w8a8_int8 import (
31+ VllmCompressedTensorsW8A8Int8 ,
32+ )
33+ from tpu_inference .layers .vllm .quantization .unquantized import VllmUnquantizedConfig
2734
2835P = PartitionSpec
2936logger = init_logger (__name__ )
3037
3138
3239@register_quantization_config ("jax-compressed-tensors" )
3340class VllmCompressedTensorsConfig (CompressedTensorsConfig , JaxCommonConfig ):
34-
3541 @classmethod
3642 def get_name (cls ) -> str :
3743 return "jax-compressed-tensors"
3844
39- def get_scheme (self ,
40- layer : torch .nn .Module ,
41- layer_name : Optional [str ] = None
42- ) -> Optional ["CompressedTensorsScheme" ]:
45+ def get_scheme (
46+ self , layer : torch .nn .Module , layer_name : Optional [str ] = None
47+ ) -> Optional ["CompressedTensorsScheme" ]:
4348 """
4449 compressed-tensors supports non uniform in the following way:
4550
@@ -60,24 +65,30 @@ def get_scheme(self,
6065 layer_name = layer_name ,
6166 module = layer ,
6267 targets = self .target_scheme_map .keys (),
63- fused_mapping = self .packed_modules_mapping )
68+ fused_mapping = self .packed_modules_mapping ,
69+ )
6470
6571 scheme_dict = self .target_scheme_map [matched_target ]
6672 weight_quant = scheme_dict .get ("weights" )
6773 input_quant = scheme_dict .get ("input_activations" )
6874 format = scheme_dict .get ("format" )
6975
7076 if weight_quant is None :
71- logger .warning_once ("Acceleration for non-quantized schemes is "
72- "not supported by Compressed Tensors. "
73- "Falling back to UnquantizedLinearMethod" )
77+ logger .warning_once (
78+ "Acceleration for non-quantized schemes is "
79+ "not supported by Compressed Tensors. "
80+ "Falling back to UnquantizedLinearMethod"
81+ )
7482 return None
7583
7684 # TODO(kyuyeunk): Add support for different act_quant_format
77- act_quant_format = is_activation_quantization_format ( # noqa: F841
78- format
79- ) if format is not None else is_activation_quantization_format (
80- self .quant_format )
85+ act_quant_format = (
86+ is_activation_quantization_format ( # noqa: F841
87+ format
88+ )
89+ if format is not None
90+ else is_activation_quantization_format (self .quant_format )
91+ )
8192
8293 linear_config = self .get_linear_config (layer )
8394 if self ._is_fp8_w8a8 (weight_quant , input_quant ):
@@ -94,28 +105,28 @@ def get_scheme(self,
94105 input_symmetric = input_quant .symmetric ,
95106 jax_config = linear_config ,
96107 )
97- raise NotImplementedError (
98- "No compressed-tensors compatible scheme was found." )
108+ raise NotImplementedError ("No compressed-tensors compatible scheme was found." )
99109
100110 def get_quant_method (
101111 self ,
102112 layer : torch .nn .Module ,
103113 prefix : str ,
104114 ) -> Optional [QuantizeMethodBase ]:
105- if should_ignore_layer (prefix ,
106- ignore = self .ignore ,
107- fused_mapping = self . packed_modules_mapping ):
115+ if should_ignore_layer (
116+ prefix , ignore = self .ignore , fused_mapping = self . packed_modules_mapping
117+ ):
108118 return VllmUnquantizedConfig .get_quant_method (self , layer , prefix )
109119 if isinstance (layer , LinearBase ):
110120 scheme = self .get_scheme (layer = layer , layer_name = prefix )
111121 if scheme is None :
112- return VllmUnquantizedConfig .get_quant_method (
113- self , layer , prefix )
122+ return VllmUnquantizedConfig .get_quant_method (self , layer , prefix )
114123 layer .scheme = scheme
115124 return CompressedTensorsLinearMethod (self )
116125 if isinstance (layer , FusedMoE ):
117- raise NotImplementedError (
118- "FusedMoE quantization is currently not supported." )
126+ print ("HERE" , layer )
127+ return CompressedTensorsW8A8Fp8MoEMethod (
128+ self , layer .quant_config , self .mesh
129+ )
119130 if isinstance (layer , Attention ):
120131 return CompressedTensorsKVCacheMethod (self )
121132 return None
0 commit comments