2828from typing import Callable
2929from typing import Iterable
3030from typing import Literal
31+ from typing import NamedTuple
3132from typing import NoReturn
33+ from typing import Sequence
3234from typing import cast
3335from unittest .mock import patch
3436import uuid
4749from .config_generation import ConfigGeneration
4850from .config_generation import FlatConfig
4951from .logger import SUPPRESSED_TRITON_CODE_MSG
50- from .logger import LambdaLogger
52+ from .logger import AutotuneLogEntry
53+ from .logger import AutotuningLogger
5154from .logger import classify_triton_exception
5255from .logger import format_triton_compile_failure
5356from .logger import log_generated_triton_code_debug
6366 from ..runtime .settings import Settings
6467 from . import ConfigSpec
6568
66- log = logging .getLogger (__name__ )
67-
6869
6970class BaseAutotuner (abc .ABC ):
7071 """
@@ -76,6 +77,16 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
7677 raise NotImplementedError
7778
7879
80+ class BenchmarkResult (NamedTuple ):
81+ """Result tuple returned by parallel_benchmark."""
82+
83+ config : Config
84+ fn : Callable [..., object ]
85+ perf : float
86+ status : Literal ["ok" , "error" , "timeout" ]
87+ compile_time : float | None
88+
89+
7990class BaseSearch (BaseAutotuner ):
8091 """
8192 Base class for search algorithms. This class defines the interface and utilities for all
@@ -109,7 +120,7 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
109120 self .config_spec : ConfigSpec = kernel .config_spec
110121 self .args : Sequence [object ] = args
111122 self .counters : collections .Counter [str ] = collections .Counter ()
112- self .log = LambdaLogger (self .settings . autotune_log_level )
123+ self .log = AutotuningLogger (self .settings )
113124 self .best_perf_so_far = inf
114125 seed = self .settings .autotune_random_seed
115126 random .seed (seed )
@@ -439,7 +450,7 @@ def start_precompile_and_check_for_hangs(
439450 process .daemon = True
440451 else :
441452 precompiler = _prepare_precompiler_for_fork (
442- fn , device_args , config , self .kernel , decorator
453+ fn , device_args , config , self .kernel , decorator , self . log
443454 )
444455 if precompiler is None :
445456 return PrecompileFuture .skip (self , config , True )
@@ -463,14 +474,7 @@ def start_precompile_and_check_for_hangs(
463474
464475 def parallel_benchmark (
465476 self , configs : list [Config ], * , desc : str = "Benchmarking"
466- ) -> list [
467- tuple [
468- Config ,
469- Callable [..., object ],
470- float ,
471- Literal ["ok" , "error" , "timeout" ],
472- ]
473- ]:
477+ ) -> list [BenchmarkResult ]:
474478 """
475479 Benchmark multiple configurations in parallel.
476480
@@ -479,24 +483,26 @@ def parallel_benchmark(
479483 desc: Description for the progress bar.
480484
481485 Returns:
482- A list of tuples containing configurations and their performance.
486+ A list of BenchmarkResult entries containing the configuration, compiled
487+ callable, measured performance, status, and compilation time.
483488 """
484- fns = [self .kernel .compile_config (c , allow_print = False ) for c in configs ]
485- precompile_status : list [Literal ["ok" , "error" , "timeout" ]]
489+ fns : list [Callable [..., object ]] = []
490+ futures : list [PrecompileFuture ] | None = None
491+ for config in configs :
492+ fn = self .kernel .compile_config (config , allow_print = False )
493+ fns .append (fn )
486494 if self .settings .autotune_precompile :
487- futures = [
488- * starmap (
495+ futures = list (
496+ starmap (
489497 self .start_precompile_and_check_for_hangs ,
490498 zip (configs , fns , strict = True ),
491499 )
492- ]
493- is_workings = PrecompileFuture .wait_for_all (
494- futures ,
495- desc = f"{ desc } precompiling"
496- if self .settings .autotune_progress_bar
497- else None ,
498500 )
499- precompile_status = []
501+ precompile_desc = (
502+ f"{ desc } precompiling" if self .settings .autotune_progress_bar else None
503+ )
504+ is_workings = PrecompileFuture .wait_for_all (futures , desc = precompile_desc )
505+ precompile_status : list [Literal ["ok" , "error" , "timeout" ]] = []
500506 for future , ok in zip (futures , is_workings , strict = True ):
501507 reason = future .failure_reason
502508 if ok :
@@ -508,29 +514,52 @@ def parallel_benchmark(
508514 else :
509515 is_workings = [True ] * len (configs )
510516 precompile_status = ["ok" ] * len (configs )
511- results : list [
512- tuple [
513- Config , Callable [..., object ], float , Literal ["ok" , "error" , "timeout" ]
514- ]
515- ] = []
517+
518+ results : list [BenchmarkResult ] = []
516519
517520 # Render a progress bar only when the user requested it.
518521 iterator = iter_with_progress (
519- zip (configs , fns , is_workings , precompile_status , strict = True ),
522+ enumerate ( zip (fns , is_workings , precompile_status , strict = True ) ),
520523 total = len (configs ),
521524 description = f"{ desc } exploring neighbors" ,
522525 enabled = self .settings .autotune_progress_bar ,
523526 )
524- for config , fn , is_working , reason in iterator :
527+ for index , (fn , is_working , reason ) in iterator :
528+ config = configs [index ]
529+ if futures is not None :
530+ future = futures [index ]
531+ compile_time = (
532+ future .elapsed
533+ if future .process is not None and future .started
534+ else None
535+ )
536+ else :
537+ compile_time = None
525538 status : Literal ["ok" , "error" , "timeout" ]
526539 if is_working :
527540 # benchmark one-by-one to avoid noisy results
528541 perf = self .benchmark_function (config , fn )
529542 status = "ok" if math .isfinite (perf ) else "error"
530- results .append ((config , fn , perf , status ))
543+ results .append (
544+ BenchmarkResult (
545+ config = config ,
546+ fn = fn ,
547+ perf = perf ,
548+ status = status ,
549+ compile_time = compile_time ,
550+ )
551+ )
531552 else :
532553 status = "timeout" if reason == "timeout" else "error"
533- results .append ((config , fn , inf , status ))
554+ results .append (
555+ BenchmarkResult (
556+ config = config ,
557+ fn = fn ,
558+ perf = inf ,
559+ status = status ,
560+ compile_time = compile_time ,
561+ )
562+ )
534563 return results
535564
536565 def autotune (self , * , skip_cache : bool = False ) -> Config :
@@ -543,9 +572,11 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
543572 The best configuration found during autotuning.
544573 """
545574 start = time .perf_counter ()
546- self .log .reset ()
547575 exit_stack = contextlib .ExitStack ()
548576 with exit_stack :
577+ if self .settings .autotune_log and isinstance (self , PopulationBasedSearch ):
578+ exit_stack .enter_context (self .log .autotune_logging ())
579+ self .log .reset ()
549580 # Autotuner triggers bugs in remote triton compile service
550581 exit_stack .enter_context (
551582 patch .dict (os .environ , {"TRITON_LOCAL_BUILD" : "1" }, clear = False )
@@ -600,6 +631,7 @@ class PopulationMember:
600631 flat_values : FlatConfig
601632 config : Config
602633 status : Literal ["ok" , "error" , "timeout" , "unknown" ] = "unknown"
634+ compile_time : float | None = None
603635
604636 @property
605637 def perf (self ) -> float :
@@ -667,6 +699,7 @@ def __init__(
667699 """
668700 super ().__init__ (kernel , args )
669701 self .population : list [PopulationMember ] = []
702+ self ._current_generation : int = 0
670703 overrides = self .settings .autotune_config_overrides or None
671704 self .config_gen : ConfigGeneration = ConfigGeneration (
672705 self .config_spec ,
@@ -683,6 +716,9 @@ def best(self) -> PopulationMember:
683716 """
684717 return min (self .population , key = performance )
685718
719+ def set_generation (self , generation : int ) -> None :
720+ self ._current_generation = generation
721+
686722 def benchmark_flat (self , flat_values : FlatConfig ) -> PopulationMember :
687723 """
688724 Benchmark a flat configuration.
@@ -694,9 +730,9 @@ def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
694730 A population member with the benchmark results.
695731 """
696732 config = self .config_gen .unflatten (flat_values )
697- fn , perf = self . benchmark ( config )
698- status : Literal [ "ok" , "error" ] = "ok" if math . isfinite ( perf ) else "error"
699- return PopulationMember ( fn , [ perf ], flat_values , config , status = status )
733+ member = PopulationMember ( _unset_fn , [], flat_values , config )
734+ self . parallel_benchmark_population ([ member ], desc = "Benchmarking" )
735+ return member
700736
701737 def parallel_benchmark_flat (
702738 self , to_check : list [FlatConfig ]
@@ -737,17 +773,31 @@ def parallel_benchmark_population(
737773 members: The list of population members to benchmark.
738774 desc: Description for the progress bar.
739775 """
740- for member , (config_out , fn , perf , status ) in zip (
741- members ,
742- self .parallel_benchmark ([m .config for m in members ], desc = desc ),
743- strict = True ,
744- ):
745- assert config_out is member .config
746- member .perfs .append (perf )
747- member .fn = fn
748- member .status = status
776+ results = self .parallel_benchmark ([m .config for m in members ], desc = desc )
777+ for member , result in zip (members , results , strict = True ):
778+ assert result .config is member .config
779+ member .perfs .append (result .perf )
780+ member .fn = result .fn
781+ member .status = result .status
782+ member .compile_time = result .compile_time
783+ self ._log_population_results (members )
749784 return members
750785
786+ def _log_population_results (self , members : Sequence [PopulationMember ]) -> None :
787+ for member in members :
788+ perf_value = member .perf if member .perfs else None
789+ if perf_value is not None and not math .isfinite (perf_value ):
790+ perf_value = None
791+ self .log .record_autotune_entry (
792+ AutotuneLogEntry (
793+ generation = self ._current_generation ,
794+ status = member .status ,
795+ perf_ms = perf_value ,
796+ compile_time = member .compile_time ,
797+ config = member .config ,
798+ )
799+ )
800+
751801 def compare (self , a : PopulationMember , b : PopulationMember ) -> int :
752802 """
753803 Compare two population members based on their performance, possibly with re-benchmarking.
@@ -1320,6 +1370,7 @@ def _prepare_precompiler_for_fork(
13201370 config : Config ,
13211371 kernel : BoundKernel ,
13221372 decorator : str ,
1373+ logger : AutotuningLogger ,
13231374) -> Callable [[], None ] | None :
13241375 def extract_launcher (
13251376 triton_kernel : object ,
@@ -1344,12 +1395,12 @@ def extract_launcher(
13441395 return precompiler
13451396 except Exception :
13461397 log_generated_triton_code_debug (
1347- log ,
1398+ logger ,
13481399 kernel ,
13491400 config ,
13501401 prefix = f"Generated Triton code for { decorator } :" ,
13511402 )
1352- log .warning (
1403+ logger .warning (
13531404 "Helion autotuner precompile error for %s. %s" ,
13541405 decorator ,
13551406 SUPPRESSED_TRITON_CODE_MSG ,
0 commit comments