Skip to content

Conversation

@namgyu-youn
Copy link
Contributor

@namgyu-youn namgyu-youn commented Oct 24, 2025

Summary:
Introduce a new tensor subclass API. Main features are

  • Int8Tensor: Main API, which handles quantization and dequantization operations
  • Utility operation functions: Tensor slice, index selection

This api is integrated to global variants (Int8WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig) using version, and not defined as a default.

Related Issue/PR:
This is reopened PR for #3038

Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Performance:
The following are the results of https://github.com/pytorch/ao/blob/main/tutorials/quantize_vit/run_vit_b_quant.py with a batch size of 32:

API With torch.compile Without torch.compile
Old 65.47 ms 234.39 ms
New 63.30 ms 239.30 ms

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3241

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 24, 2025
)

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_quantization_shapes(self, dtype):
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be a combination of two tests, one for dynamic quant one for static quant, can you use something like this:

@common_utils.parametrize("mode", ["dynamic", "weight-only"])

also I feel it might be better to not add static quant in this PR, and in a separate PR add both the tensor support and config support for static quant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, not sure to remove static flags (although its not fully implemented) before, but small PR should be always better I feel. I will remove static_scale and all those supports.

if act_quant_kwargs is not None and act_quant_kwargs.static_scale is not None:
# INT8 × INT8 (static)
scale = act_quant_kwargs.static_scale
zero_point = torch.zeros_like(scale, dtype=torch.int8)
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think user should specify static_zero_point as well

but again, it's better to do this in a separate PR, since current state is a half of the static quant feature (no config)

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should

  1. split the static quant support to separate PR
  2. follow what https://github.com/pytorch/ao/blob/main/torchao/dtypes/uintx/plain_layout.py is doing for quantized linear implementation

this should be a refactor PR, not a refactor + some extra modifications + some feature implementations I think

aten = torch.ops.aten

# Unsupported case for now, this would be 1 scale per data element
# Per-tensor quantization (scalar scale)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change related?

Copy link
Contributor Author

@namgyu-youn namgyu-youn Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is updated to support more granularity. Without this change, we can't use per-tensor (0D scale) and per-row (1D scale).

above comment is incorrect and this change is unrelated; #3241

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So maybe it's better to move this util function to a common place?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be moved to torchao/quantization/quantize_/common/utils.py I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, then I will move this to torchao/quantization/quantize_/common/utils.py after this PR.


@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
@common_utils.parametrize("has_bias", [True, False])
def test_weight_only_linear_with_bias(self, dtype, has_bias):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can probably be merged into the linear varaints test as well

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I think the tensor changes looks good, but need to make a linear_variants tests to make sure we cover different aspects of things (e.g. compile), see comments inline

can you also do a e2e perf check with https://github.com/pytorch/ao/blob/main/tutorials/quantize_vit/run_vit_b_quant.py to make sure the performance are the same before and after change for vit model?

also add a kernel check might be useful to make sure we don't regress things:

def test_expected_gpu_kernel_fbgemm(self):

@namgyu-youn
Copy link
Contributor Author

namgyu-youn commented Oct 31, 2025

Updated logs:

@Xia-Weiwen
Copy link
Collaborator

Xia-Weiwen commented Nov 3, 2025

Hi @namgyu-youn Do you plan to submit another PR for static quantization? We also need static quantization for SmoothQuant. So, we are wondering if you have a plan or we should consider adding it ourselves. Thanks. CC @cyxlily

@namgyu-youn
Copy link
Contributor Author

Hi @namgyu-youn Do you plan to submit another PR for static quantization? We also need static quantization for SmoothQuant. So, we are wondering if you have a plan or we should consider adding it ourselves. Thanks. CC @cyxlily

Yeah, static quantization support using static/dynamic flags is planned; I hope to show it to your team in the foreseeable future.

Also, in the SmoothQuant case, validating its support for the new quantization APIs (below) has higher priority, I think. Could you look into it?

  • W4A16-INT: Int4WeightOnlyConfig(group_size=32, version=2)
  • W4A16-FP: Float8WeightOnlyConfig(version=2)
  • W8A8-FP-dynamic: Float8DynamicActivationFloat8WeightConfig(version=2)

@Xia-Weiwen
Copy link
Collaborator

Yeah, static quantization support using static/dynamic flags is planned; I hope to show it to your team in the foreseeable future.

Thanks. Looking forward to it. If there is anything we can help with, please let us know.

Also, in the SmoothQuant case, validating its support for the new quantization APIs (below) has higher priority, I think. Could you look into it?

  • W4A16-INT: Int4WeightOnlyConfig(group_size=32, version=2)
  • W4A16-FP: Float8WeightOnlyConfig(version=2)
  • W8A8-FP-dynamic: Float8DynamicActivationFloat8WeightConfig(version=2)

By "validating them", do you mean adding test cases? And are W4A16 and W8A16 (I guess there is a typo in your comment) really needed for SmoothQuant? For W4A16 , it would be much the same as AWQ. And for W8A16, I think accuracy is generally good enough without SmoothQuant.

@namgyu-youn
Copy link
Contributor Author

By "validating them", do you mean adding test cases? And are W4A16 and W8A16 (I guess there is a typo in your comment) really needed for SmoothQuant? For W4A16 , it would be much the same as AWQ. And for W8A16, I think accuracy is generally good enough without SmoothQuant.

Oh yes, it was a typo (W8A16 is right), and W4A16-INT (Int4WeightOnlyConfig(group_size=32, version=2)) is of interest. In my last experience and https://arxiv.org/html/2411.02355v3, W4A16-INT is the most efficient choice for synchronous deployments, while W8A8-INT maximize throughput in asynchronous settings.

Because current AWQ/SmoothQuant test is only working with old APIs (version 1), we can replace it with new APIs like Int4WeightOnlyConfig(group_size=32, version=2) I guess.

@Xia-Weiwen
Copy link
Collaborator

By "validating them", do you mean adding test cases? And are W4A16 and W8A16 (I guess there is a typo in your comment) really needed for SmoothQuant? For W4A16 , it would be much the same as AWQ. And for W8A16, I think accuracy is generally good enough without SmoothQuant.

Oh yes, it was a typo (W8A16 is right), and W4A16-INT (Int4WeightOnlyConfig(group_size=32, version=2)) is of interest. In my last experience and https://arxiv.org/html/2411.02355v3, W4A16-INT is the most efficient choice for synchronous deployments, while W8A8-INT maximize throughput in asynchronous settings.

Because current AWQ/SmoothQuant test is only working with old APIs (version 1), we can replace it with new APIs like Int4WeightOnlyConfig(group_size=32, version=2) I guess.

I see. Thanks. We will evaluate that.

@Xia-Weiwen
Copy link
Collaborator

Hi @namgyu-youn May I know if you have a timeline to land this? Thanks.

@namgyu-youn
Copy link
Contributor Author

namgyu-youn commented Nov 11, 2025

dtype=self.dtype,
device="cuda",
)
linear.weight.data = self.weight_fp.cuda()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we do this? doesn't sound very necessary?

Int8WeightOnlyConfig(version=2),
],
)
def test_per_row_scale_shape(self, dtype, config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you merge the checks in this test to previous test test_int8_linear_variants?

f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, test_data)}"
)

@common_utils.parametrize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this have to be parametrize? I think what we need here is to check the code contains a sequence of ops / kernel calls, like this:

FileCheck().check_count(
"torch.ops.triton.quantize_fp8_row.default(", 1
).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default(", 1).check_not(
".run("
).run(code[0])

I think we can check 1. the quantize op and then 2. the mm op extern_kernels._int_mm, in a single run (see example), that should be enough

Comment on lines 62 to 65
self.weight_fp = torch.randn(*self.test_shape, dtype=self.dtype)
self.input_fp = torch.randn(*self.test_shape, dtype=self.dtype)
self.bias = torch.randn(self.test_shape[0], dtype=self.dtype)
self.block_size = list(self.test_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we probably don't need these, it's also easier for people to follow to define everything / most of things in the test itself


# Unsupported case for now, this would be 1 scale per data element
# Per-tensor quantization (scalar scale)
if scale.numel() == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: I think we can just check for ndim consistently everywhere, after #3324 is fixed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also isn't handling for per tensor and per row already included in original code?

if block_size_for_dim == 1:
# Scale is per-element along this dimension
# Slice away as normal
return aten.slice.Tensor(scale, dim, start, end, step)
else:
# There is blocking in this dimension
# Calculate which scale elements correspond to the sliced data
scale_start = start // block_size_for_dim if start is not None else None
scale_end = (
(end + block_size_for_dim - 1) // block_size_for_dim
if end is not None
else None
)
# Error on Step > 1
if step > 1:
raise NotImplementedError(
"Slicing with step > 1 is not implemented for scale tensors."
)
return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1)

Copy link
Contributor Author

@namgyu-youn namgyu-youn Nov 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted; I miscalculated block_sizes, and this change is unrelated. The original code is already handling per-tensor and per-row.

btw, this function might be improved by separating the granularity check I think; please check #3345 for this update.

Comment on lines +178 to +182
if len(act_kwargs.block_size) != input_ndim:
if input_ndim == 3 and len(act_kwargs.block_size) == 2:
block_size_updated = [1] + list(act_kwargs.block_size)
else:
block_size_updated = list(act_kwargs.block_size)[-input_ndim:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we kind of changed the meaning of block_size used in PerBlock quant recently, check

elif isinstance(granularity, PerBlock):

when is this code needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in principle we shouldn't update block_size here, but instead, make sure block_size make sense and is consistent throughout the code base

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment is not addressed

quantized_weight = Int8Tensor.from_hp(
weight,
block_size,
act_quant_kwargs=QuantizeTensorToInt8Kwargs(block_size=block_size),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does activation use the same block_size as weight?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can change block_size argument to granularity, if the block_size is unknown for activation since the shape is unknown

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please rebase, also some comments are not addressed yet I think

@namgyu-youn
Copy link
Contributor Author

please rebase, also some comments are not addressed yet I think

Thanks for the fast review, but please wait until the review request. I will address all comments this weekend :)

"""

group_size: Optional[int] = None
granularity: Optional[Union[PerRow, PerTensor]] = PerRow()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the type can just be Optional[Granularity] I think, we can do validation on supported granularity later

layout: Optional[Layout] = PlainLayout()
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
weight_only_decode: bool = False
granularity: Optional[Union[PerRow, PerTensor]] = PerRow()
Copy link
Contributor

@jerryzh168 jerryzh168 Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

also should this one be a tuple? since there is both activation and weight

or follow this:

granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None


assert config.version == 2, f"Unexpected version: {config.version}"
# Compute block_size from granularity for activation quantization kwargs
block_size = get_block_size(weight.shape, config.granularity)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is incorrect I think, we don't want to calculate the block_size for activation from weight

# TODO: Static quantization support using `static_scale`, `static_zero_point`
"""

block_size: list[int]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should probably just take granularity

Comment on lines +152 to +155
while scale.ndim < qdata_fp.ndim:
scale = scale.unsqueeze(-1)

scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, qdata_fp.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these needed?

scale = scale.unsqueeze(-1)

scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, qdata_fp.shape)
return qdata_fp * scale_expanded.to(output_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just call the dequantize_affine op directly

def dequantize_affine(

Comment on lines +223 to +225
# can downcast only at the very end
output_dtype = activation_tensor.dtype
y = y.to(output_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can probably do this after bias addition

y = y.reshape(*activation_tensor.shape[:-1], weight_tensor.qdata.shape[0])
if bias is not None:
y += bias
return y
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can do the cast right before return I think

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the updates, test looks good I think, still some comments that need to be addressed in int8_tensor.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants