Skip to content

Commit 8d546c8

Browse files
authored
[xpu][test] Port 2 test/quantization_{qat, quant_api} UT files to intel XPU (#3351)
* port 2 files to intel XPU * port 2 files to intel XPU * update * update
1 parent 6e21a1f commit 8d546c8

File tree

2 files changed

+70
-57
lines changed

2 files changed

+70
-57
lines changed

test/quantization/test_qat.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,14 @@
9898
)
9999
from torchao.utils import (
100100
_is_fbgemm_gpu_genai_available,
101+
get_current_accelerator_device,
101102
is_fbcode,
102103
is_sm_at_least_89,
103104
)
104105

105106
# TODO: put this in a common test utils file
106107
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
108+
_DEVICE = get_current_accelerator_device()
107109

108110

109111
class Sub(torch.nn.Module):
@@ -347,7 +349,7 @@ def _set_ptq_weight(
347349
group_size,
348350
)
349351
q_weight = torch.ops.aten._convert_weight_to_int4pack(
350-
q_weight.to("cuda"),
352+
q_weight.to(_DEVICE),
351353
qat_linear.inner_k_tiles,
352354
)
353355
ptq_linear.weight = q_weight
@@ -600,13 +602,13 @@ def _assert_close_4w(self, val, ref):
600602
print(mean_err)
601603
self.assertTrue(mean_err < 0.05)
602604

603-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
605+
@unittest.skipIf(_DEVICE is None, "skipping when gpu is not available")
604606
def test_qat_4w_primitives(self):
605607
n_bit = 4
606608
group_size = 32
607609
inner_k_tiles = 8
608610
scales_precision = torch.bfloat16
609-
device = torch.device("cuda")
611+
device = torch.device(_DEVICE)
610612
dtype = torch.bfloat16
611613
torch.manual_seed(self.SEED)
612614
x = torch.randn(100, 256, dtype=dtype, device=device)
@@ -699,7 +701,7 @@ def test_qat_4w_quantizer(self):
699701

700702
group_size = 32
701703
inner_k_tiles = 8
702-
device = torch.device("cuda")
704+
device = torch.device(_DEVICE)
703705
dtype = torch.bfloat16
704706
torch.manual_seed(self.SEED)
705707
m = M().to(device).to(dtype)

0 commit comments

Comments
 (0)