Skip to content

Commit 06cea84

Browse files
authored
Add autotuning log (#1095)
1 parent 8f4fb7b commit 06cea84

File tree

9 files changed

+538
-99
lines changed

9 files changed

+538
-99
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,4 @@ uv.lock
9595
docs/examples/
9696
docs/sg_execution_times.rst
9797
AGENTS.md
98+
*.csv

docs/api/settings.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:
112112
113113
You can also use ``0`` to completely disable all autotuning output. Controlled by ``HELION_AUTOTUNE_LOG_LEVEL``.
114114
115+
.. autoattribute:: Settings.autotune_log
116+
117+
When set, Helion writes per-config autotuning telemetry (config index, generation, status, perf, compile time, timestamp, config JSON) to ``<value>.csv`` and mirrors the autotune log output to ``<value>.log`` for population-based autotuners (currently ``PatternSearch`` and ``DifferentialEvolution``).
118+
Controlled by ``HELION_AUTOTUNE_LOG``.
119+
115120
.. autoattribute:: Settings.autotune_compile_timeout
116121
117122
Timeout in seconds for Triton compilation during autotuning. Default is ``60``. Controlled by ``HELION_AUTOTUNE_COMPILE_TIMEOUT``.
@@ -250,6 +255,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
250255
| ``HELION_DISALLOW_AUTOTUNING`` | ``check_autotuning_disabled`` | Hard-disable autotuning; kernels must supply explicit configs when this is ``1``. |
251256
| ``HELION_AUTOTUNE_COMPILE_TIMEOUT`` | ``autotune_compile_timeout`` | Maximum seconds to wait for Triton compilation during autotuning. |
252257
| ``HELION_AUTOTUNE_LOG_LEVEL`` | ``autotune_log_level`` | Adjust logging verbosity; accepts names like ``INFO`` or numeric levels. |
258+
| ``HELION_AUTOTUNE_LOG`` | ``autotune_log`` | Base filename for per-config CSV telemetry and mirrored autotune logs. |
253259
| ``HELION_AUTOTUNE_PRECOMPILE`` | ``autotune_precompile`` | Select the autotuner precompile mode (``"fork"`` (default), ``"spawn"``, or disable when empty). |
254260
| ``HELION_AUTOTUNE_PRECOMPILE_JOBS`` | ``autotune_precompile_jobs`` | Cap the number of concurrent Triton precompile subprocesses. |
255261
| ``HELION_AUTOTUNE_RANDOM_SEED`` | ``autotune_random_seed`` | Seed used for randomized autotuning searches. |

helion/autotuner/base_search.py

Lines changed: 100 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
from typing import Callable
2929
from typing import Iterable
3030
from typing import Literal
31+
from typing import NamedTuple
3132
from typing import NoReturn
33+
from typing import Sequence
3234
from typing import cast
3335
from unittest.mock import patch
3436
import uuid
@@ -47,7 +49,8 @@
4749
from .config_generation import ConfigGeneration
4850
from .config_generation import FlatConfig
4951
from .logger import SUPPRESSED_TRITON_CODE_MSG
50-
from .logger import LambdaLogger
52+
from .logger import AutotuneLogEntry
53+
from .logger import AutotuningLogger
5154
from .logger import classify_triton_exception
5255
from .logger import format_triton_compile_failure
5356
from .logger import log_generated_triton_code_debug
@@ -63,8 +66,6 @@
6366
from ..runtime.settings import Settings
6467
from . import ConfigSpec
6568

66-
log = logging.getLogger(__name__)
67-
6869

6970
class BaseAutotuner(abc.ABC):
7071
"""
@@ -76,6 +77,16 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
7677
raise NotImplementedError
7778

7879

80+
class BenchmarkResult(NamedTuple):
81+
"""Result tuple returned by parallel_benchmark."""
82+
83+
config: Config
84+
fn: Callable[..., object]
85+
perf: float
86+
status: Literal["ok", "error", "timeout"]
87+
compile_time: float | None
88+
89+
7990
class BaseSearch(BaseAutotuner):
8091
"""
8192
Base class for search algorithms. This class defines the interface and utilities for all
@@ -109,7 +120,7 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
109120
self.config_spec: ConfigSpec = kernel.config_spec
110121
self.args: Sequence[object] = args
111122
self.counters: collections.Counter[str] = collections.Counter()
112-
self.log = LambdaLogger(self.settings.autotune_log_level)
123+
self.log = AutotuningLogger(self.settings)
113124
self.best_perf_so_far = inf
114125
seed = self.settings.autotune_random_seed
115126
random.seed(seed)
@@ -439,7 +450,7 @@ def start_precompile_and_check_for_hangs(
439450
process.daemon = True
440451
else:
441452
precompiler = _prepare_precompiler_for_fork(
442-
fn, device_args, config, self.kernel, decorator
453+
fn, device_args, config, self.kernel, decorator, self.log
443454
)
444455
if precompiler is None:
445456
return PrecompileFuture.skip(self, config, True)
@@ -463,14 +474,7 @@ def start_precompile_and_check_for_hangs(
463474

464475
def parallel_benchmark(
465476
self, configs: list[Config], *, desc: str = "Benchmarking"
466-
) -> list[
467-
tuple[
468-
Config,
469-
Callable[..., object],
470-
float,
471-
Literal["ok", "error", "timeout"],
472-
]
473-
]:
477+
) -> list[BenchmarkResult]:
474478
"""
475479
Benchmark multiple configurations in parallel.
476480
@@ -479,24 +483,26 @@ def parallel_benchmark(
479483
desc: Description for the progress bar.
480484
481485
Returns:
482-
A list of tuples containing configurations and their performance.
486+
A list of BenchmarkResult entries containing the configuration, compiled
487+
callable, measured performance, status, and compilation time.
483488
"""
484-
fns = [self.kernel.compile_config(c, allow_print=False) for c in configs]
485-
precompile_status: list[Literal["ok", "error", "timeout"]]
489+
fns: list[Callable[..., object]] = []
490+
futures: list[PrecompileFuture] | None = None
491+
for config in configs:
492+
fn = self.kernel.compile_config(config, allow_print=False)
493+
fns.append(fn)
486494
if self.settings.autotune_precompile:
487-
futures = [
488-
*starmap(
495+
futures = list(
496+
starmap(
489497
self.start_precompile_and_check_for_hangs,
490498
zip(configs, fns, strict=True),
491499
)
492-
]
493-
is_workings = PrecompileFuture.wait_for_all(
494-
futures,
495-
desc=f"{desc} precompiling"
496-
if self.settings.autotune_progress_bar
497-
else None,
498500
)
499-
precompile_status = []
501+
precompile_desc = (
502+
f"{desc} precompiling" if self.settings.autotune_progress_bar else None
503+
)
504+
is_workings = PrecompileFuture.wait_for_all(futures, desc=precompile_desc)
505+
precompile_status: list[Literal["ok", "error", "timeout"]] = []
500506
for future, ok in zip(futures, is_workings, strict=True):
501507
reason = future.failure_reason
502508
if ok:
@@ -508,29 +514,52 @@ def parallel_benchmark(
508514
else:
509515
is_workings = [True] * len(configs)
510516
precompile_status = ["ok"] * len(configs)
511-
results: list[
512-
tuple[
513-
Config, Callable[..., object], float, Literal["ok", "error", "timeout"]
514-
]
515-
] = []
517+
518+
results: list[BenchmarkResult] = []
516519

517520
# Render a progress bar only when the user requested it.
518521
iterator = iter_with_progress(
519-
zip(configs, fns, is_workings, precompile_status, strict=True),
522+
enumerate(zip(fns, is_workings, precompile_status, strict=True)),
520523
total=len(configs),
521524
description=f"{desc} exploring neighbors",
522525
enabled=self.settings.autotune_progress_bar,
523526
)
524-
for config, fn, is_working, reason in iterator:
527+
for index, (fn, is_working, reason) in iterator:
528+
config = configs[index]
529+
if futures is not None:
530+
future = futures[index]
531+
compile_time = (
532+
future.elapsed
533+
if future.process is not None and future.started
534+
else None
535+
)
536+
else:
537+
compile_time = None
525538
status: Literal["ok", "error", "timeout"]
526539
if is_working:
527540
# benchmark one-by-one to avoid noisy results
528541
perf = self.benchmark_function(config, fn)
529542
status = "ok" if math.isfinite(perf) else "error"
530-
results.append((config, fn, perf, status))
543+
results.append(
544+
BenchmarkResult(
545+
config=config,
546+
fn=fn,
547+
perf=perf,
548+
status=status,
549+
compile_time=compile_time,
550+
)
551+
)
531552
else:
532553
status = "timeout" if reason == "timeout" else "error"
533-
results.append((config, fn, inf, status))
554+
results.append(
555+
BenchmarkResult(
556+
config=config,
557+
fn=fn,
558+
perf=inf,
559+
status=status,
560+
compile_time=compile_time,
561+
)
562+
)
534563
return results
535564

536565
def autotune(self, *, skip_cache: bool = False) -> Config:
@@ -543,9 +572,11 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
543572
The best configuration found during autotuning.
544573
"""
545574
start = time.perf_counter()
546-
self.log.reset()
547575
exit_stack = contextlib.ExitStack()
548576
with exit_stack:
577+
if self.settings.autotune_log and isinstance(self, PopulationBasedSearch):
578+
exit_stack.enter_context(self.log.autotune_logging())
579+
self.log.reset()
549580
# Autotuner triggers bugs in remote triton compile service
550581
exit_stack.enter_context(
551582
patch.dict(os.environ, {"TRITON_LOCAL_BUILD": "1"}, clear=False)
@@ -600,6 +631,7 @@ class PopulationMember:
600631
flat_values: FlatConfig
601632
config: Config
602633
status: Literal["ok", "error", "timeout", "unknown"] = "unknown"
634+
compile_time: float | None = None
603635

604636
@property
605637
def perf(self) -> float:
@@ -667,6 +699,7 @@ def __init__(
667699
"""
668700
super().__init__(kernel, args)
669701
self.population: list[PopulationMember] = []
702+
self._current_generation: int = 0
670703
overrides = self.settings.autotune_config_overrides or None
671704
self.config_gen: ConfigGeneration = ConfigGeneration(
672705
self.config_spec,
@@ -683,6 +716,9 @@ def best(self) -> PopulationMember:
683716
"""
684717
return min(self.population, key=performance)
685718

719+
def set_generation(self, generation: int) -> None:
720+
self._current_generation = generation
721+
686722
def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
687723
"""
688724
Benchmark a flat configuration.
@@ -694,9 +730,9 @@ def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
694730
A population member with the benchmark results.
695731
"""
696732
config = self.config_gen.unflatten(flat_values)
697-
fn, perf = self.benchmark(config)
698-
status: Literal["ok", "error"] = "ok" if math.isfinite(perf) else "error"
699-
return PopulationMember(fn, [perf], flat_values, config, status=status)
733+
member = PopulationMember(_unset_fn, [], flat_values, config)
734+
self.parallel_benchmark_population([member], desc="Benchmarking")
735+
return member
700736

701737
def parallel_benchmark_flat(
702738
self, to_check: list[FlatConfig]
@@ -737,17 +773,31 @@ def parallel_benchmark_population(
737773
members: The list of population members to benchmark.
738774
desc: Description for the progress bar.
739775
"""
740-
for member, (config_out, fn, perf, status) in zip(
741-
members,
742-
self.parallel_benchmark([m.config for m in members], desc=desc),
743-
strict=True,
744-
):
745-
assert config_out is member.config
746-
member.perfs.append(perf)
747-
member.fn = fn
748-
member.status = status
776+
results = self.parallel_benchmark([m.config for m in members], desc=desc)
777+
for member, result in zip(members, results, strict=True):
778+
assert result.config is member.config
779+
member.perfs.append(result.perf)
780+
member.fn = result.fn
781+
member.status = result.status
782+
member.compile_time = result.compile_time
783+
self._log_population_results(members)
749784
return members
750785

786+
def _log_population_results(self, members: Sequence[PopulationMember]) -> None:
787+
for member in members:
788+
perf_value = member.perf if member.perfs else None
789+
if perf_value is not None and not math.isfinite(perf_value):
790+
perf_value = None
791+
self.log.record_autotune_entry(
792+
AutotuneLogEntry(
793+
generation=self._current_generation,
794+
status=member.status,
795+
perf_ms=perf_value,
796+
compile_time=member.compile_time,
797+
config=member.config,
798+
)
799+
)
800+
751801
def compare(self, a: PopulationMember, b: PopulationMember) -> int:
752802
"""
753803
Compare two population members based on their performance, possibly with re-benchmarking.
@@ -1320,6 +1370,7 @@ def _prepare_precompiler_for_fork(
13201370
config: Config,
13211371
kernel: BoundKernel,
13221372
decorator: str,
1373+
logger: AutotuningLogger,
13231374
) -> Callable[[], None] | None:
13241375
def extract_launcher(
13251376
triton_kernel: object,
@@ -1344,12 +1395,12 @@ def extract_launcher(
13441395
return precompiler
13451396
except Exception:
13461397
log_generated_triton_code_debug(
1347-
log,
1398+
logger,
13481399
kernel,
13491400
config,
13501401
prefix=f"Generated Triton code for {decorator}:",
13511402
)
1352-
log.warning(
1403+
logger.warning(
13531404
"Helion autotuner precompile error for %s. %s",
13541405
decorator,
13551406
SUPPRESSED_TRITON_CODE_MSG,

helion/autotuner/differential_evolution.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def mutate(self, x_index: int) -> FlatConfig:
5656

5757
def initial_two_generations(self) -> None:
5858
# The initial population is 2x larger so we can throw out the slowest half and give the tuning process a head start
59+
self.set_generation(0)
5960
oversized_population = sorted(
6061
self.parallel_benchmark_flat(
6162
self.config_gen.random_population_flat(self.population_size * 2),
@@ -68,16 +69,25 @@ def initial_two_generations(self) -> None:
6869
)
6970
self.population = oversized_population[: self.population_size]
7071

72+
def _benchmark_mutation_batch(
73+
self, indices: Sequence[int]
74+
) -> list[PopulationMember]:
75+
if not indices:
76+
return []
77+
flat_configs = [self.mutate(i) for i in indices]
78+
return self.parallel_benchmark_flat(flat_configs)
79+
7180
def iter_candidates(self) -> Iterator[tuple[int, PopulationMember]]:
7281
if self.immediate_update:
7382
for i in range(len(self.population)):
74-
yield i, self.benchmark_flat(self.mutate(i))
83+
candidates = self._benchmark_mutation_batch([i])
84+
if not candidates:
85+
continue
86+
yield i, candidates[0]
7587
else:
76-
yield from enumerate(
77-
self.parallel_benchmark_flat(
78-
[self.mutate(i) for i in range(len(self.population))]
79-
)
80-
)
88+
indices = list(range(len(self.population)))
89+
candidates = self._benchmark_mutation_batch(indices)
90+
yield from zip(indices, candidates, strict=True)
8191

8292
def evolve_population(self) -> int:
8393
replaced = 0
@@ -96,6 +106,7 @@ def _autotune(self) -> Config:
96106
)
97107
self.initial_two_generations()
98108
for i in range(2, self.max_generations):
109+
self.set_generation(i)
99110
self.log(f"Generation {i} starting")
100111
replaced = self.evolve_population()
101112
self.log(f"Generation {i} complete: replaced={replaced}", self.statistics)

helion/autotuner/finite_search.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@ def __init__(
3535
def _autotune(self) -> Config:
3636
best_config = None
3737
best_time = float("inf")
38-
for config, _fn, time, _status in self.parallel_benchmark(
39-
self.configs, desc="Benchmarking"
40-
):
41-
if time < best_time:
42-
best_time = time
43-
best_config = config
38+
for result in self.parallel_benchmark(self.configs, desc="Benchmarking"):
39+
if result.perf < best_time:
40+
best_time = result.perf
41+
best_config = result.config
4442
assert best_config is not None
4543
return best_config

0 commit comments

Comments
 (0)