Skip to content

Commit 60726b0

Browse files
authored
improve model for prefix cache score (#1770)
1 parent b2ddec6 commit 60726b0

File tree

5 files changed

+647
-182
lines changed

5 files changed

+647
-182
lines changed

latencypredictor-v1/prediction_server.py

Lines changed: 113 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ class LightweightPredictor:
175175

176176
def __init__(self):
177177
mt = settings.MODEL_TYPE
178+
self.prefix_buckets = 4
178179

179180
# Add LightGBM fallback logic
180181
if mt == ModelType.XGBOOST and not XGBOOST_AVAILABLE:
@@ -202,6 +203,54 @@ def is_ready(self) -> bool:
202203
else: # XGBoost or LightGBM
203204
return all([self.ttft_model, self.tpot_model])
204205

206+
def _prepare_features_with_interaction(self, df: pd.DataFrame, model_type: str) -> pd.DataFrame:
207+
"""
208+
Prepare features with interaction terms to match training server.
209+
210+
Args:
211+
df: DataFrame with raw features
212+
model_type: 'ttft' or 'tpot'
213+
214+
Returns:
215+
DataFrame with engineered features including interactions
216+
"""
217+
if model_type == "ttft":
218+
# Create interaction: prefix score * input length
219+
df['effective_input_tokens'] = (1-df['prefix_cache_score']) * df['input_token_length']
220+
df['prefill_score_bucket'] = (
221+
(df['prefix_cache_score'].clip(0, 1) * self.prefix_buckets)
222+
.astype(int)
223+
.clip(upper=self.prefix_buckets - 1)
224+
)
225+
226+
# make it categorical for tree models (safe for LGB, XGB with enable_categorical)
227+
df['prefill_score_bucket'] = pd.Categorical(df['prefill_score_bucket'], categories=[0,1,2,3], ordered=True)
228+
229+
230+
# Return TTFT features with interaction
231+
feature_cols = [
232+
'kv_cache_percentage',
233+
'input_token_length',
234+
'num_request_waiting',
235+
'num_request_running',
236+
'prefix_cache_score',
237+
'effective_input_tokens',
238+
'prefill_score_bucket'
239+
]
240+
241+
return df[feature_cols]
242+
243+
else: # tpot
244+
# TPOT doesn't use prefix_cache_score, so no interaction needed
245+
feature_cols = [
246+
'kv_cache_percentage',
247+
'input_token_length',
248+
'num_request_waiting',
249+
'num_request_running',
250+
'num_tokens_generated'
251+
]
252+
253+
return df[feature_cols]
205254

206255
def load_models(self) -> bool:
207256
try:
@@ -228,33 +277,52 @@ def load_models(self) -> bool:
228277
logging.error(f"Load error: {e}")
229278
return False
230279

231-
# 4. Update predict method to handle LightGBM
232280
def predict(self, features: dict) -> Tuple[float, float]:
233281
"""Make quantile predictions using the loaded models."""
234282
try:
235283
with self.lock:
236284
if not self.is_ready:
237285
raise HTTPException(status_code=503, detail="Models not ready")
238286

239-
# Validation remains the same
240-
required = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting', 'num_request_running', 'num_tokens_generated', 'prefix_cache_score']
287+
# Validation
288+
required = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting',
289+
'num_request_running', 'num_tokens_generated', 'prefix_cache_score']
241290
for f in required:
242291
if f not in features:
243292
raise ValueError(f"Missing required feature: {f}")
244293
if not isinstance(features[f], (int, float)):
245294
raise ValueError(f"Invalid type for feature {f}: expected number")
246295

247-
# Feature columns remain the same
248-
ttft_cols = ['kv_cache_percentage','input_token_length','num_request_waiting','num_request_running','prefix_cache_score']
249-
tpot_cols = ['kv_cache_percentage','input_token_length','num_request_waiting','num_request_running','num_tokens_generated']
296+
# Create raw DataFrames (without interaction)
297+
ttft_raw_data = {
298+
'kv_cache_percentage': features['kv_cache_percentage'],
299+
'input_token_length': features['input_token_length'],
300+
'num_request_waiting': features['num_request_waiting'],
301+
'num_request_running': features['num_request_running'],
302+
'prefix_cache_score': features['prefix_cache_score']
303+
}
304+
305+
tpot_raw_data = {
306+
'kv_cache_percentage': features['kv_cache_percentage'],
307+
'input_token_length': features['input_token_length'],
308+
'num_request_waiting': features['num_request_waiting'],
309+
'num_request_running': features['num_request_running'],
310+
'num_tokens_generated': features['num_tokens_generated']
311+
}
312+
313+
# Prepare features with interactions
314+
df_ttft_raw = pd.DataFrame([ttft_raw_data])
315+
df_ttft = self._prepare_features_with_interaction(df_ttft_raw, "ttft")
316+
250317

251-
# Create DataFrames for predictions
252-
df_ttft = pd.DataFrame([{col: features[col] for col in ttft_cols}])
253-
df_tpot = pd.DataFrame([{col: features[col] for col in tpot_cols}])
318+
df_tpot_raw = pd.DataFrame([tpot_raw_data])
319+
df_tpot = self._prepare_features_with_interaction(df_tpot_raw, "tpot")
320+
#df_tpot = pd.DataFrame([tpot_raw_data])
254321

255322
if self.model_type == ModelType.BAYESIAN_RIDGE:
256-
# Bayesian Ridge logic (unchanged)
257-
ttft_scaled = self.ttft_scaler.transform(df_ttft)
323+
324+
ttft_for_scale = df_ttft.drop(columns=['prefill_score_bucket'], errors='ignore')
325+
ttft_scaled = self.ttft_scaler.transform(ttft_for_scale)
258326
tpot_scaled = self.tpot_scaler.transform(df_tpot)
259327

260328
ttft_pred_mean, ttft_std = self.ttft_model.predict(ttft_scaled, return_std=True)
@@ -267,14 +335,12 @@ def predict(self, features: dict) -> Tuple[float, float]:
267335
return ttft_pred, tpot_pred
268336

269337
elif self.model_type == ModelType.XGBOOST:
270-
# XGBoost logic (unchanged)
271338
ttft_pred = self.ttft_model.predict(df_ttft)
272339
tpot_pred = self.tpot_model.predict(df_tpot)
273340

274341
return ttft_pred[0], tpot_pred[0]
275342

276-
else: # LightGBM - NEW
277-
# LightGBM quantile regression directly predicts the quantile
343+
else: # LightGBM
278344
ttft_pred = self.ttft_model.predict(df_ttft)
279345
tpot_pred = self.tpot_model.predict(df_tpot)
280346

@@ -289,36 +355,56 @@ def predict(self, features: dict) -> Tuple[float, float]:
289355
logging.error("Error in predict():", exc_info=True)
290356
raise HTTPException(status_code=500, detail="Internal error during prediction")
291357

292-
# 5. Update predict_batch method to handle LightGBM
293358
def predict_batch(self, features_list: List[dict]) -> Tuple[np.ndarray, np.ndarray]:
294359
"""Make batch quantile predictions using the loaded models."""
295360
try:
296361
with self.lock:
297362
if not self.is_ready:
298363
raise HTTPException(status_code=503, detail="Models not ready")
299364

300-
# Validation logic remains the same
301-
required = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting', 'num_request_running', 'num_tokens_generated', 'prefix_cache_score']
365+
# Validation
366+
required = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting',
367+
'num_request_running', 'num_tokens_generated', 'prefix_cache_score']
302368
for i, features in enumerate(features_list):
303369
for f in required:
304370
if f not in features:
305371
raise ValueError(f"Missing required feature '{f}' in request {i}")
306372
if not isinstance(features[f], (int, float)):
307373
raise ValueError(f"Invalid type for feature '{f}' in request {i}: expected number")
308374

309-
# Feature columns and DataFrame creation remains the same
310-
ttft_cols = ['kv_cache_percentage','input_token_length','num_request_waiting','num_request_running','prefix_cache_score']
311-
tpot_cols = ['kv_cache_percentage','input_token_length','num_request_waiting','num_request_running','num_tokens_generated']
375+
# Create raw feature data (without interaction)
376+
ttft_raw_data = []
377+
tpot_raw_data = []
378+
379+
for features in features_list:
380+
ttft_raw_data.append({
381+
'kv_cache_percentage': features['kv_cache_percentage'],
382+
'input_token_length': features['input_token_length'],
383+
'num_request_waiting': features['num_request_waiting'],
384+
'num_request_running': features['num_request_running'],
385+
'prefix_cache_score': features['prefix_cache_score']
386+
})
387+
388+
tpot_raw_data.append({
389+
'kv_cache_percentage': features['kv_cache_percentage'],
390+
'input_token_length': features['input_token_length'],
391+
'num_request_waiting': features['num_request_waiting'],
392+
'num_request_running': features['num_request_running'],
393+
'num_tokens_generated': features['num_tokens_generated']
394+
})
312395

313-
ttft_data = [{col: features[col] for col in ttft_cols} for features in features_list]
314-
tpot_data = [{col: features[col] for col in tpot_cols} for features in features_list]
396+
# Prepare features with interactions
397+
df_ttft_raw = pd.DataFrame(ttft_raw_data)
398+
df_ttft_batch = self._prepare_features_with_interaction(df_ttft_raw, "ttft")
399+
#df_ttft_batch = pd.DataFrame(ttft_raw_data)
315400

316-
df_ttft_batch = pd.DataFrame(ttft_data)
317-
df_tpot_batch = pd.DataFrame(tpot_data)
401+
df_tpot_raw = pd.DataFrame(tpot_raw_data)
402+
df_tpot_batch = self._prepare_features_with_interaction(df_tpot_raw, "tpot")
403+
#df_tpot_batch = pd.DataFrame(tpot_raw_data)
318404

319405
if self.model_type == ModelType.BAYESIAN_RIDGE:
320-
# Bayesian Ridge logic (unchanged)
321-
ttft_scaled = self.ttft_scaler.transform(df_ttft_batch)
406+
ttft_for_scale = df_ttft_batch.drop(columns=['prefill_score_bucket'], errors='ignore')
407+
ttft_scaled = self.ttft_scaler.transform(ttft_for_scale)
322408
tpot_scaled = self.tpot_scaler.transform(df_tpot_batch)
323409

324410
ttft_pred_mean, ttft_std = self.ttft_model.predict(ttft_scaled, return_std=True)
@@ -331,14 +417,12 @@ def predict_batch(self, features_list: List[dict]) -> Tuple[np.ndarray, np.ndarr
331417
return ttft_pred, tpot_pred
332418

333419
elif self.model_type == ModelType.XGBOOST:
334-
# XGBoost logic (unchanged)
335420
ttft_pred = self.ttft_model.predict(df_ttft_batch)
336421
tpot_pred = self.tpot_model.predict(df_tpot_batch)
337422

338423
return ttft_pred, tpot_pred
339424

340-
else: # LightGBM - NEW
341-
# LightGBM quantile regression directly predicts the quantile
425+
else: # LightGBM
342426
ttft_pred = self.ttft_model.predict(df_ttft_batch)
343427
tpot_pred = self.tpot_model.predict(df_tpot_batch)
344428

0 commit comments

Comments
 (0)