Skip to content

Commit 7581998

Browse files
ethcheEthan Che
andauthored
Add LFBO Pattern Search (#1115)
Co-authored-by: Ethan Che <eche@fb.com>
1 parent 7acbc82 commit 7581998

File tree

11 files changed

+624
-38
lines changed

11 files changed

+624
-38
lines changed

.github/workflows/benchmark.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ jobs:
100100
- name: Install Helion
101101
run: |
102102
source .venv/bin/activate
103-
SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate]'
103+
SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,surrogate]'
104104
python -c "import helion; print(helion.__name__)"
105105
106106
- name: Install Benchmark Requirements

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
run: |
4242
source .venv/bin/activate
4343
uv pip install pyrefly
44-
uv pip install .'[dev]'
44+
uv pip install .'[dev,surrogate]'
4545
4646
- name: Run pre-commit
4747
run: |

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ jobs:
146146
run: |
147147
source .venv/bin/activate
148148
uv pip install setuptools ninja
149-
SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate]'
149+
SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,surrogate]'
150150
python -c "import helion; print(helion.__name__)"
151151
152152
- name: Run Tests

helion/autotuner/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
from .local_cache import StrictLocalAutotuneCache as StrictLocalAutotuneCache
2020
from .pattern_search import PatternSearch as PatternSearch
2121
from .random_search import RandomSearch as RandomSearch
22+
from .surrogate_pattern_search import LFBOPatternSearch
2223

2324
search_algorithms = {
2425
"DESurrogateHybrid": DESurrogateHybrid,
26+
"LFBOPatternSearch": LFBOPatternSearch,
2527
"DifferentialEvolutionSearch": DifferentialEvolutionSearch,
2628
"FiniteSearch": FiniteSearch,
2729
"PatternSearch": PatternSearch,

helion/autotuner/config_fragment.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ def differential_mutation(self, a: object, b: object, c: object) -> object:
5151
def is_block_size(self) -> bool:
5252
return False
5353

54-
def get_minimum(self) -> int:
54+
def dim(self) -> int:
5555
"""
56-
Return the minimum allowed value for this fragment.
56+
Returns the dimension of the output of encode
5757
"""
5858
raise NotImplementedError
5959

60-
def encode_scalar(self, value: object) -> float:
60+
def encode(self, value: object) -> list[float]:
6161
"""
62-
Encode a configuration value into a float for ML models.
62+
Encode a configuration value into a list of floats for ML models.
6363
6464
This is used by surrogate-assisted algorithms to convert configurations
6565
into numerical vectors for prediction models.
@@ -68,14 +68,15 @@ def encode_scalar(self, value: object) -> float:
6868
value: The configuration value to encode.
6969
7070
Returns:
71-
A float representing the encoded value.
71+
A list of floats representing the encoded value.
7272
"""
73-
# Default: convert to float if possible
74-
if not isinstance(value, (int, float, bool)):
75-
raise TypeError(
76-
f"Cannot encode {type(value).__name__} value {value!r} for ML"
77-
)
78-
return float(value)
73+
raise NotImplementedError
74+
75+
def get_minimum(self) -> int:
76+
"""
77+
Return the minimum allowed value for this fragment.
78+
"""
79+
raise NotImplementedError
7980

8081

8182
@dataclasses.dataclass
@@ -106,6 +107,17 @@ def pattern_neighbors(self, current: object) -> list[object]:
106107
neighbors.append(swapped)
107108
return neighbors
108109

110+
def dim(self) -> int:
111+
return self.length
112+
113+
def encode(self, value: object) -> list[float]:
114+
assert isinstance(value, list)
115+
encoded = []
116+
for val in value:
117+
assert isinstance(val, int)
118+
encoded.append(float(val))
119+
return value
120+
109121

110122
@dataclasses.dataclass
111123
class BaseIntegerFragment(ConfigSpecFragment):
@@ -129,6 +141,9 @@ def clamp(self, val: int) -> int:
129141
def get_minimum(self) -> int:
130142
return self.low
131143

144+
def dim(self) -> int:
145+
return 1
146+
132147
def pattern_neighbors(self, current: object) -> list[object]:
133148
if type(current) is not int: # bool is not allowed
134149
raise TypeError(f"Expected int, got {type(current).__name__}")
@@ -141,13 +156,9 @@ def pattern_neighbors(self, current: object) -> list[object]:
141156
neighbors.append(upper)
142157
return neighbors
143158

144-
def encode_scalar(self, value: object) -> float:
145-
"""Encode integer values directly as floats."""
146-
if not isinstance(value, (int, float)):
147-
raise TypeError(
148-
f"Expected int/float for BaseIntegerFragment, got {type(value).__name__}: {value!r}"
149-
)
150-
return float(value)
159+
def encode(self, value: object) -> list[float]:
160+
assert isinstance(value, int)
161+
return [float(value)]
151162

152163

153164
class PowerOfTwoFragment(BaseIntegerFragment):
@@ -180,7 +191,7 @@ def differential_mutation(self, a: object, b: object, c: object) -> int:
180191
return self.clamp(ai * 2)
181192
return ai
182193

183-
def encode_scalar(self, value: object) -> float:
194+
def encode(self, value: object) -> list[float]:
184195
"""Encode power-of-2 values using log2 transformation."""
185196
import math
186197

@@ -192,7 +203,7 @@ def encode_scalar(self, value: object) -> float:
192203
raise ValueError(
193204
f"Expected positive value for PowerOfTwoFragment, got {value}"
194205
)
195-
return math.log2(float(value))
206+
return [math.log2(float(value))]
196207

197208

198209
class IntegerFragment(BaseIntegerFragment):
@@ -235,7 +246,10 @@ def differential_mutation(self, a: object, b: object, c: object) -> object:
235246
choices.remove(a)
236247
return random.choice(choices)
237248

238-
def encode_scalar(self, value: object) -> float:
249+
def dim(self) -> int:
250+
return len(self.choices)
251+
252+
def encode(self, value: object) -> list[float]:
239253
"""Encode enum values as their index."""
240254
try:
241255
choice_idx = self.choices.index(value)
@@ -244,7 +258,7 @@ def encode_scalar(self, value: object) -> float:
244258
f"Invalid enum value {value!r} for EnumFragment. "
245259
f"Valid choices: {self.choices}"
246260
) from None
247-
return float(choice_idx)
261+
return [1.0 if i == choice_idx else 0.0 for i in range(len(self.choices))]
248262

249263

250264
class BooleanFragment(ConfigSpecFragment):
@@ -265,6 +279,14 @@ def differential_mutation(self, a: object, b: object, c: object) -> bool:
265279
return a
266280
return not a
267281

282+
def dim(self) -> int:
283+
return 1
284+
285+
def encode(self, value: object) -> list[float]:
286+
"""Encode enum values as their index."""
287+
assert isinstance(value, bool)
288+
return [1.0] if value else [0.0]
289+
268290

269291
class BlockSizeFragment(PowerOfTwoFragment):
270292
def category(self) -> Category:
@@ -320,3 +342,13 @@ def differential_mutation(self, a: object, b: object, c: object) -> list[object]
320342
self.inner.differential_mutation(a[i], b[i], c[i])
321343
for i in range(self.length)
322344
]
345+
346+
def dim(self) -> int:
347+
return self.length * self.inner.dim()
348+
349+
def encode(self, value: object) -> list[float]:
350+
assert isinstance(value, list)
351+
encoded = []
352+
for v in value:
353+
encoded.extend(self.inner.encode(v))
354+
return encoded

helion/autotuner/config_generation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ def encode_config(self, flat_config: FlatConfig) -> list[float]:
199199

200200
for flat_idx, spec in enumerate(self.flat_spec):
201201
value = flat_config[flat_idx]
202-
encoded.append(spec.encode_scalar(value))
202+
encoded_value = spec.encode(value)
203+
assert len(encoded_value) == spec.dim()
204+
encoded.extend(encoded_value)
203205

204206
return encoded

helion/autotuner/de_surrogate_hybrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191
if not HAS_ML_DEPS:
9292
raise ImportError(
9393
"DESurrogateHybrid requires numpy and scikit-learn. "
94-
"Install them with: pip install helion[de-surrogate]"
94+
"Install them with: pip install helion[surrogate]"
9595
)
9696

9797
# Initialize parent with early stopping parameters

helion/autotuner/pattern_search.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -134,20 +134,35 @@ def _pattern_search_from(
134134
if len(candidates) <= 1:
135135
return # no new candidates, stop searching
136136
yield candidates # yield new population to benchmark in parallel
137+
# update search copy and check early stopping criteria
137138
best = min(candidates, key=performance)
138-
if best is current:
139-
return # no improvement, stop searching
140-
# Stop if the relative improvement is smaller than a user-specified delta
141-
if (
142-
self.min_improvement_delta > 0.0
143-
and math.isfinite(best.perf)
144-
and math.isfinite(current.perf)
145-
and current.perf != 0.0
146-
and abs(best.perf / current.perf - 1.0) < self.min_improvement_delta
147-
):
139+
if self._check_early_stopping(best, current):
148140
return
149141
current = best
150142

143+
def _check_early_stopping(
144+
self, best: PopulationMember, current: PopulationMember
145+
) -> bool:
146+
"""
147+
Check if early stopping criteria are met for the search copy
148+
149+
Early stops if either the best config has not changed or if
150+
the relative improvement is smaller than a user-specified delta
151+
152+
Returns:
153+
True the search copy is terminated, False otherwise.
154+
"""
155+
if best is current:
156+
return True # no improvement, stop searching
157+
# Stop if the relative improvement is smaller than a user-specified delta
158+
return bool(
159+
self.min_improvement_delta > 0.0
160+
and math.isfinite(best.perf)
161+
and math.isfinite(current.perf)
162+
and current.perf != 0.0
163+
and abs(best.perf / current.perf - 1.0) < self.min_improvement_delta
164+
)
165+
151166
def _generate_neighbors(self, base: FlatConfig) -> list[FlatConfig]:
152167
"""
153168
Generate neighboring configurations by changing one or two parameters at a time.

0 commit comments

Comments
 (0)