@@ -145,6 +145,7 @@ def calculate_updated_min_max(
145145 reduce_dims : Optional [Tuple [int ]] = None ,
146146 tensor_id : Optional [Any ] = None ,
147147 global_scale : Optional [torch .Tensor ] = None ,
148+ is_local : Optional [bool ]= False ,
148149 ) -> Tuple [FloatTensor , IntTensor ]:
149150 """
150151 Updates the mse-clipped min and max values of the observed tensor using
@@ -164,7 +165,7 @@ def calculate_updated_min_max(
164165 """
165166
166167 # Skip local scales updates for dynamic activations (this will happen at runtime)
167- if self .is_activation and reduce_dims is not None :
168+ if self .is_activation and is_local :
168169 # Activations local scales: min–max
169170 min_val = torch .amin (observed , dim = reduce_dims , keepdims = True )
170171 max_val = torch .amax (observed , dim = reduce_dims , keepdims = True )
@@ -219,6 +220,7 @@ def calculate_qparams(
219220 tensor_id = tensor_id ,
220221 reduce_dims = reduce_dims ,
221222 global_scale = global_scale ,
223+ is_local = True ,
222224 )
223225 scale , zero_point = calculate_qparams (
224226 min_vals = updated_min_val ,
0 commit comments