|
98 | 98 | ) |
99 | 99 | from torchao.utils import ( |
100 | 100 | _is_fbgemm_gpu_genai_available, |
| 101 | + get_current_accelerator_device, |
101 | 102 | is_fbcode, |
102 | 103 | is_sm_at_least_89, |
103 | 104 | ) |
104 | 105 |
|
105 | 106 | # TODO: put this in a common test utils file |
106 | 107 | _CUDA_IS_AVAILABLE = torch.cuda.is_available() |
| 108 | +_DEVICE = get_current_accelerator_device() |
107 | 109 |
|
108 | 110 |
|
109 | 111 | class Sub(torch.nn.Module): |
@@ -347,7 +349,7 @@ def _set_ptq_weight( |
347 | 349 | group_size, |
348 | 350 | ) |
349 | 351 | q_weight = torch.ops.aten._convert_weight_to_int4pack( |
350 | | - q_weight.to("cuda"), |
| 352 | + q_weight.to(_DEVICE), |
351 | 353 | qat_linear.inner_k_tiles, |
352 | 354 | ) |
353 | 355 | ptq_linear.weight = q_weight |
@@ -600,13 +602,13 @@ def _assert_close_4w(self, val, ref): |
600 | 602 | print(mean_err) |
601 | 603 | self.assertTrue(mean_err < 0.05) |
602 | 604 |
|
603 | | - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") |
| 605 | + @unittest.skipIf(_DEVICE is None, "skipping when gpu is not available") |
604 | 606 | def test_qat_4w_primitives(self): |
605 | 607 | n_bit = 4 |
606 | 608 | group_size = 32 |
607 | 609 | inner_k_tiles = 8 |
608 | 610 | scales_precision = torch.bfloat16 |
609 | | - device = torch.device("cuda") |
| 611 | + device = torch.device(_DEVICE) |
610 | 612 | dtype = torch.bfloat16 |
611 | 613 | torch.manual_seed(self.SEED) |
612 | 614 | x = torch.randn(100, 256, dtype=dtype, device=device) |
@@ -699,7 +701,7 @@ def test_qat_4w_quantizer(self): |
699 | 701 |
|
700 | 702 | group_size = 32 |
701 | 703 | inner_k_tiles = 8 |
702 | | - device = torch.device("cuda") |
| 704 | + device = torch.device(_DEVICE) |
703 | 705 | dtype = torch.bfloat16 |
704 | 706 | torch.manual_seed(self.SEED) |
705 | 707 | m = M().to(device).to(dtype) |
|
0 commit comments