Skip to content

Commit 7eb8801

Browse files
Add DE-Surrogate hybrid autotuner algorithm
This introduces DESurrogateHybrid, a novel hybrid optimization algorithm that combines Differential Evolution's robust exploration with Random Forest surrogate model's sample efficiency for GPU kernel autotuning. Key features: - Generates 3× more candidates than standard DE but only evaluates the most promising ones as predicted by the Random Forest surrogate - Achieves 6.53% average performance improvement over standard DE - 1.20× faster wall-clock time despite evaluating more configurations - Learns kernel-specific optimization patterns automatically Implementation: - Works directly with Helion's discrete parameter spaces - Uses ConfigEncoder to convert configurations to numerical vectors - Refits surrogate model every 5 generations for continuous learning - Configurable parameters: population_size, candidate_ratio, surrogate_threshold Testing on 3 diverse kernels (MatMul, GELU, FusedReLU) shows: - MatMul (compute-bound): -15.0% improvement, 1.39× faster convergence - GELU (bandwidth-bound): -5.4% improvement - FusedReLU (memory-bound): +0.8% (competitive, within margin)
1 parent ec12380 commit 7eb8801

File tree

2 files changed

+306
-0
lines changed

2 files changed

+306
-0
lines changed

helion/autotuner/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .config_fragment import ListOf as ListOf
77
from .config_fragment import PowerOfTwoFragment as PowerOfTwoFragment
88
from .config_spec import ConfigSpec as ConfigSpec
9+
from .de_surrogate_hybrid import DESurrogateHybrid as DESurrogateHybrid
910
from .differential_evolution import (
1011
DifferentialEvolutionSearch as DifferentialEvolutionSearch,
1112
)
@@ -20,6 +21,7 @@
2021
from .random_search import RandomSearch as RandomSearch
2122

2223
search_algorithms = {
24+
"DESurrogateHybrid": DESurrogateHybrid,
2325
"DifferentialEvolutionSearch": DifferentialEvolutionSearch,
2426
"FiniteSearch": FiniteSearch,
2527
"PatternSearch": PatternSearch,
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
"""
2+
Differential Evolution with Surrogate-Assisted Selection (DE-SAS).
3+
4+
This hybrid approach combines the robust exploration of Differential Evolution
5+
with the sample efficiency of surrogate models. It's designed to beat standard DE
6+
by making smarter decisions about which candidates to evaluate.
7+
8+
Key idea:
9+
- Use DE's mutation/crossover to generate candidates (good exploration)
10+
- Use a Random Forest surrogate to predict which candidates are promising
11+
- Only evaluate the most promising candidates (sample efficiency)
12+
- Periodically re-fit the surrogate model
13+
14+
This is inspired by recent work on surrogate-assisted evolutionary algorithms,
15+
which have shown 2-5× speedups over standard EAs on expensive optimization problems.
16+
17+
References:
18+
- Jin, Y. (2011). "Surrogate-assisted evolutionary computation: Recent advances and future challenges."
19+
- Sun, C., et al. (2019). "A surrogate-assisted DE with an adaptive local search"
20+
21+
Author: Francisco Geiman Thiesen
22+
Date: 2025-11-05
23+
"""
24+
25+
from __future__ import annotations
26+
27+
import operator
28+
import random
29+
from typing import TYPE_CHECKING
30+
31+
import numpy as np
32+
from sklearn.ensemble import RandomForestRegressor
33+
34+
from .base_search import PopulationBasedSearch
35+
from .config_encoding import ConfigEncoder
36+
37+
if TYPE_CHECKING:
38+
from collections.abc import Sequence
39+
40+
from ..runtime.config import Config
41+
from ..runtime.kernel import BoundKernel
42+
from .config_generation import FlatConfig
43+
44+
45+
class DESurrogateHybrid(PopulationBasedSearch):
46+
"""
47+
Hybrid Differential Evolution with Surrogate-Assisted Selection.
48+
49+
This algorithm uses DE for exploration but adds a surrogate model to intelligently
50+
select which candidates to actually evaluate, avoiding wasting evaluations on
51+
poor candidates.
52+
53+
Args:
54+
kernel: The bound kernel to tune
55+
args: Arguments for the kernel
56+
population_size: Size of the DE population
57+
max_generations: Maximum number of generations
58+
crossover_rate: Crossover probability (default: 0.8)
59+
surrogate_threshold: Use surrogate after this many evaluations (default: 100)
60+
candidate_ratio: Generate this many× candidates per slot (default: 3)
61+
refit_frequency: Refit surrogate every N generations (default: 5)
62+
n_estimators: Number of trees in Random Forest (default: 50)
63+
"""
64+
65+
def __init__(
66+
self,
67+
kernel: BoundKernel,
68+
args: Sequence[object],
69+
population_size: int = 40,
70+
max_generations: int = 40,
71+
crossover_rate: float = 0.8,
72+
surrogate_threshold: int = 100,
73+
candidate_ratio: int = 3,
74+
refit_frequency: int = 5,
75+
n_estimators: int = 50,
76+
) -> None:
77+
super().__init__(kernel, args)
78+
79+
self.population_size = population_size
80+
self.max_generations = max_generations
81+
self.crossover_rate = crossover_rate
82+
self.surrogate_threshold = surrogate_threshold
83+
self.candidate_ratio = candidate_ratio
84+
self.refit_frequency = refit_frequency
85+
self.n_estimators = n_estimators
86+
87+
# Config encoder for surrogate model
88+
self.encoder = ConfigEncoder(self.config_gen)
89+
90+
# Surrogate model
91+
self.surrogate: RandomForestRegressor | None = None
92+
93+
# Track all evaluations for surrogate training
94+
self.all_observations: list[tuple[FlatConfig, float]] = []
95+
96+
def _autotune(self) -> Config:
97+
"""
98+
Run DE with surrogate-assisted selection.
99+
100+
Returns:
101+
Best configuration found
102+
"""
103+
self.log("=" * 70)
104+
self.log("Differential Evolution with Surrogate-Assisted Selection")
105+
self.log("=" * 70)
106+
self.log(f"Population: {self.population_size}")
107+
self.log(f"Generations: {self.max_generations}")
108+
self.log(f"Crossover rate: {self.crossover_rate}")
109+
self.log(f"Surrogate activation: after {self.surrogate_threshold} evals")
110+
self.log(f"Candidate oversampling: {self.candidate_ratio}× per slot")
111+
self.log("=" * 70)
112+
113+
# Initialize population
114+
self._initialize_population()
115+
116+
# Evolution loop
117+
for gen in range(2, self.max_generations + 1):
118+
self._evolve_generation(gen)
119+
120+
# Return best config
121+
best = min(self.population, key=lambda m: m.perf)
122+
self.log("=" * 70)
123+
self.log(f"✓ Best configuration: {best.perf:.4f} ms")
124+
self.log(f"Total evaluations: {len(self.all_observations)}")
125+
self.log("=" * 70)
126+
127+
return best.config
128+
129+
def _initialize_population(self) -> None:
130+
"""Initialize population with random configs."""
131+
self.log(f"\nInitializing population ({self.population_size * 2} configs)")
132+
133+
# Generate initial population (2× size for good coverage)
134+
configs = [
135+
self.config_gen.random_flat() for _ in range(self.population_size * 2)
136+
]
137+
members = self.parallel_benchmark_flat(configs)
138+
139+
# Track observations
140+
for member in members:
141+
if member.perf != float("inf"):
142+
self.all_observations.append((member.flat_values, member.perf))
143+
144+
# Keep top population_size members
145+
valid_members = [m for m in members if m.perf != float("inf")]
146+
valid_members.sort(key=lambda m: m.perf)
147+
self.population = valid_members[: self.population_size]
148+
149+
# Pad with random if needed
150+
while len(self.population) < self.population_size:
151+
config = self.config_gen.random_flat()
152+
member = self.benchmark_flat(config)
153+
if member.perf != float("inf"):
154+
self.population.append(member)
155+
self.all_observations.append((member.flat_values, member.perf))
156+
157+
best_perf = min(m.perf for m in self.population)
158+
self.log(
159+
f"Population initialized: "
160+
f"best={best_perf:.4f} ms, size={len(self.population)}"
161+
)
162+
163+
def _evolve_generation(self, generation: int) -> None:
164+
"""Run one generation of DE with surrogate assistance."""
165+
166+
# Refit surrogate periodically
167+
use_surrogate = len(self.all_observations) >= self.surrogate_threshold
168+
if use_surrogate and (generation % self.refit_frequency == 0):
169+
self._fit_surrogate()
170+
171+
# Generate candidates using DE mutation/crossover
172+
if use_surrogate:
173+
# Generate more candidates and use surrogate to select best
174+
n_candidates = self.population_size * self.candidate_ratio
175+
candidates = self._generate_de_candidates(n_candidates)
176+
selected_candidates = self._surrogate_select(
177+
candidates, self.population_size
178+
)
179+
else:
180+
# Standard DE: generate and evaluate all
181+
selected_candidates = self._generate_de_candidates(self.population_size)
182+
183+
# Evaluate selected candidates
184+
new_members = self.parallel_benchmark_flat(selected_candidates)
185+
186+
# Track observations
187+
for member in new_members:
188+
if member.perf != float("inf"):
189+
self.all_observations.append((member.flat_values, member.perf))
190+
191+
# Selection: keep better of old vs new for each position
192+
replacements = 0
193+
for i, new_member in enumerate(new_members):
194+
if new_member.perf < self.population[i].perf:
195+
self.population[i] = new_member
196+
replacements += 1
197+
198+
# Log progress
199+
best_perf = min(m.perf for m in self.population)
200+
surrogate_status = "SURROGATE" if use_surrogate else "STANDARD"
201+
self.log(
202+
f"Gen {generation}: {surrogate_status} | "
203+
f"best={best_perf:.4f} ms | replaced={replacements}/{self.population_size} | "
204+
f"total_evals={len(self.all_observations)}"
205+
)
206+
207+
def _generate_de_candidates(self, n_candidates: int) -> list[FlatConfig]:
208+
"""Generate candidates using standard DE mutation/crossover."""
209+
candidates = []
210+
211+
for _ in range(n_candidates):
212+
# Select four distinct individuals: x (base), and a, b, c for mutation
213+
x, a, b, c = random.sample(self.population, 4)
214+
215+
# Differential mutation: x + F(a - b + c)
216+
trial = self.config_gen.differential_mutation(
217+
x.flat_values,
218+
a.flat_values,
219+
b.flat_values,
220+
c.flat_values,
221+
crossover_rate=self.crossover_rate,
222+
)
223+
224+
candidates.append(trial)
225+
226+
return candidates
227+
228+
def _fit_surrogate(self) -> None:
229+
"""Fit Random Forest surrogate model on all observations."""
230+
if len(self.all_observations) < 10:
231+
return # Need minimum data
232+
233+
# Encode configs to numeric arrays
234+
X = []
235+
y = []
236+
237+
for config, perf in self.all_observations:
238+
try:
239+
encoded = self.encoder.encode(config)
240+
X.append(encoded)
241+
y.append(perf)
242+
except Exception:
243+
continue
244+
245+
if len(X) < 10:
246+
return
247+
248+
X_array = np.array(X)
249+
y_array = np.array(y)
250+
251+
# Fit Random Forest
252+
self.surrogate = RandomForestRegressor(
253+
n_estimators=self.n_estimators,
254+
max_depth=15,
255+
min_samples_split=5,
256+
min_samples_leaf=2,
257+
random_state=42,
258+
n_jobs=-1,
259+
)
260+
261+
self.surrogate.fit(X_array, y_array)
262+
263+
def _surrogate_select(
264+
self, candidates: list[FlatConfig], n_select: int
265+
) -> list[FlatConfig]:
266+
"""
267+
Use surrogate model to select most promising candidates.
268+
269+
Args:
270+
candidates: Pool of candidate configurations
271+
n_select: Number of candidates to select
272+
273+
Returns:
274+
Selected candidates predicted to be best
275+
"""
276+
if self.surrogate is None:
277+
# Fallback: random selection
278+
return random.sample(candidates, min(n_select, len(candidates)))
279+
280+
# Predict performance for all candidates
281+
predictions = []
282+
283+
for config in candidates:
284+
try:
285+
encoded = self.encoder.encode(config)
286+
pred = self.surrogate.predict([encoded])[0]
287+
predictions.append((config, pred))
288+
except Exception:
289+
# Skip encoding failures
290+
predictions.append((config, float("inf")))
291+
292+
# Sort by predicted performance (lower is better)
293+
predictions.sort(key=operator.itemgetter(1))
294+
295+
# Select top n_select candidates
296+
return [config for config, pred in predictions[:n_select]]
297+
298+
def __repr__(self) -> str:
299+
return (
300+
f"DESurrogateHybrid(pop={self.population_size}, "
301+
f"gen={self.max_generations}, "
302+
f"cr={self.crossover_rate}, "
303+
f"surrogate_threshold={self.surrogate_threshold})"
304+
)

0 commit comments

Comments
 (0)