-
Notifications
You must be signed in to change notification settings - Fork 370
introduce new int8 quantization API #3241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
namgyu-youn
wants to merge
13
commits into
pytorch:main
Choose a base branch
from
namgyu-youn:int8-quant-api
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
da7cfea
introduce new int8 quantization API
namgyu-youn 27076d3
refactor ops and update test cases
namgyu-youn cdb1d9f
update granularity slicing support
namgyu-youn 9f1b6c9
add 3D support to api, build linear variants test
namgyu-youn 9301717
add kernel detection test case
namgyu-youn caaba7a
refactor kernel test
namgyu-youn 0f51ee6
update linear variant, kernel detection test
namgyu-youn 305c3a9
update default granularity, kernel test
namgyu-youn 3ab38ba
fix quantization ops
namgyu-youn b516304
merge test cases with cleanup
namgyu-youn 0c2bb76
update `block_size` args to `granularity`
namgyu-youn d11af10
update expected kernel test
namgyu-youn 027afd8
Merge branch 'main' into int8-quant-api
namgyu-youn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
239 changes: 239 additions & 0 deletions
239
test/quantization/quantize_/workflows/int8/test_int8_tensor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,239 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import copy | ||
| import unittest | ||
|
|
||
| import torch | ||
| from torch._inductor.utils import run_and_get_code | ||
| from torch.testing import FileCheck | ||
| from torch.testing._internal import common_utils | ||
|
|
||
| from torchao.quantization import ( | ||
| Int8DynamicActivationInt8WeightConfig, | ||
| Int8WeightOnlyConfig, | ||
| quantize_, | ||
| ) | ||
| from torchao.quantization.granularity import PerRow, PerTensor | ||
| from torchao.quantization.utils import compute_error | ||
| from torchao.testing.utils import TorchAOIntegrationTestCase | ||
|
|
||
|
|
||
| # TODO: Refactor after https://github.com/pytorch/ao/pull/2729 is merged | ||
| class ToyTwoLinearModel(torch.nn.Module): | ||
| def __init__( | ||
| self, | ||
| input_dim, | ||
| hidden_dim, | ||
| output_dim, | ||
| has_bias=False, | ||
| dtype=None, | ||
| device=None, | ||
| ): | ||
| super().__init__() | ||
| self.dtype = dtype | ||
| self.device = device | ||
| self.linear1 = torch.nn.Linear( | ||
| input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device | ||
| ) | ||
| self.linear2 = torch.nn.Linear( | ||
| hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| x = self.linear1(x) | ||
| x = self.linear2(x) | ||
| return x | ||
|
|
||
|
|
||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| @common_utils.instantiate_parametrized_tests | ||
| class TestInt8Tensor(TorchAOIntegrationTestCase): | ||
| def setUp(self): | ||
| super().setUp() | ||
|
|
||
| self.test_shape = (32, 20) | ||
| self.dtype = torch.bfloat16 | ||
| self.batch_size = 32 | ||
|
|
||
| torch.manual_seed(42) | ||
|
|
||
| @common_utils.parametrize( | ||
| "config", | ||
| [ | ||
| Int8DynamicActivationInt8WeightConfig(version=2), | ||
| Int8WeightOnlyConfig(version=2), | ||
| ], | ||
| ) | ||
| def test_creation_and_attributes(self, config): | ||
| """Test tensor creation, dtypes, and ranges""" | ||
| linear = torch.nn.Linear( | ||
| self.test_shape[1], | ||
| self.test_shape[0], | ||
| bias=False, | ||
| dtype=self.dtype, | ||
| device="cuda", | ||
| ) | ||
| quantize_(linear, config) | ||
|
|
||
| w = linear.weight | ||
|
|
||
| self.assertEqual(w.shape, self.test_shape) | ||
| self.assertEqual(w.qdata.dtype, torch.int8) | ||
| self.assertTrue(torch.all(w.qdata >= -128) and torch.all(w.qdata <= 127)) | ||
|
|
||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) | ||
| @common_utils.parametrize("compile", [True, False]) | ||
| @common_utils.parametrize( | ||
| "config", | ||
| [ | ||
| Int8DynamicActivationInt8WeightConfig(version=2), | ||
| Int8WeightOnlyConfig(version=2), | ||
| ], | ||
| ) | ||
| @common_utils.parametrize( | ||
| "sizes", | ||
| [ | ||
| ((128,), 256, 128), # 2D | ||
| ((32, 128), 64, 256), # 3D | ||
| ], | ||
| ) | ||
| def test_int8_linear_variants( | ||
| self, | ||
| dtype: torch.dtype, | ||
| config, | ||
| compile: bool, | ||
| sizes: tuple, | ||
| ): | ||
| """Test linear operation supports including shape and compile""" | ||
| M, N, K = sizes | ||
| input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") | ||
| model = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval() | ||
| model_q = copy.deepcopy(model) | ||
|
|
||
| quantize_(model_q, config) | ||
|
|
||
| self.assertEqual(model_q.linear2.weight.scale.shape, (K,)) | ||
| self.assertEqual(model_q.linear2.weight.scale.ndim, 1) | ||
|
|
||
| if compile: | ||
| model_q = torch.compile(model_q, fullgraph=True) | ||
|
|
||
| output_fp = model(input_tensor) | ||
| output_quantized = model_q(input_tensor) | ||
|
|
||
| assert compute_error(output_fp, output_quantized) > 20, ( | ||
| f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}" | ||
| ) | ||
|
|
||
| @common_utils.parametrize( | ||
| "config", | ||
| [ | ||
| Int8DynamicActivationInt8WeightConfig(version=2), | ||
| Int8WeightOnlyConfig(version=2), | ||
| ], | ||
| ) | ||
| @common_utils.parametrize("device", ["cpu", "cuda"]) | ||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| def test_slice(self, config, device, dtype): | ||
| """Test tensor slicing with per-row quantization""" | ||
| tensor_size = 256 | ||
| slice_sizes = (64, 128) | ||
|
|
||
| dummy = torch.nn.Linear( | ||
| tensor_size, tensor_size, bias=False, dtype=dtype, device=device | ||
| ) | ||
| quantize_(dummy, config) | ||
|
|
||
| weight1 = dummy.weight.clone().narrow(0, 0, slice_sizes[0]) | ||
| weight2 = dummy.weight.clone().narrow(1, 0, slice_sizes[1]) | ||
|
|
||
| self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0])) | ||
| self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1])) | ||
| self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0])) | ||
| self.assertEqual(weight2.scale, dummy.weight.scale) | ||
| with self.assertRaises(NotImplementedError): | ||
| _ = dummy.weight[::2] | ||
|
|
||
| @common_utils.parametrize( | ||
| "config", | ||
| [ | ||
| Int8DynamicActivationInt8WeightConfig, | ||
| Int8WeightOnlyConfig, | ||
| ], | ||
| ) | ||
| @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) | ||
| def test_index_select(self, config, granularity): | ||
| """test that `x_0 = x[0]` works when `x` is a 2D quantized tensor.""" | ||
| N, K = 256, 512 | ||
| x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) | ||
| linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda") | ||
| linear.weight.data = x | ||
|
|
||
| config = config(version=2, granularity=granularity) | ||
| quantize_(linear, config) | ||
|
|
||
| x_int8 = linear.weight | ||
| x_int8_0 = x_int8[0] | ||
|
|
||
| # Test dequantization consistency | ||
| torch.testing.assert_close( | ||
| x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0 | ||
| ) | ||
|
|
||
| # Test block_size granularity | ||
| if isinstance(granularity, PerRow): | ||
| self.assertEqual(x_int8.block_size, [1, K]) | ||
| elif isinstance(granularity, PerTensor): | ||
| self.assertEqual(x_int8.block_size, [N, K]) | ||
|
|
||
| @common_utils.parametrize( | ||
| "config", | ||
| [ | ||
| Int8DynamicActivationInt8WeightConfig(version=2), | ||
| Int8WeightOnlyConfig(version=2), | ||
| ], | ||
| ) | ||
| def test_dequantization_accuracy(self, config): | ||
jerryzh168 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Test dequantization accuracy separately""" | ||
| linear = torch.nn.Linear( | ||
| 256, 512, bias=False, dtype=torch.bfloat16, device="cuda" | ||
| ) | ||
| weight_fp = copy.deepcopy(linear.weight) | ||
| quantize_(linear, config) | ||
|
|
||
| tensor = linear.weight | ||
| dequantized = tensor.dequantize() | ||
| self.assertEqual(dequantized.shape, weight_fp.shape) | ||
| assert compute_error(dequantized, weight_fp) > 20, ( | ||
| f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, weight_fp)}" | ||
| ) | ||
|
|
||
| def test_available_gpu_kernels(self): | ||
| """Check which GPU kernels are used""" | ||
| torch.compiler.reset() | ||
|
|
||
| M, K, N = 128, 256, 512 | ||
| m = torch.nn.Sequential( | ||
| torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16) | ||
| ) | ||
|
|
||
| config = Int8DynamicActivationInt8WeightConfig(version=2) | ||
| quantize_(m, config) | ||
|
|
||
| m = torch.compile(m) | ||
| x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) | ||
|
|
||
| out, code = run_and_get_code(m, x) | ||
|
|
||
| # Check expected kernels are present | ||
| FileCheck().check_count("triton_per_fused", 1).check_count( | ||
| "extern_kernels._int_mm", 1 | ||
| ).check_count("triton_poi_fused", 1).run(code[0]) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| common_utils.run_tests() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.