Skip to content

Commit d11af10

Browse files
committed
update expected kernel test
1 parent 0c2bb76 commit d11af10

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -212,23 +212,27 @@ def test_dequantization_accuracy(self, config):
212212
f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, weight_fp)}"
213213
)
214214

215-
@common_utils.parametrize(
216-
"kernel",
217-
["triton_per_fused", "extern_kernels._int_mm", "triton_poi_fused"],
218-
)
219-
def test_available_gpu_kernels(self, kernel):
220-
"""Check which GPU kernels are available"""
215+
def test_available_gpu_kernels(self):
216+
"""Check which GPU kernels are used"""
217+
torch.compiler.reset()
218+
221219
M, K, N = 128, 256, 512
222220
m = torch.nn.Sequential(
223221
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
224222
)
223+
225224
config = Int8DynamicActivationInt8WeightConfig(version=2)
226225
quantize_(m, config)
226+
227227
m = torch.compile(m)
228228
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
229229

230230
out, code = run_and_get_code(m, x)
231-
FileCheck().check(kernel).run(code[0])
231+
232+
# Check expected kernels are present
233+
FileCheck().check_count("triton_per_fused", 1).check_count(
234+
"extern_kernels._int_mm", 1
235+
).check_count("triton_poi_fused", 1).run(code[0])
232236

233237

234238
if __name__ == "__main__":

0 commit comments

Comments
 (0)