@@ -89,8 +89,8 @@ def calculate_mse_min_max(
8989 from compressed_tensors .quantization .utils import generate_gparam
9090
9191 if (is_fp4 (self .quantization_args )) and global_scale is None :
92- # If the quantization scheme is fp4 and global_scale is still None
93- # i.e it has not yet been optimized, then we are should first get
92+ # If the quantization scheme is fp4 and global_scale is still None
93+ # i.e it has not yet been optimized, then we are should first get
9494 # the global scale and then optimize the local scales.
9595 # Local scales are set to by the absolute min and max.
9696 iteration_global_scale = generate_gparam (
@@ -99,7 +99,8 @@ def calculate_mse_min_max(
9999 iteration_min_val = absolute_min_val
100100 iteration_max_val = absolute_max_val
101101 else :
102- # Otherwise, we are optimizing local scales and use the shrinked min and max
102+ # Otherwise, we are optimizing local scales and use the shrinked
103+ # min and max
103104 iteration_min_val = shrinked_min_val
104105 iteration_max_val = shrinked_max_val
105106 iteration_global_scale = global_scale
@@ -152,8 +153,9 @@ def calculate_updated_min_max(
152153 Updates the mse-clipped min and max values of the observed tensor using
153154 a moving average smoothed by the averaging_constant.
154155
155- - Weights: global and local scales use MSE-optimized values.
156- - Activations: global scale uses MSE-optimized values, local scales use min–max.
156+ - Weights: global and local scales use MSE-optimized values.
157+ - Activations: global scale uses MSE-optimized values, local scales use
158+ min–max.
157159
158160 :param observed: observed tensor to calculate quantization parameters for
159161 :param reduce_dims: optional tuple of dimensions to reduce along,
@@ -165,7 +167,8 @@ def calculate_updated_min_max(
165167 :return: updated min and max values derived from the observed value
166168 """
167169
168- # Skip local scales updates for dynamic activations (this will happen at runtime)
170+ # Skip local scales updates for dynamic activations (this will happen at
171+ # runtime)
169172 if self .is_activation and is_local :
170173 # Activations local scales: min–max
171174 min_val = torch .amin (observed , dim = reduce_dims , keepdims = True )
@@ -255,19 +258,22 @@ def reset(self):
255258 self .min_val = {}
256259 self .max_val = {}
257260
258-
261+
259262 def calculate_gparam (self , observed : Tensor ) -> torch .Tensor :
260263 """
261264 Generate a global scale using the observed min and max from MSE optimization.
262265
263- - Weights: global scale is computed with standard MSE optimization.
264- - Activations: global scale is computed with dynamic MSE-based scaling.
266+ - Weights: global scale is computed with standard MSE optimization.
267+ - Activations: global scale is computed with dynamic MSE-based scaling.
265268
266269 :param observed: observed tensor to calculate quantization parameters for
267270 :return: updated global scale derived from the observed tensor
268271 """
269272 from compressed_tensors .quantization .utils import generate_gparam
270273
271- updated_min_val , updated_max_val = self .calculate_updated_min_max (observed = observed )
274+ updated_min_val , updated_max_val = self .calculate_updated_min_max (
275+ observed = observed
276+ )
272277
273- return generate_gparam (updated_min_val = updated_min_val , updated_max_val = updated_max_val )
278+ return generate_gparam (
279+ updated_min_val = updated_min_val , updated_max_val = updated_max_val )
0 commit comments