diff --git a/docs/source/quantization_overview.rst b/docs/source/quantization_overview.rst index f5c82bfe5f..df0a924b11 100644 --- a/docs/source/quantization_overview.rst +++ b/docs/source/quantization_overview.rst @@ -5,7 +5,7 @@ First we want to lay out the torchao stack:: Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc. --------------------------------------------------------------------------------------------- - Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Float8Tensor + Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Int8Tensor, Float8Tensor --------------------------------------------------------------------------------------------- Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize --------------------------------------------------------------------------------------------- @@ -88,6 +88,8 @@ So in general we structure Tensor subclasses by dervied dtpype and packing forma - scaled int4 - preshuffled (special format to optimize for loading) - float8 act + int4 weight dynamic quantization and int4 weight only quantization + * - Int8Tensor + - plain .. note:: We don't have granularity specific tensor subclasses, i.e. no Float8RowwiseTensor or Float8BlockwiseTensor, all granularities are implemented in the same Tensor, we typically use a general `block_size` attribute to distinguish between different granularities, and each Tensor is allowed to support only a subset of all possible granularity options. diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py new file mode 100644 index 0000000000..c2f099fcde --- /dev/null +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +from torch._inductor.utils import run_and_get_code +from torch.testing import FileCheck +from torch.testing._internal import common_utils + +from torchao.quantization import ( + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, + quantize_, +) +from torchao.quantization.granularity import PerRow, PerTensor +from torchao.quantization.utils import compute_error, get_block_size +from torchao.testing.utils import TorchAOIntegrationTestCase + + +# TODO: Refactor after https://github.com/pytorch/ao/pull/2729 is merged +class ToyTwoLinearModel(torch.nn.Module): + def __init__( + self, + input_dim, + hidden_dim, + output_dim, + has_bias=False, + dtype=None, + device=None, + ): + super().__init__() + self.dtype = dtype + self.device = device + self.linear1 = torch.nn.Linear( + input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device + ) + self.linear2 = torch.nn.Linear( + hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device + ) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@common_utils.instantiate_parametrized_tests +class TestInt8Tensor(TorchAOIntegrationTestCase): + def setUp(self): + super().setUp() + + self.test_shape = (32, 20) + self.dtype = torch.bfloat16 + self.batch_size = 32 + + torch.manual_seed(42) + + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig(version=2), + Int8WeightOnlyConfig(version=2), + ], + ) + def test_creation_and_attributes(self, config): + """Test tensor creation, dtypes, and ranges""" + linear = torch.nn.Linear( + self.test_shape[1], + self.test_shape[0], + bias=False, + dtype=self.dtype, + device="cuda", + ) + quantize_(linear, config) + + w = linear.weight + + self.assertEqual(w.shape, self.test_shape) + self.assertEqual(w.qdata.dtype, torch.int8) + self.assertTrue(torch.all(w.qdata >= -128) and torch.all(w.qdata <= 127)) + + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) + @common_utils.parametrize("compile", [True, False]) + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig(version=2), + Int8WeightOnlyConfig(version=2), + ], + ) + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), # 2D + ((32, 128), 64, 256), # 3D + ], + ) + def test_int8_linear_variants( + self, + dtype: torch.dtype, + config, + compile: bool, + sizes: tuple, + ): + """Test linear operation supports including shape and compile""" + M, N, K = sizes + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + model = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval() + model_q = copy.deepcopy(model) + + quantize_(model_q, config) + + self.assertEqual(model_q.linear2.weight.scale.shape, (K,)) + self.assertEqual(model_q.linear2.weight.scale.ndim, 1) + + if compile: + model_q = torch.compile(model_q, fullgraph=True) + + output_fp = model(input_tensor) + output_quantized = model_q(input_tensor) + + assert compute_error(output_fp, output_quantized) > 20, ( + f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}" + ) + + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig(version=2), + Int8WeightOnlyConfig(version=2), + ], + ) + @common_utils.parametrize("device", ["cpu", "cuda"]) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_slice(self, config, device, dtype): + """Test tensor slicing with per-row quantization""" + tensor_size = 256 + slice_sizes = (64, 128) + + dummy = torch.nn.Linear( + tensor_size, tensor_size, bias=False, dtype=dtype, device=device + ) + quantize_(dummy, config) + + weight1 = dummy.weight.clone().narrow(0, 0, slice_sizes[0]) + weight2 = dummy.weight.clone().narrow(1, 0, slice_sizes[1]) + + self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0])) + self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1])) + self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0])) + self.assertEqual(weight2.scale, dummy.weight.scale) + with self.assertRaises(NotImplementedError): + _ = dummy.weight[::2] + + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, + ], + ) + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + def test_index_select(self, config, granularity): + """test that `x_0 = x[0]` works when `x` is a 2D quantized tensor.""" + N, K = 256, 512 + x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) + linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda") + linear.weight.data = x + + config = config(version=2, granularity=granularity) + quantize_(linear, config) + + x_int8 = linear.weight + x_int8_0 = x_int8[0] + + # Test dequantization consistency + torch.testing.assert_close( + x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0 + ) + + # Test block_size granularity + if isinstance(granularity, PerRow): + self.assertEqual( + list(get_block_size(x_int8.shape, x_int8.granularity)), [1, K] + ) + elif isinstance(granularity, PerTensor): + self.assertEqual( + list(get_block_size(x_int8.shape, x_int8.granularity)), [N, K] + ) + + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig(version=2), + Int8WeightOnlyConfig(version=2), + ], + ) + def test_dequantization_accuracy(self, config): + """Test dequantization accuracy separately""" + linear = torch.nn.Linear( + 256, 512, bias=False, dtype=torch.bfloat16, device="cuda" + ) + weight_fp = copy.deepcopy(linear.weight) + quantize_(linear, config) + + tensor = linear.weight + dequantized = tensor.dequantize() + self.assertEqual(dequantized.shape, weight_fp.shape) + assert compute_error(dequantized, weight_fp) > 20, ( + f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, weight_fp)}" + ) + + def test_available_gpu_kernels(self): + """Check which GPU kernels are used""" + torch.compiler.reset() + + M, K, N = 128, 256, 512 + m = torch.nn.Sequential( + torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16) + ) + + config = Int8DynamicActivationInt8WeightConfig(version=2) + quantize_(m, config) + + m = torch.compile(m) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + out, code = run_and_get_code(m, x) + + # Check expected kernels are present + FileCheck().check_count("triton_per_fused", 1).check_count( + "extern_kernels._int_mm", 1 + ).check_count("triton_poi_fused", 1).run(code[0]) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 212df9c5db..25e0dfd7aa 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -139,43 +139,50 @@ def _slice_scale_for_dimension( Slice the scale tensor appropriately based on the data tensor slicing. This function calculates how the scale should be sliced when the data tensor is sliced along a given dimension, taking into account the block structure. - """ - aten = torch.ops.aten - # Unsupported case for now, this would be 1 scale per data element - if scale.shape == data_shape: - return aten.slice.Tensor(scale, dim, start, end, step) + Example: + If data_shape is [256, 128] and scale shape is [1] (indicating per-tensor scaling), + slicing along any dimension should return the same scale tensor. - # Reconstruct block sizes based on data shape and scale shape - block_sizes = tuple(data_shape[i] // scale.shape[i] for i in range(len(data_shape))) + If data_shape is [256, 128] and scale shape is [256] (indicating per-row scaling), + and we slice data along dim=0 from 64 to 192, the corresponding scale + """ + aten = torch.ops.aten - if dim >= len(block_sizes): - # Slicing beyond the dimensions we care about + # Case 1: Per-tensor quantization (scalar scale) + if scale.numel() <= 1: return scale + # Case 2: Per-row quantization (1D scale) + # Scale is per-element along this dimension + if scale.ndim == 1: + if dim == 0: + return aten.slice.Tensor(scale, 0, start, end, step) + else: + return scale + + # Case 3: Per-block quantization (2D scale) + block_sizes = tuple( + data_shape[i] // scale.shape[i] for i in range(len(scale.shape)) + ) + block_size_for_dim = block_sizes[dim] - if block_size_for_dim == 1: - # Scale is per-element along this dimension - # Slice away as normal - return aten.slice.Tensor(scale, dim, start, end, step) - else: - # There is blocking in this dimension - # Calculate which scale elements correspond to the sliced data - scale_start = start // block_size_for_dim if start is not None else None - scale_end = ( - (end + block_size_for_dim - 1) // block_size_for_dim - if end is not None - else None + if step > 1: + raise NotImplementedError( + "Slicing with step > 1 is not implemented for scale tensors." ) - # Error on Step > 1 - if step > 1: - raise NotImplementedError( - "Slicing with step > 1 is not implemented for scale tensors." - ) + # There is blocking in this dimension + # Calculate which scale elements correspond to the sliced data + scale_start = start // block_size_for_dim if start is not None else None + scale_end = ( + (end + block_size_for_dim - 1) // block_size_for_dim + if end is not None + else None + ) - return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1) + return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1) def _is_rowwise_scaled(x: torch.Tensor) -> bool: diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 577ac40721..77de6732f7 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -100,6 +100,7 @@ Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, + Int8Tensor, IntxOpaqueTensor, IntxUnpackedToInt8Tensor, ) @@ -173,6 +174,7 @@ "IntxOpaqueTensor", "IntxUnpackedToInt8Tensor", "Int4TilePackedTo4dTensor", + "Int8Tensor", "Float8Tensor", "Int4OpaqueTensor", "Float8OpaqueTensor", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 8f48de494d..d1942fca35 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -84,6 +84,7 @@ Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, + Int8Tensor, IntxChooseQParamsAlgorithm, IntxOpaqueTensor, IntxPackingFormat, @@ -1324,14 +1325,18 @@ class Int8WeightOnlyConfig(AOBaseConfig): Configuration for applying int8 weight-only symmetric per-channel quantization to linear layers. Args: - group_size: Optional[int] = None - Controls the granularity of quantization. If None, applies per-channel quantization. - Otherwise, applies per-group quantization with the specified group size. + group_size (version 1) - Controls the granularity of quantization. + If None, applies per-channel quantization. Otherwise, applies per-group quantization with the specified group size. + granularity (version 2) - Quantization granularity. + PerRow() for per-channel quantization, PerTensor() for per-tensor quantization. set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values for better performance with this quantization scheme. """ group_size: Optional[int] = None + granularity: Optional[Union[PerRow, PerTensor]] = PerRow() set_inductor_config: bool = True + version: int = 1 def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig") @@ -1342,22 +1347,29 @@ def __post_init__(self): def _int8_weight_only_quantize_tensor(weight, config): - mapping_type = MappingType.SYMMETRIC - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - group_size = config.group_size - if group_size is None: - group_size = weight.shape[-1] - block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) - new_weight = to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - eps=eps, - zero_point_dtype=zero_point_dtype, - ) + if config.version == 1: + warnings.warn( + "Config Deprecation: version 1 of Int8WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" + ) + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + group_size = config.group_size + if group_size is None: + group_size = weight.shape[-1] + block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + ) + else: + assert config.version == 2, f"Unexpected version: {config.version}" + new_weight = Int8Tensor.from_hp(weight, granularity=config.granularity) return new_weight @@ -1492,12 +1504,15 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): in original precision during decode operations. set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values for better performance with this quantization scheme. + version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Int8Tensor """ layout: Optional[Layout] = PlainLayout() act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC weight_only_decode: bool = False + granularity: Optional[Union[PerRow, PerTensor]] = PerRow() set_inductor_config: bool = True + version: int = 1 def __post_init__(self): torch._C._log_api_usage_once( @@ -1529,9 +1544,6 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): mapping_type = MappingType.SYMMETRIC weight_zero_point_domain = ZeroPointDomain.NONE - def get_weight_block_size(x): - return tuple([1 for _ in range(x.dim() - 1)] + [x.shape[-1]]) - target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 @@ -1545,19 +1557,44 @@ def get_weight_block_size(x): else: input_quant_func = _int8_asymm_per_token_quant - block_size = get_weight_block_size(weight) - new_weight = to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - eps=eps, - zero_point_dtype=zero_point_dtype, - _layout=layout, - zero_point_domain=weight_zero_point_domain, - ) - new_weight = to_linear_activation_quantized(new_weight, input_quant_func) - return new_weight + if config.version == 1: + warnings.warn( + "Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" + ) + if isinstance(config.granularity, PerTensor): + block_size = weight.shape + else: + block_size = tuple( + [1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]] + ) + + quantized_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + _layout=layout, + zero_point_domain=weight_zero_point_domain, + ) + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func + ) + else: + from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( + QuantizeTensorToInt8Kwargs, + ) + + assert config.version == 2, f"Unexpected version: {config.version}" + + quantized_weight = Int8Tensor.from_hp( + weight, + granularity=config.granularity, + act_quant_kwargs=QuantizeTensorToInt8Kwargs(granularity=config.granularity), + ) + + return quantized_weight @register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) diff --git a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py index 0adc8c786d..15540e34c8 100644 --- a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py +++ b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py @@ -39,7 +39,9 @@ def _choose_quant_func_and_quantize_tensor( """ from torchao.quantization.quantize_.workflows import ( Float8Tensor, + Int8Tensor, QuantizeTensorToFloat8Kwargs, + QuantizeTensorToInt8Kwargs, ) if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs): @@ -52,5 +54,11 @@ def _choose_quant_func_and_quantize_tensor( quant_kwargs.hp_value_ub, quant_kwargs.kernel_preference, ) + elif isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs): + return Int8Tensor.from_hp( + tensor, + granularity=quant_kwargs.granularity, + act_quant_kwargs=quant_kwargs, + ) raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}") diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index e379327689..90515190e9 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -24,6 +24,10 @@ Int4Tensor, ) from .int4.int4_tile_packed_to_4d_tensor import Int4TilePackedTo4dTensor +from .int8.int8_tensor import ( + Int8Tensor, + QuantizeTensorToInt8Kwargs, +) from .intx.intx_choose_qparams_algorithm import IntxChooseQParamsAlgorithm from .intx.intx_opaque_tensor import ( IntxOpaqueTensor, @@ -41,6 +45,8 @@ "Int4MarlinSparseTensor", "Int4PlainInt32Tensor", "Int4TilePackedTo4dTensor", + "Int8Tensor", + "QuantizeTensorToInt8Kwargs", "Float8OpaqueTensor", "Float8Tensor", "Float8PackingFormat", diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py new file mode 100644 index 0000000000..dceb0964aa --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -0,0 +1,303 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.float8.inference import _slice_scale_for_dimension +from torchao.kernel import int_scaled_matmul +from torchao.quantization.granularity import Granularity, PerRow +from torchao.quantization.quant_primitives import ( + MappingType, + choose_qparams_affine, + dequantize_affine, + quantize_affine, +) +from torchao.quantization.quantize_.common import ( + QuantizeTensorKwargs, + _choose_quant_func_and_quantize_tensor, +) +from torchao.quantization.utils import get_block_size +from torchao.utils import TorchAOBaseTensor, fill_defaults + +__all__ = ["Int8Tensor", "QuantizeTensorToInt8Kwargs"] + +aten = torch.ops.aten + + +@dataclass +class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): + """Tensor kwargs for creating int8 tensor (either activation or weight) + + Args: + granularity: the granularity for the Tensor, currently either PerRow() or PerTensor() + """ + + granularity: Granularity = PerRow() + + +class Int8Tensor(TorchAOBaseTensor): + """ + int8 quantized tensor with plain layout + + Tensor Attributes: + qdata: (N, K) or (B, N, K) int8 quantized weight data (2D or 3D) + scale: scale factors for dequantization + # TODO: Static quantization support using `static_scale` + + Non-Tensor Attributes: + granularity: the granularity for quantization (e.g., PerRow(), PerTensor()) + act_quant_kwargs: flags for dynamic activation quantization + """ + + # TODO: Static quantization support using `static_scale` + tensor_data_names = ["qdata", "scale"] + tensor_attribute_names = ["granularity"] + optional_tensor_attribute_names = ["act_quant_kwargs", "dtype"] + + def __new__( + cls: type, + qdata: torch.Tensor, + scale: torch.Tensor, + granularity: Granularity, + act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, + dtype: Optional[torch.dtype] = None, + ): + kwargs = { + "device": qdata.device, + "dtype": dtype or scale.dtype, + "requires_grad": False, + } + return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, **kwargs) + + def __init__( + self, + qdata: torch.Tensor, + scale: torch.Tensor, + granularity: Granularity, + act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.qdata = qdata + self.scale = scale + self.granularity = granularity + self.act_quant_kwargs = act_quant_kwargs + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"act_quant_kwargs={self.act_quant_kwargs}, " + f"qdata={self.qdata}, " + f"scale={self.scale}, " + f"granularity={self.granularity}, " + f"shape={self.shape}, " + f"device={self.device}, " + f"dtype={self.dtype})" + ) + + @classmethod + def from_hp( + cls, + w_hp: torch.Tensor, + granularity: Granularity = PerRow(), + act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, + ): + """Create Int8Tensor from high-precision tensor""" + block_size = get_block_size(w_hp.shape, granularity) + + if w_hp.dim() not in [2, 3] or len(block_size) != w_hp.dim(): + raise ValueError( + f"Expected 2D or 3D tensor with matching block_size dimensions, " + f"got tensor dim={w_hp.dim()}, block_size length={len(block_size)}" + ) + + scale, zero_point = choose_qparams_affine( + input=w_hp, + mapping_type=MappingType.SYMMETRIC, + block_size=block_size, + target_dtype=torch.int8, + quant_min=-128, + quant_max=127, + scale_dtype=w_hp.dtype, + zero_point_dtype=torch.int8, + ) + + int_data = quantize_affine( + w_hp, + block_size=block_size, + scale=scale, + zero_point=zero_point, + output_dtype=torch.int8, + ) + + return cls( + int_data, + scale, + granularity, + act_quant_kwargs=act_quant_kwargs, + dtype=w_hp.dtype, + ) + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """Dequantize int8 tensor to floating point""" + if output_dtype is None: + output_dtype = self.dtype + + block_size = get_block_size(self.qdata.shape, self.granularity) + + return dequantize_affine( + input=self.qdata, + block_size=block_size, + scale=self.scale, + zero_point=None, + input_dtype=torch.int8, + quant_min=-128, + quant_max=127, + output_dtype=output_dtype, + ) + + +implements = Int8Tensor.implements +implements_torch_function = Int8Tensor.implements_torch_function + + +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) +def _(func, types, args, kwargs): + """INT8 quantization: dynamic activation or weight-only""" + activation_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + if not isinstance(weight_tensor, Int8Tensor): + raise TypeError(f"Expected weight to be Int8Tensor, got {type(weight_tensor)}") + + output_dtype = activation_tensor.dtype + + if weight_tensor.act_quant_kwargs is not None: + # Dynamic activation quantization path + if not isinstance(activation_tensor, Int8Tensor): + activation_tensor = _choose_quant_func_and_quantize_tensor( + activation_tensor, weight_tensor.act_quant_kwargs + ) + + # 1. do the matrix form of dot(X_i, W_j) + # + # 2. rescale the output + # + # in cases with large matrices, y_dot_int32 can grow sufficiently + # large that y_dot_int32 * a FP16 scale is greater than the maximum + # value of a FP16, (which results in a value of inf even if multiplying + # by the other scale would bring it within the expected range) + + x_vals_int8 = activation_tensor.qdata + x_scales = activation_tensor.scale + w_vals_int8_t = weight_tensor.qdata.contiguous().t() + w_scales = weight_tensor.scale + + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + x_scales_dtype = x_scales.dtype + # Cast FP16 scale to float to avoid overflow in int_scaled_matmul + intermediate_dtype = ( + torch.float if x_scales_dtype == torch.half else x_scales_dtype + ) + y_dot_scaled = int_scaled_matmul( + tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype) + ) + y_dot_scaled = y_dot_scaled.to(x_scales_dtype) + + y = (y_dot_scaled * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] + ) + + else: + # FP × INT8 (weight-only) + w_vals_int8_t = weight_tensor.qdata.t() + m = torch.mm( + activation_tensor.reshape(-1, activation_tensor.shape[-1]), + w_vals_int8_t.to(activation_tensor.dtype), + ) + y = m * weight_tensor.scale.to(m.dtype) + y = y.reshape(*activation_tensor.shape[:-1], weight_tensor.qdata.shape[0]) + + if bias is not None: + y += bias + + return y.to(output_dtype) + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + """Slice operation for Int8Tensor""" + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + + if step != 1: + raise NotImplementedError( + f"Slicing with step != 1 is not supported, got step={step}" + ) + + if dim not in [0, 1, 2]: + raise ValueError(f"Only dim in [0, 1, 2] supported, got dim={dim}") + + if self.qdata.ndim not in [2, 3]: + raise ValueError(f"Expected qdata to be 2D or 3D, got {self.qdata.ndim}D") + + if end is None or end > self.shape[dim]: + end = self.shape[dim] + + sliced_qdata = aten.slice.Tensor(self.qdata, dim, start, end, step) + sliced_scale = _slice_scale_for_dimension( + self.scale, self.qdata.shape, dim, start, end, step + ) + + return return_and_correct_aliasing( + func, + args, + kwargs, + Int8Tensor( + sliced_qdata, + sliced_scale, + self.granularity, + self.act_quant_kwargs, + dtype=self.dtype, + ), + ) + + +@implements(aten.select.int) +def _(func, types, args, kwargs): + """Select operation for Int8Tensor""" + self, dim, index = args + if dim != 0: + raise NotImplementedError(f"Only dim=0 supported, got dim={dim}") + + selected_qdata = self.qdata[index] + selected_scale = _slice_scale_for_dimension( + self.scale, self.qdata.shape, dim, index, index + 1, step=1 + ).squeeze(0) + + return return_and_correct_aliasing( + func, + args, + kwargs, + Int8Tensor( + selected_qdata, + selected_scale, + self.granularity, + self.act_quant_kwargs, + self.dtype, + ), + ) + + +Int8Tensor.__module__ = "torchao.quantization" +torch.serialization.add_safe_globals([Int8Tensor, QuantizeTensorToInt8Kwargs])