Skip to content

Commit 5001a53

Browse files
committed
[Autotuner] Use cudagraph for time measurement on Nvidia hardware
1 parent 4db264a commit 5001a53

File tree

4 files changed

+123
-11
lines changed

4 files changed

+123
-11
lines changed

helion/_testing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Generator
1818
import unittest
1919

20-
import pytest
2120
import torch
2221
from torch.utils._pytree import tree_map
2322
import triton
@@ -267,6 +266,8 @@ def setUp(self) -> None:
267266
if not self._in_ref_eager_mode:
268267
return
269268

269+
import pytest
270+
270271
# Reset assert_close counter for this test
271272
RefEagerTestBase._assert_close_count = 0
272273
# Reset assertRaises counter for this test
@@ -361,6 +362,8 @@ def tearDown(self) -> None:
361362
super().tearDown() # type: ignore[misc]
362363
return
363364

365+
import pytest
366+
364367
try:
365368
# Exit the run_ref tracker
366369
self._run_ref_tracker.__exit__(None, None, None)

helion/autotuner/base_search.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@
4040
from triton.testing import do_bench
4141

4242
from .. import exc
43+
from .._testing import is_cuda
4344
from ..runtime.kernel import BoundKernel
4445
from ..runtime.precompile_shim import already_compiled
4546
from ..runtime.precompile_shim import make_precompiler
47+
from .benchmarking import do_bench_cudagraph_with_cache_clear
4648
from .benchmarking import interleaved_bench
4749
from .config_generation import ConfigGeneration
4850
from .config_generation import FlatConfig
@@ -325,12 +327,18 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
325327
# Accuracy check failed; reject this config
326328
return inf
327329
t1 = time.perf_counter()
328-
res = do_bench(
329-
functools.partial(fn, *self.args),
330-
return_mode="median",
331-
warmup=1, # we are already warmed up above
332-
rep=50,
333-
)
330+
kwargs = {
331+
"fn": functools.partial(fn, *self.args),
332+
"rep": 50,
333+
}
334+
if is_cuda():
335+
res = do_bench_cudagraph_with_cache_clear(**kwargs)
336+
else:
337+
res = do_bench(
338+
**kwargs,
339+
return_mode="median",
340+
warmup=1, # we are already warmed up above
341+
)
334342
t2 = time.perf_counter()
335343
assert isinstance(res, float)
336344
self.log.debug(

helion/autotuner/benchmarking.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,104 @@
44
import math
55
import statistics
66
from typing import Callable
7+
from typing import Sequence
78

9+
import torch
10+
import triton
811
from triton import runtime
912

1013
from .progress_bar import iter_with_progress
1114

1215

16+
def do_bench_cudagraph_with_cache_clear(
17+
fn: Callable[[], object],
18+
rep: int = 20,
19+
grad_to_none: Sequence[torch.Tensor] | None = None,
20+
) -> float:
21+
"""
22+
Clone of triton.testing.do_bench_cudagraph with explicit L2 cache clearing.
23+
Only supports calculating mean execution time.
24+
25+
Args:
26+
fn: Function to benchmark
27+
rep: Target total measurement time in milliseconds
28+
grad_to_none: Tensors whose gradients should be cleared before each measurement
29+
30+
Returns:
31+
Mean execution time in milliseconds
32+
"""
33+
# Get a cache tensor and function to zero it for L2 cache clearing
34+
cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() # type: ignore[attr-defined]
35+
clear_cache_fn = cache.zero_
36+
37+
with torch.cuda.stream(torch.cuda.Stream()):
38+
# Warmup: clear cache and run function once to ensure it's compiled
39+
clear_cache_fn()
40+
fn()
41+
42+
# Reset gradients if needed
43+
if grad_to_none is not None:
44+
for x in grad_to_none:
45+
x.detach_()
46+
x.requires_grad_(True)
47+
x.grad = None
48+
49+
# Estimate execution time
50+
start_event = torch.cuda.Event(enable_timing=True)
51+
end_event = torch.cuda.Event(enable_timing=True)
52+
start_event.record()
53+
for _ in range(5):
54+
clear_cache_fn()
55+
fn()
56+
end_event.record()
57+
torch.cuda.synchronize()
58+
estimate_ms = start_event.elapsed_time(end_event) / 5
59+
60+
# Calculate number of repetitions needed to reach target measurement time (rep)
61+
n_repeat = 1000 if estimate_ms == 0 else max(1, int(rep / estimate_ms))
62+
63+
# Create a CUDA graph for measuring total time (cache clearing + kernel execution)
64+
g = torch.cuda.CUDAGraph()
65+
with torch.cuda.graph(g):
66+
for _ in range(n_repeat):
67+
if grad_to_none is not None:
68+
for x in grad_to_none:
69+
x.grad = None
70+
clear_cache_fn()
71+
fn()
72+
torch.cuda.synchronize()
73+
74+
# Create a separate CUDA graph for measuring cache clearing time only
75+
cache_clear_graph = torch.cuda.CUDAGraph()
76+
with torch.cuda.graph(cache_clear_graph):
77+
for _ in range(n_repeat):
78+
clear_cache_fn()
79+
torch.cuda.synchronize()
80+
81+
# Measure time for cache clearing only
82+
cache_clear_start_event = torch.cuda.Event(enable_timing=True)
83+
cache_clear_end_event = torch.cuda.Event(enable_timing=True)
84+
cache_clear_start_event.record()
85+
cache_clear_graph.replay()
86+
cache_clear_end_event.record()
87+
torch.cuda.synchronize()
88+
cache_clear_time = (
89+
cache_clear_start_event.elapsed_time(cache_clear_end_event) / n_repeat
90+
)
91+
92+
# Measure total time (cache clearing + kernel execution)
93+
start_event = torch.cuda.Event(enable_timing=True)
94+
end_event = torch.cuda.Event(enable_timing=True)
95+
start_event.record()
96+
g.replay()
97+
end_event.record()
98+
torch.cuda.synchronize()
99+
total_time = start_event.elapsed_time(end_event) / n_repeat
100+
101+
# Subtract cache clearing overhead to get pure kernel execution time
102+
return total_time - cache_clear_time
103+
104+
13105
def compute_repeat(
14106
fn: Callable[[], object],
15107
*,

test/test_debug_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from helion._testing import DEVICE
1414
from helion._testing import RefEagerTestDisabled
1515
from helion._testing import TestCase
16+
from helion._testing import is_cuda
1617
from helion._testing import skipIfCpu
1718
import helion.language as hl
1819

@@ -142,20 +143,28 @@ def test_print_repro_on_autotune_error(self):
142143
torch.manual_seed(0)
143144
x = torch.randn([128], dtype=torch.float32, device=DEVICE)
144145

145-
# Mock do_bench to fail on the second config with PTXASError (warn level)
146+
# Mock benchmark helper to fail on the second config with PTXASError (warn level)
146147
from torch._inductor.runtime.triton_compat import PTXASError
147-
from triton.testing import do_bench as original_do_bench
148+
149+
from helion.autotuner import base_search
148150

149151
call_count = [0]
150152

153+
bench_attr = (
154+
"do_bench_cudagraph_with_cache_clear" if is_cuda() else "do_bench"
155+
)
156+
157+
original_bench = getattr(base_search, bench_attr)
158+
bench_target = f"helion.autotuner.base_search.{bench_attr}"
159+
151160
def mock_do_bench(*args, **kwargs):
152161
call_count[0] += 1
153162
if call_count[0] == 2: # Fail on second config
154163
raise PTXASError("Mocked PTXAS error")
155-
return original_do_bench(*args, **kwargs)
164+
return original_bench(*args, **kwargs)
156165

157166
with self.capture_output() as output_capture:
158-
with mock.patch("helion.autotuner.base_search.do_bench", mock_do_bench):
167+
with mock.patch(bench_target, mock_do_bench):
159168
# Autotune will try both configs, second one will fail and print repro
160169
kernel.autotune([x], force=False)
161170

0 commit comments

Comments
 (0)