Skip to content

Commit 5ccf6f4

Browse files
authored
[Autotune] Filter bad config with accuracy check (#655)
1 parent 1701f8d commit 5ccf6f4

File tree

2 files changed

+164
-2
lines changed

2 files changed

+164
-2
lines changed

helion/autotuner/base_search.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
if TYPE_CHECKING:
2222
from triton.runtime.jit import JITFunction
2323

24+
import torch
2425
import torch.multiprocessing as mp
26+
from torch.utils._pytree import tree_flatten
27+
from torch.utils._pytree import tree_map
2528
from triton.testing import do_bench
2629

2730
from .. import exc
@@ -82,10 +85,63 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
8285
self.kernel = kernel
8386
self.settings: Settings = kernel.settings
8487
self.config_spec: ConfigSpec = kernel.config_spec
85-
self.args = args
88+
self.args: Sequence[object] = args
8689
self.counters: collections.Counter[str] = collections.Counter()
8790
self.log = LambdaLogger(self.settings.autotune_log_level)
8891
random.seed(self.settings.autotune_random_seed)
92+
self._original_args: Sequence[object] = self._clone_args(self.args)
93+
(
94+
self._baseline_output,
95+
self._kernel_mutates_args,
96+
self._baseline_post_args,
97+
) = self._compute_baseline()
98+
99+
def _clone_args(self, args: Sequence[object]) -> Sequence[object]:
100+
def _clone_leaf(leaf: object) -> object:
101+
if isinstance(leaf, torch.Tensor):
102+
clone = leaf.detach().clone()
103+
clone.requires_grad_(leaf.requires_grad)
104+
return clone
105+
return leaf
106+
107+
return tree_map(_clone_leaf, args)
108+
109+
def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]:
110+
"""
111+
Return output and post-run input arguments of the default-config kernel.
112+
Also detect if the kernel mutates any of its input arguments.
113+
"""
114+
new_args = self._clone_args(self._original_args)
115+
baseline_config = self.config_spec.default_config()
116+
baseline_output = self.kernel.compile_config(
117+
baseline_config, allow_print=False
118+
)(*new_args)
119+
original_args_flat, _ = tree_flatten(self._original_args)
120+
new_args_flat, _ = tree_flatten(new_args)
121+
mutated = False
122+
for old, new in zip(original_args_flat, new_args_flat, strict=False):
123+
if (
124+
isinstance(old, torch.Tensor)
125+
and isinstance(new, torch.Tensor)
126+
and (not torch.equal(new, old))
127+
):
128+
mutated = True
129+
break
130+
baseline_post_args = self._clone_args(new_args)
131+
return baseline_output, mutated, baseline_post_args
132+
133+
def _validate_against_baseline(
134+
self, config: Config, output: object, args: Sequence[object]
135+
) -> bool:
136+
try:
137+
torch.testing.assert_close(output, self._baseline_output)
138+
if self._kernel_mutates_args:
139+
torch.testing.assert_close(args, self._baseline_post_args)
140+
except AssertionError as e:
141+
self.counters["accuracy_mismatch"] += 1
142+
self.log.warning(f"Accuracy mismatch for {config!r}: {e!s}")
143+
return False
144+
return True
89145

90146
def benchmark(self, config: Config) -> float:
91147
"""
@@ -121,7 +177,12 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
121177
try:
122178
# TODO(jansel): early exit with fewer trials if early runs are slow
123179
t0 = time.perf_counter()
124-
fn(*self.args) # make sure the kernel is compiled
180+
if self._kernel_mutates_args:
181+
self.args = self._clone_args(self._original_args)
182+
output = fn(*self.args) # make sure the kernel is compiled
183+
if not self._validate_against_baseline(config, output, self.args):
184+
# Accuracy check failed; reject this config
185+
return inf
125186
t1 = time.perf_counter()
126187
res = do_bench(
127188
functools.partial(fn, *self.args),

test/test_autotuner.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from contextlib import contextmanager
4+
import math
45
import os
56
from pathlib import Path
67
import random
@@ -20,6 +21,7 @@
2021
from helion._testing import skipIfRocm
2122
from helion.autotuner import DifferentialEvolutionSearch
2223
from helion.autotuner.config_generation import ConfigGeneration
24+
from helion.autotuner.finite_search import FiniteSearch
2325
from helion.autotuner.random_search import RandomSearch
2426
import helion.language as hl
2527
from helion.language import loops
@@ -172,6 +174,105 @@ def test_differential_evolution_search(self):
172174
fn = bound_kernel.compile_config(best)
173175
torch.testing.assert_close(fn(*args), args[0] @ args[1], rtol=1e-2, atol=1e-1)
174176

177+
def test_accuracy_check_filters_bad_config_wrong_output(self) -> None:
178+
bad_config = helion.Config(block_sizes=[1], num_warps=8)
179+
good_config = helion.Config(block_sizes=[1], num_warps=4)
180+
181+
@helion.kernel(configs=[bad_config, good_config], autotune_log_level=0)
182+
def add_inplace(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
183+
for tile in hl.tile(b.size()):
184+
b[tile] = a[tile] + b[tile]
185+
return b
186+
187+
a = torch.randn([32], device=DEVICE)
188+
b = torch.randn([32], device=DEVICE)
189+
bound_kernel = add_inplace.bind((a, b))
190+
191+
original_compile = bound_kernel.compile_config
192+
193+
def make_bad_config_produce_wrong_output(
194+
config: helion.Config, *, allow_print: bool = True
195+
):
196+
fn = original_compile(config, allow_print=allow_print)
197+
if config == bad_config:
198+
return lambda *fn_args, **fn_kwargs: fn(*fn_args, **fn_kwargs) + 1
199+
return fn
200+
201+
with patch.object(
202+
bound_kernel,
203+
"compile_config",
204+
side_effect=make_bad_config_produce_wrong_output,
205+
):
206+
search = FiniteSearch(
207+
bound_kernel, (a, b), configs=[bad_config, good_config]
208+
)
209+
bad_time = search.benchmark(bad_config)
210+
assert math.isinf(bad_time)
211+
self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1)
212+
search.counters["accuracy_mismatch"] = 0 # reset counter
213+
214+
good_time = search.benchmark(good_config)
215+
assert not math.isinf(good_time)
216+
self.assertEqual(search.counters.get("accuracy_mismatch", 0), 0)
217+
search.counters["accuracy_mismatch"] = 0 # reset counter
218+
219+
best = search._autotune()
220+
self.assertEqual(best, good_config)
221+
self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1)
222+
223+
def test_accuracy_check_filters_bad_config_wrong_arg_mutation(self) -> None:
224+
bad_config = helion.Config(block_sizes=[1], num_warps=8)
225+
good_config = helion.Config(block_sizes=[1], num_warps=4)
226+
227+
@helion.kernel(configs=[bad_config, good_config], autotune_log_level=0)
228+
def add_inplace(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
229+
for tile in hl.tile(b.size()):
230+
b[tile] = a[tile] + b[tile]
231+
return b
232+
233+
a = torch.randn([32], device=DEVICE)
234+
b = torch.randn([32], device=DEVICE)
235+
bound_kernel = add_inplace.bind((a, b))
236+
237+
original_compile = bound_kernel.compile_config
238+
239+
def make_bad_config_produce_wrong_input_arg_mutation(
240+
config: helion.Config, *, allow_print: bool = True
241+
):
242+
fn = original_compile(config, allow_print=allow_print)
243+
if config == bad_config:
244+
245+
def wrong_fn(*fn_args, **fn_kwargs):
246+
result = fn(*fn_args, **fn_kwargs)
247+
# Introduce an extra mutation so inputs differ from baseline
248+
fn_args[1].add_(1)
249+
return result
250+
251+
return wrong_fn
252+
return fn
253+
254+
with patch.object(
255+
bound_kernel,
256+
"compile_config",
257+
side_effect=make_bad_config_produce_wrong_input_arg_mutation,
258+
):
259+
search = FiniteSearch(
260+
bound_kernel, (a, b), configs=[bad_config, good_config]
261+
)
262+
bad_time = search.benchmark(bad_config)
263+
assert math.isinf(bad_time)
264+
self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1)
265+
search.counters["accuracy_mismatch"] = 0 # reset counter
266+
267+
good_time = search.benchmark(good_config)
268+
assert not math.isinf(good_time)
269+
self.assertEqual(search.counters.get("accuracy_mismatch", 0), 0)
270+
search.counters["accuracy_mismatch"] = 0 # reset counter
271+
272+
best = search._autotune()
273+
self.assertEqual(best, good_config)
274+
self.assertGreaterEqual(search.counters.get("accuracy_mismatch", 0), 1)
275+
175276
def test_use_default_config(self):
176277
@helion.kernel(use_default_config=True)
177278
def add(a, b):

0 commit comments

Comments
 (0)