From 3941a7fdc116345b0ed2bfd08d34d8fbf837c8f3 Mon Sep 17 00:00:00 2001 From: Jerry Hou Date: Mon, 1 Dec 2025 17:35:40 -0800 Subject: [PATCH] Entropy warmup enhancements (#651) Summary: This diff fixes an implementation error that was impacting slope and r^2 computations. It also includes optimization to the entropy warmup logic that should improve its performance with smaller/faster kernels. Further numerical stability optimization were brought in from the cross-pr on nvbench https://github.com/NVIDIA/nvbench/pull/286 Since entropy tracks the frequency of appearence for each unique time element, the rounding precision applied on individual latency element can impact the characteristics of entropy convergence. This diff introduces logic to dynamically increase rounding precision to maintain a balance between entropy sensitivity and trend detection Reviewed By: xuzhao9 Differential Revision: D87379814 --- .../do_bench/entropy/entropy_criterion.py | 61 +++++++++++-------- tritonbench/components/do_bench/run.py | 42 ++++++++++--- 2 files changed, 67 insertions(+), 36 deletions(-) diff --git a/tritonbench/components/do_bench/entropy/entropy_criterion.py b/tritonbench/components/do_bench/entropy/entropy_criterion.py index d2872b1aa..2df20f9b4 100644 --- a/tritonbench/components/do_bench/entropy/entropy_criterion.py +++ b/tritonbench/components/do_bench/entropy/entropy_criterion.py @@ -99,12 +99,17 @@ def _update_entropy_sum(self, old_count: int, new_count: int) -> None: new_count: New count (0 if removing unique value) """ # Remove old contribution: S -= old_count * log2(old_count) - if old_count > 0: - self._sum_count_log_count -= old_count * math.log2(old_count) - - # Add new contribution: S += new_count * log2(new_count) - if new_count > 0: - self._sum_count_log_count += new_count * math.log2(new_count) + # Optimization: nlog(n) - olog(o) = nlog(1+(n-o)/o) + (n - o)log(o) + if old_count > 0 and new_count > 0: + delta = new_count - old_count + self._sum_count_log_count += ( + new_count * math.log2(1 + delta / old_count) + delta * math.log2(old_count) + ) + else: + if old_count > 0: + self._sum_count_log_count -= old_count * math.log2(old_count) + if new_count > 0: + self._sum_count_log_count += new_count * math.log2(new_count) def _compute_entropy(self) -> float: """ @@ -121,7 +126,7 @@ def _compute_entropy(self) -> float: # Entropy formula: H = log2(n) - S/n entropy = math.log2(n) - (self._sum_count_log_count / n) - return entropy + return max(0.0, entropy) def add_measurement(self, measurement: float) -> None: """ @@ -157,24 +162,20 @@ def add_measurement(self, measurement: float) -> None: # Update running statistics for linear regression # If entropy_tracker is full, remove oldest component from running stats + # removal index in the sliding window = 0 if len(self.entropy_tracker) == self.window_size: old_entropy = self.entropy_tracker[0] - old_x = 0 # Oldest position in the sliding window - old_sum_x = self._sum_x # Remove old values from running sums - self._sum_x -= old_x self._sum_y -= old_entropy - self._sum_xy -= old_x * old_entropy - self._sum_x2 -= old_x * old_x self._sum_y2 -= old_entropy * old_entropy - self._n -= 1 # Remove element's effect from sum of squares - n = self._n - self._sum_x2 -= 2 * old_sum_x + n # Use saved old_sum_x + n = self._n - 1 self._sum_x -= n + self._sum_x2 -= 2 * self._sum_x + n self._sum_xy -= self._sum_y + self._n -= 1 # Add new entropy value to running stats x = self._n @@ -214,10 +215,11 @@ def is_finished(self) -> bool: mean_y = self._sum_y / n # Compute slope using cached statistics - numerator = self._sum_xy - n * mean_x * mean_y - denominator = self._sum_x2 - n * mean_x * mean_x + # scaled down by 1/n to avoid overflow + numerator = self._sum_xy / n - mean_x * mean_y + denominator = self._sum_x2 / n - mean_x * mean_x - if denominator == 0: + if abs(denominator) < 1e-12: return False slope = numerator / denominator @@ -227,21 +229,22 @@ def is_finished(self) -> bool: slope_degrees = math.degrees(math.atan(slope)) # Compute total sum of squares (TSS) - ss_tot = self._sum_y2 - n * mean_y * mean_y + # ss_tot and ss_res scaled by 1/n to avoid overflow + ss_tot = (self._sum_y2 / n) - mean_y * mean_y # Calculate residual sum of squares (RSS) using the cached value # ss_res = Σ(y - (slope*x + intercept))² expanded ss_res = ( - self._sum_y2 - - 2 * slope * self._sum_xy - - 2 * intercept * self._sum_y - + slope * slope * self._sum_x2 - + 2 * slope * intercept * self._sum_x - + n * intercept * intercept + (self._sum_y2 / n) + - 2 * slope * (self._sum_xy / n) + - 2 * intercept * (self._sum_y / n) + + slope * slope * (self._sum_x2 / n) + + 2 * slope * intercept * (self._sum_x / n) + + intercept * intercept ) - # If ss_tot == 0, entropy values are identical => perfect stability - if ss_tot == 0: + # If ss_tot < epsilon, entropy values are identical => perfect stability + if abs(ss_tot) < 1e-12: r2 = 1.0 else: r2 = max(0.0, min(1.0, 1 - (ss_res / ss_tot))) @@ -263,6 +266,10 @@ def is_finished(self) -> bool: return True + def get_convergence_info(self) -> dict: + """Get the last convergence check information.""" + return getattr(self, "_last_convergence_check", {}) + def get_stats(self) -> dict: """ Get current statistics for debugging/monitoring. diff --git a/tritonbench/components/do_bench/run.py b/tritonbench/components/do_bench/run.py index 9f0d4e8dc..91f4c5b29 100644 --- a/tritonbench/components/do_bench/run.py +++ b/tritonbench/components/do_bench/run.py @@ -505,22 +505,25 @@ def _do_bench_entropy( assert return_mode in ["min", "max", "mean", "median", "all"] # ENTROPY-BASED WARMUP - criterion = EntropyCriterion( + entropy_criterion = EntropyCriterion( max_angle=max_angle, min_r2=min_r2, window_size=window_size, min_warmup_samples=min_warmup_samples, ) - criterion.reset() - BATCH_SIZE = 20 + entropy_criterion.reset() + + rounding_factor = 3 + BATCH_SIZE = 50 last_batch = [-1.00] * BATCH_SIZE counter = 0 converged = False + precision_increase = False cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() # Adaptive warmup loop with batched synchronization - while not criterion.is_finished(): + while True: remaining = max_samples - counter batch_size = min(BATCH_SIZE, remaining) if remaining > 0 else BATCH_SIZE @@ -543,20 +546,41 @@ def _do_bench_entropy( torch.cuda.synchronize() for i in range(batch_size): - v = round(batch_start_events[i].elapsed_time(batch_end_events[i]), 3) - criterion.add_measurement(v) + v = round(batch_start_events[i].elapsed_time(batch_end_events[i]), rounding_factor) last_batch[i] = v + + entropy_criterion.add_measurement(v) + + if entropy_criterion.is_finished(): + converged = True + break + counter += batch_size + if converged: + break + if counter >= max_samples: break - else: - converged = True + + if counter >= 200 and not precision_increase: + stats = entropy_criterion.get_stats() + unique_count = stats.get('unique_measurements', 0) + + # If we have < 20 unique values, this indicates quantization, increase rounding precision + if unique_count < 20: + rounding_factor = 4 + entropy_criterion.reset() + entropy_criterion.entropy_window_size = 1000 + + logger.info(f"Quantization detected: only {unique_count} unique measurements. ") + precision_increase = True + # Log if warmup didn't converge if not converged: logger.warning( - f"Entropy warmup did not converge after {counter} samples " + f"Warmup did not converge after {counter} samples " f"(max_samples={max_samples})" )