Skip to content

Commit 60ae810

Browse files
committed
Fix format errors
Signed-off-by: Shubhra Pandit <shubhra.pandit@gmail.com>
1 parent 4051ef0 commit 60ae810

File tree

1 file changed

+17
-11
lines changed
  • src/llmcompressor/observers

1 file changed

+17
-11
lines changed

src/llmcompressor/observers/mse.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def calculate_mse_min_max(
8989
from compressed_tensors.quantization.utils import generate_gparam
9090

9191
if(is_fp4(self.quantization_args)) and global_scale is None:
92-
# If the quantization scheme is fp4 and global_scale is still None
93-
# i.e it has not yet been optimized, then we are should first get
92+
# If the quantization scheme is fp4 and global_scale is still None
93+
# i.e it has not yet been optimized, then we are should first get
9494
# the global scale and then optimize the local scales.
9595
# Local scales are set to by the absolute min and max.
9696
iteration_global_scale = generate_gparam(
@@ -99,7 +99,8 @@ def calculate_mse_min_max(
9999
iteration_min_val = absolute_min_val
100100
iteration_max_val = absolute_max_val
101101
else:
102-
# Otherwise, we are optimizing local scales and use the shrinked min and max
102+
# Otherwise, we are optimizing local scales and use the shrinked
103+
# min and max
103104
iteration_min_val = shrinked_min_val
104105
iteration_max_val = shrinked_max_val
105106
iteration_global_scale = global_scale
@@ -152,8 +153,9 @@ def calculate_updated_min_max(
152153
Updates the mse-clipped min and max values of the observed tensor using
153154
a moving average smoothed by the averaging_constant.
154155
155-
- Weights: global and local scales use MSE-optimized values.
156-
- Activations: global scale uses MSE-optimized values, local scales use min–max.
156+
- Weights: global and local scales use MSE-optimized values.
157+
- Activations: global scale uses MSE-optimized values, local scales use
158+
min–max.
157159
158160
:param observed: observed tensor to calculate quantization parameters for
159161
:param reduce_dims: optional tuple of dimensions to reduce along,
@@ -165,7 +167,8 @@ def calculate_updated_min_max(
165167
:return: updated min and max values derived from the observed value
166168
"""
167169

168-
# Skip local scales updates for dynamic activations (this will happen at runtime)
170+
# Skip local scales updates for dynamic activations (this will happen at
171+
# runtime)
169172
if self.is_activation and is_local:
170173
# Activations local scales: min–max
171174
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
@@ -255,19 +258,22 @@ def reset(self):
255258
self.min_val = {}
256259
self.max_val = {}
257260

258-
261+
259262
def calculate_gparam(self, observed: Tensor) -> torch.Tensor:
260263
"""
261264
Generate a global scale using the observed min and max from MSE optimization.
262265
263-
- Weights: global scale is computed with standard MSE optimization.
264-
- Activations: global scale is computed with dynamic MSE-based scaling.
266+
- Weights: global scale is computed with standard MSE optimization.
267+
- Activations: global scale is computed with dynamic MSE-based scaling.
265268
266269
:param observed: observed tensor to calculate quantization parameters for
267270
:return: updated global scale derived from the observed tensor
268271
"""
269272
from compressed_tensors.quantization.utils import generate_gparam
270273

271-
updated_min_val, updated_max_val = self.calculate_updated_min_max(observed=observed)
274+
updated_min_val, updated_max_val = self.calculate_updated_min_max(
275+
observed=observed
276+
)
272277

273-
return generate_gparam(updated_min_val=updated_min_val, updated_max_val=updated_max_val)
278+
return generate_gparam(
279+
updated_min_val=updated_min_val, updated_max_val=updated_max_val)

0 commit comments

Comments
 (0)