Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,13 +653,13 @@ def test_qat_4w_primitives(self):

self._assert_close_4w(qat_out, ptq_out)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(_DEVICE is None, "skipping when GPU is not available")
def test_qat_4w_linear(self):
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear

group_size = 128
device = torch.device("cuda")
device = torch.device(_DEVICE)
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
qat_linear = Int4WeightOnlyQATLinear(
Expand Down Expand Up @@ -694,7 +694,11 @@ def test_qat_4w_quantizer_gradients(self):
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
self._test_qat_quantized_gradients(quantizer)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(_DEVICE is None, "skipping when GPU is not available")
@unittest.skipIf(
_DEVICE is torch.device("xpu"),
"skipped due to https://github.com/intel/torch-xpu-ops/issues/1770",
)
def test_qat_4w_quantizer(self):
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
Expand All @@ -711,8 +715,7 @@ def test_qat_4w_quantizer(self):
inner_k_tiles=inner_k_tiles,
)
ptq_quantizer = Int4WeightOnlyQuantizer(
groupsize=group_size,
inner_k_tiles=inner_k_tiles,
groupsize=group_size, inner_k_tiles=inner_k_tiles, device=device
)
qat_model = qat_quantizer.prepare(m)
ptq_model = ptq_quantizer.quantize(m2)
Expand Down Expand Up @@ -1893,12 +1896,12 @@ def _test_quantize_api_against_ptq(
torch.manual_seed(self.SEED)

if module_type == "linear":
m = M().to(dtype).cuda()
example_inputs = (m.example_inputs()[0].to(dtype).cuda(),)
m = M().to(dtype).to(_DEVICE)
example_inputs = (m.example_inputs()[0].to(dtype).to(_DEVICE),)
filter_fn = lambda m, fqn: isinstance(m, torch.nn.Linear)
elif module_type == "embedding":
m = M3().to(dtype).cuda()
example_inputs = (m.example_inputs()[0].cuda(),)
m = M3().to(dtype).to(_DEVICE)
example_inputs = (m.example_inputs()[0].to(_DEVICE),)
filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding)
else:
raise ValueError(f"Unknown module type {module_type}")
Expand Down Expand Up @@ -1973,7 +1976,7 @@ def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat
target_convert_sqnr=float("inf"),
)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(_DEVICE is None, "skipping when GPU is not available")
def test_quantize_api_int8_int4(self):
"""
Test the following:
Expand All @@ -1986,7 +1989,7 @@ def test_quantize_api_int8_int4(self):
target_convert_sqnr=float("inf"),
)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(_DEVICE is None, "skipping when GPU is not available")
@parametrize(
"weight_dtype, weight_granularity, dtype",
[
Expand All @@ -2011,7 +2014,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype):
dtype=dtype,
)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(_DEVICE is None, "skipping when GPU is not available")
@parametrize(
"weight_dtype, granularity, dtype, module_type",
[
Expand Down