88 QuantizationArgs ,
99 QuantizationStrategy ,
1010)
11- from compressed_tensors .quantization .utils import is_fp4
11+ from compressed_tensors .quantization .utils import is_fp4 , strict_divide
1212from compressed_tensors .registry .registry import RegistryMixin
1313from loguru import logger
1414from torch import FloatTensor , IntTensor , Tensor
@@ -128,8 +128,6 @@ def get_qparams(
128128 :return: tuple of scale and zero point based on last observed value
129129 """
130130 if observed is not None :
131- group_size = self .quantization_args .group_size
132-
133131 if self .quantization_args .strategy == QuantizationStrategy .TENSOR :
134132 # re-calculate scale and zero point, update the stored value
135133 self ._scale , self ._zero_point = self .calculate_qparams (observed )
@@ -138,49 +136,43 @@ def get_qparams(
138136 QuantizationStrategy .TENSOR_GROUP ,
139137 QuantizationStrategy .GROUP ,
140138 ):
141- rows = observed .shape [0 ]
142- columns = observed .shape [1 ]
143- num_groups = int (ceil (columns / group_size ))
144- if num_groups * group_size != columns :
145- logger .bind (log_once = True ).warning (
146- "Attempting to quantize a module weight whose columns "
147- f"({ columns } ) are not divisible by group_size ({ group_size } ). "
148- "This scheme is not supported by vLLM, please consider "
149- "adjusting the group_size for modules with this number of "
150- "columns" ,
151- )
139+ # should be identical implementation to first half of
140+ # `_process_quantization`
152141
153- self ._scale = torch .empty (
154- (rows , num_groups ), dtype = observed .dtype , device = observed .device
155- )
142+ # get shapes
143+ assert observed .ndim >= 2
144+ rows , columns = observed .shape [- 2 :]
145+ group_size = self .quantization_args .group_size
146+ num_groups = strict_divide (columns , group_size )
147+
148+ # FP4: cast zp type
156149 if is_fp4 (quantization_args = self .quantization_args ):
157150 zp_dtype = FP8_E4M3_DATA .dtype
158151 else :
159152 zp_dtype = self .quantization_args .pytorch_dtype ()
160153
154+ # allocate qparams
155+ self ._scale = torch .empty (
156+ (rows , num_groups ), dtype = observed .dtype , device = observed .device
157+ )
161158 self ._zero_point = torch .empty (
162159 (rows , num_groups ), dtype = zp_dtype , device = observed .device
163160 )
164161
165- # support column-order (default) quantization as well as other orderings
166- # such as activation ordering. Below checks if g_idx has initialized
167- is_column_order = g_idx is None or - 1 in g_idx
168- if is_column_order :
169- group_sizes = torch .full ((num_groups ,), group_size , dtype = torch .int )
170- else :
171- group_indices , group_sizes = torch .unique (g_idx , return_counts = True )
172- group_sizes = group_sizes [torch .argsort (group_indices )]
173-
174- observed = observed .index_select (- 1 , g_idx )
162+ # permute groups
163+ if g_idx is not None :
164+ perm = torch .argsort (g_idx )
165+ observed = observed .index_select (- 1 , perm )
175166
176167 # TODO: experiment with vectorizing for loop for performance
168+ # all reduce all dims except the second to last one
177169 end = 0
178- for group_index , group_count in enumerate ( group_sizes ):
170+ for group_index in range ( num_groups ):
179171 start = end
180- end = start + group_count
172+ end = start + group_size
181173 scale , zero_point = self .get_qparams_along_dim (
182- observed [: , start :end ],
183- 0 ,
174+ observed [... , start :end ],
175+ dim = - 2 ,
184176 tensor_id = group_index ,
185177 global_scale = global_scale ,
186178 )
@@ -189,8 +181,8 @@ def get_qparams(
189181 self ._zero_point [:, group_index ] = zero_point .squeeze (1 )
190182
191183 elif self .quantization_args .strategy == QuantizationStrategy .CHANNEL :
192- # assume observed is transposed, because its the output, hence use dim 0
193- self ._scale , self ._zero_point = self .get_qparams_along_dim (observed , 0 )
184+ # all reduce all dims except the second to last one
185+ self ._scale , self ._zero_point = self .get_qparams_along_dim (observed , - 2 )
194186
195187 elif self .quantization_args .strategy == QuantizationStrategy .TOKEN :
196188 # use dim 1, assume the obsersed.shape = [batch, token, hidden]
@@ -203,7 +195,7 @@ def get_qparams(
203195 elif self .quantization_args .strategy == QuantizationStrategy .BLOCK :
204196 # Block-wise quantization: one scale/zero_point per block of shape
205197 # [block_rows, block_cols]
206- rows , cols = observed .shape [: 2 ]
198+ rows , cols = observed .shape [- 2 : ]
207199 bs = self .quantization_args .block_structure
208200 if not (
209201 isinstance (bs , (list , tuple ))
@@ -255,15 +247,20 @@ def get_qparams(
255247
256248 def get_qparams_along_dim (
257249 self ,
258- observed ,
250+ observed : torch . Tensor ,
259251 dim : Union [int , Iterable [int ]],
260252 tensor_id : Optional [Any ] = None ,
261253 global_scale : Optional [Tensor ] = None ,
262254 ):
255+ # cast to set
263256 if isinstance (dim , int ):
264257 dim = [dim ]
265258 dim = set (dim )
266259
260+ # convert negative dims
261+ dim = [d if d >= 0 else observed .ndim + d for d in dim ]
262+
263+ # reduce all dimensions except the the one passed as argument to this function
267264 reduce_dims = tuple (idx for idx in range (observed .ndim ) if idx not in dim )
268265 return self .calculate_qparams (
269266 observed ,
0 commit comments