Skip to content

Commit e612b6f

Browse files
committed
[Autotuner] Use cudagraph for time measurement on Nvidia hardware
1 parent 4db264a commit e612b6f

File tree

4 files changed

+128
-11
lines changed

4 files changed

+128
-11
lines changed

helion/_testing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Generator
1818
import unittest
1919

20-
import pytest
2120
import torch
2221
from torch.utils._pytree import tree_map
2322
import triton
@@ -267,6 +266,8 @@ def setUp(self) -> None:
267266
if not self._in_ref_eager_mode:
268267
return
269268

269+
import pytest
270+
270271
# Reset assert_close counter for this test
271272
RefEagerTestBase._assert_close_count = 0
272273
# Reset assertRaises counter for this test
@@ -361,6 +362,8 @@ def tearDown(self) -> None:
361362
super().tearDown() # type: ignore[misc]
362363
return
363364

365+
import pytest
366+
364367
try:
365368
# Exit the run_ref tracker
366369
self._run_ref_tracker.__exit__(None, None, None)

helion/autotuner/base_search.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@
4040
from triton.testing import do_bench
4141

4242
from .. import exc
43+
from .._testing import is_cuda
4344
from ..runtime.kernel import BoundKernel
4445
from ..runtime.precompile_shim import already_compiled
4546
from ..runtime.precompile_shim import make_precompiler
47+
from .bench_utils import do_bench_cudagraph_with_cache_clear
4648
from .benchmarking import interleaved_bench
4749
from .config_generation import ConfigGeneration
4850
from .config_generation import FlatConfig
@@ -325,12 +327,18 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
325327
# Accuracy check failed; reject this config
326328
return inf
327329
t1 = time.perf_counter()
328-
res = do_bench(
329-
functools.partial(fn, *self.args),
330-
return_mode="median",
331-
warmup=1, # we are already warmed up above
332-
rep=50,
333-
)
330+
kwargs = {
331+
"fn": functools.partial(fn, *self.args),
332+
"rep": 50,
333+
}
334+
if is_cuda():
335+
res = do_bench_cudagraph_with_cache_clear(**kwargs)
336+
else:
337+
res = do_bench(
338+
**kwargs,
339+
warmup=1, # we are already warmed up above
340+
return_mode="median",
341+
)
334342
t2 = time.perf_counter()
335343
assert isinstance(res, float)
336344
self.log.debug(

helion/autotuner/bench_utils.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from __future__ import annotations
2+
3+
from typing import Callable
4+
from typing import Sequence
5+
6+
import torch
7+
import triton
8+
9+
10+
def do_bench_cudagraph_with_cache_clear(
11+
fn: Callable[[], object],
12+
rep: int = 20,
13+
grad_to_none: Sequence[torch.Tensor] | None = None,
14+
) -> float:
15+
"""
16+
Clone of triton.testing.do_bench_cudagraph with explicit L2 cache clearing.
17+
Only supports calculating mean execution time.
18+
19+
Args:
20+
fn: Function to benchmark
21+
rep: Target total measurement time in milliseconds
22+
grad_to_none: Tensors whose gradients should be cleared before each measurement
23+
24+
Returns:
25+
Mean execution time in milliseconds
26+
"""
27+
# Get a cache tensor and function to zero it for L2 cache clearing
28+
cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() # type: ignore[attr-defined]
29+
clear_cache_fn = cache.zero_
30+
31+
# Use a separate CUDA stream for all benchmark operations
32+
with torch.cuda.stream(torch.cuda.Stream()):
33+
# Warmup: clear cache and run function once to ensure it's compiled
34+
clear_cache_fn()
35+
fn()
36+
37+
# Reset gradients if needed (for autograd-enabled benchmarks)
38+
if grad_to_none is not None:
39+
for x in grad_to_none:
40+
x.detach_()
41+
x.requires_grad_(True)
42+
x.grad = None
43+
44+
# Estimate execution time
45+
start_event = torch.cuda.Event(enable_timing=True)
46+
end_event = torch.cuda.Event(enable_timing=True)
47+
start_event.record()
48+
for _ in range(5):
49+
clear_cache_fn()
50+
fn()
51+
end_event.record()
52+
torch.cuda.synchronize()
53+
estimate_ms = start_event.elapsed_time(end_event) / 5
54+
55+
# Calculate number of repetitions needed to reach target measurement time (rep)
56+
n_repeat = 1000 if estimate_ms == 0 else max(1, int(rep / estimate_ms))
57+
58+
# Create a CUDA graph for the actual kernel execution + cache clearing
59+
g = torch.cuda.CUDAGraph()
60+
with torch.cuda.graph(g):
61+
for _ in range(n_repeat):
62+
if grad_to_none is not None:
63+
for x in grad_to_none:
64+
x.grad = None
65+
clear_cache_fn()
66+
fn()
67+
torch.cuda.synchronize()
68+
69+
# Create a separate CUDA graph for just cache clearing
70+
cache_clear_graph = torch.cuda.CUDAGraph()
71+
with torch.cuda.graph(cache_clear_graph):
72+
for _ in range(n_repeat):
73+
clear_cache_fn()
74+
torch.cuda.synchronize()
75+
76+
# Measure time for cache clearing only
77+
cache_clear_start_event = torch.cuda.Event(enable_timing=True)
78+
cache_clear_end_event = torch.cuda.Event(enable_timing=True)
79+
cache_clear_start_event.record()
80+
cache_clear_graph.replay()
81+
cache_clear_end_event.record()
82+
torch.cuda.synchronize()
83+
cache_clear_time = (
84+
cache_clear_start_event.elapsed_time(cache_clear_end_event) / n_repeat
85+
)
86+
87+
# Measure total time (cache clearing + kernel execution)
88+
start_event = torch.cuda.Event(enable_timing=True)
89+
end_event = torch.cuda.Event(enable_timing=True)
90+
start_event.record()
91+
g.replay()
92+
end_event.record()
93+
torch.cuda.synchronize()
94+
total_time = start_event.elapsed_time(end_event) / n_repeat
95+
96+
# Subtract cache clearing overhead to get pure kernel execution time
97+
return total_time - cache_clear_time

test/test_debug_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from helion._testing import DEVICE
1414
from helion._testing import RefEagerTestDisabled
1515
from helion._testing import TestCase
16+
from helion._testing import is_cuda
1617
from helion._testing import skipIfCpu
1718
import helion.language as hl
1819

@@ -142,20 +143,28 @@ def test_print_repro_on_autotune_error(self):
142143
torch.manual_seed(0)
143144
x = torch.randn([128], dtype=torch.float32, device=DEVICE)
144145

145-
# Mock do_bench to fail on the second config with PTXASError (warn level)
146+
# Mock benchmark helper to fail on the second config with PTXASError (warn level)
146147
from torch._inductor.runtime.triton_compat import PTXASError
147-
from triton.testing import do_bench as original_do_bench
148+
149+
from helion.autotuner import base_search
148150

149151
call_count = [0]
150152

153+
bench_attr = (
154+
"do_bench_cudagraph_with_cache_clear" if is_cuda() else "do_bench"
155+
)
156+
157+
original_bench = getattr(base_search, bench_attr)
158+
bench_target = f"helion.autotuner.base_search.{bench_attr}"
159+
151160
def mock_do_bench(*args, **kwargs):
152161
call_count[0] += 1
153162
if call_count[0] == 2: # Fail on second config
154163
raise PTXASError("Mocked PTXAS error")
155-
return original_do_bench(*args, **kwargs)
164+
return original_bench(*args, **kwargs)
156165

157166
with self.capture_output() as output_capture:
158-
with mock.patch("helion.autotuner.base_search.do_bench", mock_do_bench):
167+
with mock.patch(bench_target, mock_do_bench):
159168
# Autotune will try both configs, second one will fail and print repro
160169
kernel.autotune([x], force=False)
161170

0 commit comments

Comments
 (0)