Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,44 @@ def forward(self, x):
sqnr = compute_error(original, quantized)
self.assertTrue(sqnr > 20)

@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai")
def test_bmm_weight_in_bkn_layout(self):
# Tests rowwise quantization of a 3d weight stored with shape (B, K, N)
# and contigous with that shape. Since the `K` dimension is not last, we
# need to specify granularity with `PerRow(1)`.

# only support per row quantization
granularity = [PerRow(), PerRow(1)]
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)

class Model(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight

def forward(self, x):
return torch.bmm(x, self.weight)

dtype = torch.bfloat16
device = "cuda"

B, M, K, N = 10, 32, 128, 256

input = torch.randn(B, M, K, dtype=dtype, device=device)
weight = torch.randn(B, K, N, dtype=dtype, device=device)
m = Model(weight).eval()
original = m(input)
quantize_(m, config, filter_fn=lambda x, fqn: True)

assert m.weight.scale.shape == (B, 1, N), (
f"unexpected scale shape {m.weight.scale.shape}"
)

quantized = m(input)
sqnr = compute_error(original, quantized)
self.assertTrue(sqnr > 20)

@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@common_utils.parametrize(
"sizes",
Expand Down
25 changes: 25 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch

from torchao.quantization.granularity import PerRow
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
Expand All @@ -27,6 +28,7 @@
# TODO: remove test for utils?
from torchao.quantization.utils import (
_quantize_activation_per_token_absmax,
get_block_size,
get_group_qparams_symmetric,
groupwise_affine_dequantize_tensor_from_qparams,
groupwise_affine_quantize_tensor_from_qparams,
Expand Down Expand Up @@ -844,6 +846,29 @@ def test_float8_blockwise_scaling(self):
torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0)
torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0)

def test_float8_rowwise_scaling_3d_weight_axis_1(self):
"""
Test scaling a weight with shape (B, K, N) and row-major memory layout
across the K dimension.
"""

B, K, N = 8, 16, 32
hp_tensor = torch.randn(B, K, N, dtype=torch.float)

granularity = PerRow(1)
block_size = get_block_size(hp_tensor.shape, granularity)
scale = _choose_scale_float8(
hp_tensor,
float8_dtype=torch.float8_e4m3fn,
block_size=block_size,
hp_value_lb=None,
hp_value_ub=None,
)
data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn)

assert scale.shape == (B, 1, N)
assert data.shape == (B, K, N)


if __name__ == "__main__":
unittest.main()
25 changes: 17 additions & 8 deletions torchao/quantization/granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ class PerAxis(Granularity):
This granularity type calculates different quantization parameters
along a specified axis of the tensor.

For example if the input tensor is shape [8, 16] and axis=0, then
the quantization parameters are calculated for each row of the tensor.
Giving a total of 8 quantization parameters.
Examples:
* input_tensor shape [A, B], axis 0 -> scale_shape [A, 1]
* input_tensor shape [A, B], axis 1 -> scale_shape [1, B]
* input_tensor shape [A, B, C], axis 1 -> scale_shape [1, B, 1]

Attributes:
axis (int): The axis along which reduction is performed.
axis (int): The axis which is kept, reduction is performed across all
the other axes
"""

axis: int
Expand Down Expand Up @@ -76,12 +78,19 @@ class PerRow(Granularity):
"""
Represents row-wise granularity in quantization.

This is a special case of per-axis quantization and is unique to Float8 matmuls
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
is quantized with a block_size of (1, weight.shape[1]).
Examples:
* input_tensor shape [A, B], dim 0 -> scale_shape [1, B]
* input_tensor shape [A, B], dim 1 -> scale_shape [A, 1]
* input_tensor shape [A, B], dim -1 -> scale_shape [A, 1]
* input_tensor shape [A, B, C], dim 1 -> scale_shape [A, 1, C]

Attributes:
dim (int): The dim which is reduced across, all other dims are kept
"""

pass
# TODO(before land): any BC concerns with loading old checkpoints
# serialized without this arg? investigate this
dim: int = -1


@dataclass(frozen=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def from_hp(
and _is_fbgemm_gpu_genai_available()
and is_sm_at_least_90()
and isinstance(granularity, PerRow)
# fbgemm path only supports quantizing along the last dim
and granularity.dim in (-1, len(hp_tensor.shape) - 1)
and float8_dtype == torch.float8_e4m3fn
and hp_value_lb is None
):
Expand Down Expand Up @@ -438,7 +440,7 @@ def _(func, types, args, kwargs):

res = torch.ops.fbgemm.f8f8bf16_rowwise_batched(
a_data,
b_data.transpose(-2, -1),
b_data.transpose(-2, -1).contiguous(),
a_scale,
b_scale.transpose(-2, -1),
b_scale,
Expand Down
6 changes: 5 additions & 1 deletion torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,8 +723,12 @@ def get_block_size(
f"Not all shapes in input shape {input_shape} are divisible by block size {block_size}"
)
return block_size
elif isinstance(granularity, (PerRow, PerToken)):
elif isinstance(granularity, PerToken):
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
elif isinstance(granularity, PerRow):
block_size = [1] * len(input_shape)
block_size[granularity.dim] = input_shape[granularity.dim]
return tuple(block_size)
elif isinstance(granularity, PerGroup):
assert input_shape[-1] % granularity.group_size == 0, (
f"Last dimension of input {input_shape[-1]} is not divisible by group size {granularity.group_size}"
Expand Down
4 changes: 3 additions & 1 deletion torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,9 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig):
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
# making the weight different
dummy_l.weight = torch.nn.Parameter(
dummy_l.weight + 2 * torch.randn(1024, 1024, device=device, dtype=dtype),
dummy_l.weight
+ 1.0
+ 2 * torch.randn(1024, 1024, device=device, dtype=dtype),
requires_grad=False,
)
quantize_(dummy_l, config)
Expand Down
Loading