Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,10 @@ def parallel_benchmark(
self.start_precompile_and_check_for_hangs,
zip(configs, fns, strict=True),
)
]
],
desc=f"{desc} precompiling"
if self.settings.autotune_progress_bar
else None,
)
else:
is_workings = [True] * len(configs)
Expand All @@ -336,7 +339,7 @@ def parallel_benchmark(
iterator = iter_with_progress(
zip(configs, fns, is_workings, strict=True),
total=len(configs),
description=desc,
description=f"{desc}: exploring neighbors",
enabled=self.settings.autotune_progress_bar,
)
for config, fn, is_working in iterator:
Expand Down Expand Up @@ -725,6 +728,7 @@ def __call__(self) -> bool:
@staticmethod
def wait_for_all(
futures: list[PrecompileFuture],
desc: str | None = None,
) -> list[bool]:
"""
Wait for all precompile futures to complete.
Expand All @@ -735,10 +739,21 @@ def wait_for_all(
Returns:
A list of boolean values indicating completion status.
"""
progress = iter_with_progress(
range(len(futures)),
total=len(futures),
description=desc,
enabled=desc is not None,
)
next(progress, None) # display the progress bar immediately
progress_left = len(futures)
remaining = [f for f in futures if f.ok is None]
try:
while remaining:
remaining = PrecompileFuture._wait_for_all_step(remaining)
while progress_left > len(remaining):
next(progress, None)
progress_left -= 1
except Exception:
for f in remaining:
if (p := f.process) is not None:
Expand Down
4 changes: 2 additions & 2 deletions helion/autotuner/pattern_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ def _autotune(self) -> Config:
unbenchmarked = [m for m in self.population if len(m.perfs) == 0]
if unbenchmarked:
self.parallel_benchmark_population(
unbenchmarked, desc=f"Generation {generation}: Exploring neighbors"
unbenchmarked, desc=f"Generation {generation}:"
)
# higher-accuracy rebenchmark
self.rebenchmark_population(
self.population, desc=f"Generation {generation}: Verifying top configs"
self.population, desc=f"Generation {generation}: verifying top configs"
)
# Log final statistics for this generation
self.log(f"Generation {generation} complete:", self.statistics)
Expand Down
Loading