|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | from contextlib import contextmanager |
| 4 | +import math |
4 | 5 | import os |
5 | 6 | from pathlib import Path |
6 | 7 | import random |
|
20 | 21 | from helion._testing import skipIfRocm |
21 | 22 | from helion.autotuner import DifferentialEvolutionSearch |
22 | 23 | from helion.autotuner.config_generation import ConfigGeneration |
| 24 | +from helion.autotuner.finite_search import FiniteSearch |
23 | 25 | from helion.autotuner.random_search import RandomSearch |
24 | 26 | import helion.language as hl |
25 | 27 | from helion.language import loops |
@@ -172,6 +174,105 @@ def test_differential_evolution_search(self): |
172 | 174 | fn = bound_kernel.compile_config(best) |
173 | 175 | torch.testing.assert_close(fn(*args), args[0] @ args[1], rtol=1e-2, atol=1e-1) |
174 | 176 |
|
| 177 | + def test_accuracy_check_filters_bad_config_wrong_output(self) -> None: |
| 178 | + bad_config = helion.Config(block_sizes=[1], num_warps=8) |
| 179 | + good_config = helion.Config(block_sizes=[1], num_warps=4) |
| 180 | + |
| 181 | + @helion.kernel(configs=[bad_config, good_config], autotune_log_level=0) |
| 182 | + def add_inplace(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
| 183 | + for tile in hl.tile(b.size()): |
| 184 | + b[tile] = a[tile] + b[tile] |
| 185 | + return b |
| 186 | + |
| 187 | + a = torch.randn([32], device=DEVICE) |
| 188 | + b = torch.randn([32], device=DEVICE) |
| 189 | + bound_kernel = add_inplace.bind((a, b)) |
| 190 | + |
| 191 | + original_compile = bound_kernel.compile_config |
| 192 | + |
| 193 | + def make_bad_config_produce_wrong_output( |
| 194 | + config: helion.Config, *, allow_print: bool = True |
| 195 | + ): |
| 196 | + fn = original_compile(config, allow_print=allow_print) |
| 197 | + if config == bad_config: |
| 198 | + return lambda *fn_args, **fn_kwargs: fn(*fn_args, **fn_kwargs) + 1 |
| 199 | + return fn |
| 200 | + |
| 201 | + with patch.object( |
| 202 | + bound_kernel, |
| 203 | + "compile_config", |
| 204 | + side_effect=make_bad_config_produce_wrong_output, |
| 205 | + ): |
| 206 | + search = FiniteSearch( |
| 207 | + bound_kernel, (a, b), configs=[bad_config, good_config] |
| 208 | + ) |
| 209 | + bad_time = search.benchmark(bad_config) |
| 210 | + assert math.isinf(bad_time) |
| 211 | + self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1) |
| 212 | + search.counters["accuracy_mismatch"] = 0 # reset counter |
| 213 | + |
| 214 | + good_time = search.benchmark(good_config) |
| 215 | + assert not math.isinf(good_time) |
| 216 | + self.assertEqual(search.counters.get("accuracy_mismatch", 0), 0) |
| 217 | + search.counters["accuracy_mismatch"] = 0 # reset counter |
| 218 | + |
| 219 | + best = search._autotune() |
| 220 | + self.assertEqual(best, good_config) |
| 221 | + self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1) |
| 222 | + |
| 223 | + def test_accuracy_check_filters_bad_config_wrong_arg_mutation(self) -> None: |
| 224 | + bad_config = helion.Config(block_sizes=[1], num_warps=8) |
| 225 | + good_config = helion.Config(block_sizes=[1], num_warps=4) |
| 226 | + |
| 227 | + @helion.kernel(configs=[bad_config, good_config], autotune_log_level=0) |
| 228 | + def add_inplace(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
| 229 | + for tile in hl.tile(b.size()): |
| 230 | + b[tile] = a[tile] + b[tile] |
| 231 | + return b |
| 232 | + |
| 233 | + a = torch.randn([32], device=DEVICE) |
| 234 | + b = torch.randn([32], device=DEVICE) |
| 235 | + bound_kernel = add_inplace.bind((a, b)) |
| 236 | + |
| 237 | + original_compile = bound_kernel.compile_config |
| 238 | + |
| 239 | + def make_bad_config_produce_wrong_input_arg_mutation( |
| 240 | + config: helion.Config, *, allow_print: bool = True |
| 241 | + ): |
| 242 | + fn = original_compile(config, allow_print=allow_print) |
| 243 | + if config == bad_config: |
| 244 | + |
| 245 | + def wrong_fn(*fn_args, **fn_kwargs): |
| 246 | + result = fn(*fn_args, **fn_kwargs) |
| 247 | + # Introduce an extra mutation so inputs differ from baseline |
| 248 | + fn_args[1].add_(1) |
| 249 | + return result |
| 250 | + |
| 251 | + return wrong_fn |
| 252 | + return fn |
| 253 | + |
| 254 | + with patch.object( |
| 255 | + bound_kernel, |
| 256 | + "compile_config", |
| 257 | + side_effect=make_bad_config_produce_wrong_input_arg_mutation, |
| 258 | + ): |
| 259 | + search = FiniteSearch( |
| 260 | + bound_kernel, (a, b), configs=[bad_config, good_config] |
| 261 | + ) |
| 262 | + bad_time = search.benchmark(bad_config) |
| 263 | + assert math.isinf(bad_time) |
| 264 | + self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1) |
| 265 | + search.counters["accuracy_mismatch"] = 0 # reset counter |
| 266 | + |
| 267 | + good_time = search.benchmark(good_config) |
| 268 | + assert not math.isinf(good_time) |
| 269 | + self.assertEqual(search.counters.get("accuracy_mismatch", 0), 0) |
| 270 | + search.counters["accuracy_mismatch"] = 0 # reset counter |
| 271 | + |
| 272 | + best = search._autotune() |
| 273 | + self.assertEqual(best, good_config) |
| 274 | + self.assertGreaterEqual(search.counters.get("accuracy_mismatch", 0), 1) |
| 275 | + |
175 | 276 | def test_use_default_config(self): |
176 | 277 | @helion.kernel(use_default_config=True) |
177 | 278 | def add(a, b): |
|
0 commit comments