@@ -212,23 +212,27 @@ def test_dequantization_accuracy(self, config):
212212 f"Dequantization error is too high to get a SQNR of { compute_error (dequantized , weight_fp )} "
213213 )
214214
215- @common_utils .parametrize (
216- "kernel" ,
217- ["triton_per_fused" , "extern_kernels._int_mm" , "triton_poi_fused" ],
218- )
219- def test_available_gpu_kernels (self , kernel ):
220- """Check which GPU kernels are available"""
215+ def test_available_gpu_kernels (self ):
216+ """Check which GPU kernels are used"""
217+ torch .compiler .reset ()
218+
221219 M , K , N = 128 , 256 , 512
222220 m = torch .nn .Sequential (
223221 torch .nn .Linear (K , N , device = "cuda" , dtype = torch .bfloat16 )
224222 )
223+
225224 config = Int8DynamicActivationInt8WeightConfig (version = 2 )
226225 quantize_ (m , config )
226+
227227 m = torch .compile (m )
228228 x = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
229229
230230 out , code = run_and_get_code (m , x )
231- FileCheck ().check (kernel ).run (code [0 ])
231+
232+ # Check expected kernels are present
233+ FileCheck ().check_count ("triton_per_fused" , 1 ).check_count (
234+ "extern_kernels._int_mm" , 1
235+ ).check_count ("triton_poi_fused" , 1 ).run (code [0 ])
232236
233237
234238if __name__ == "__main__" :
0 commit comments