Skip to content

Commit 0531833

Browse files
committed
Change the local scale identification method
Signed-off-by: Shubhra Pandit <shubhra.pandit@gmail.com>
1 parent 62fbc12 commit 0531833

File tree

1 file changed

+3
-1
lines changed
  • src/llmcompressor/observers

1 file changed

+3
-1
lines changed

src/llmcompressor/observers/mse.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def calculate_updated_min_max(
145145
reduce_dims: Optional[Tuple[int]] = None,
146146
tensor_id: Optional[Any] = None,
147147
global_scale: Optional[torch.Tensor] = None,
148+
is_local: Optional[bool]= False,
148149
) -> Tuple[FloatTensor, IntTensor]:
149150
"""
150151
Updates the mse-clipped min and max values of the observed tensor using
@@ -164,7 +165,7 @@ def calculate_updated_min_max(
164165
"""
165166

166167
# Skip local scales updates for dynamic activations (this will happen at runtime)
167-
if self.is_activation and reduce_dims is not None:
168+
if self.is_activation and is_local:
168169
# Activations local scales: min–max
169170
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
170171
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
@@ -219,6 +220,7 @@ def calculate_qparams(
219220
tensor_id=tensor_id,
220221
reduce_dims=reduce_dims,
221222
global_scale=global_scale,
223+
is_local=True,
222224
)
223225
scale, zero_point = calculate_qparams(
224226
min_vals=updated_min_val,

0 commit comments

Comments
 (0)