@@ -175,6 +175,7 @@ class LightweightPredictor:
175175
176176 def __init__ (self ):
177177 mt = settings .MODEL_TYPE
178+ self .prefix_buckets = 4
178179
179180 # Add LightGBM fallback logic
180181 if mt == ModelType .XGBOOST and not XGBOOST_AVAILABLE :
@@ -202,6 +203,54 @@ def is_ready(self) -> bool:
202203 else : # XGBoost or LightGBM
203204 return all ([self .ttft_model , self .tpot_model ])
204205
206+ def _prepare_features_with_interaction (self , df : pd .DataFrame , model_type : str ) -> pd .DataFrame :
207+ """
208+ Prepare features with interaction terms to match training server.
209+
210+ Args:
211+ df: DataFrame with raw features
212+ model_type: 'ttft' or 'tpot'
213+
214+ Returns:
215+ DataFrame with engineered features including interactions
216+ """
217+ if model_type == "ttft" :
218+ # Create interaction: prefix score * input length
219+ df ['effective_input_tokens' ] = (1 - df ['prefix_cache_score' ]) * df ['input_token_length' ]
220+ df ['prefill_score_bucket' ] = (
221+ (df ['prefix_cache_score' ].clip (0 , 1 ) * self .prefix_buckets )
222+ .astype (int )
223+ .clip (upper = self .prefix_buckets - 1 )
224+ )
225+
226+ # make it categorical for tree models (safe for LGB, XGB with enable_categorical)
227+ df ['prefill_score_bucket' ] = pd .Categorical (df ['prefill_score_bucket' ], categories = [0 ,1 ,2 ,3 ], ordered = True )
228+
229+
230+ # Return TTFT features with interaction
231+ feature_cols = [
232+ 'kv_cache_percentage' ,
233+ 'input_token_length' ,
234+ 'num_request_waiting' ,
235+ 'num_request_running' ,
236+ 'prefix_cache_score' ,
237+ 'effective_input_tokens' ,
238+ 'prefill_score_bucket'
239+ ]
240+
241+ return df [feature_cols ]
242+
243+ else : # tpot
244+ # TPOT doesn't use prefix_cache_score, so no interaction needed
245+ feature_cols = [
246+ 'kv_cache_percentage' ,
247+ 'input_token_length' ,
248+ 'num_request_waiting' ,
249+ 'num_request_running' ,
250+ 'num_tokens_generated'
251+ ]
252+
253+ return df [feature_cols ]
205254
206255 def load_models (self ) -> bool :
207256 try :
@@ -228,33 +277,52 @@ def load_models(self) -> bool:
228277 logging .error (f"Load error: { e } " )
229278 return False
230279
231- # 4. Update predict method to handle LightGBM
232280 def predict (self , features : dict ) -> Tuple [float , float ]:
233281 """Make quantile predictions using the loaded models."""
234282 try :
235283 with self .lock :
236284 if not self .is_ready :
237285 raise HTTPException (status_code = 503 , detail = "Models not ready" )
238286
239- # Validation remains the same
240- required = ['kv_cache_percentage' , 'input_token_length' , 'num_request_waiting' , 'num_request_running' , 'num_tokens_generated' , 'prefix_cache_score' ]
287+ # Validation
288+ required = ['kv_cache_percentage' , 'input_token_length' , 'num_request_waiting' ,
289+ 'num_request_running' , 'num_tokens_generated' , 'prefix_cache_score' ]
241290 for f in required :
242291 if f not in features :
243292 raise ValueError (f"Missing required feature: { f } " )
244293 if not isinstance (features [f ], (int , float )):
245294 raise ValueError (f"Invalid type for feature { f } : expected number" )
246295
247- # Feature columns remain the same
248- ttft_cols = ['kv_cache_percentage' ,'input_token_length' ,'num_request_waiting' ,'num_request_running' ,'prefix_cache_score' ]
249- tpot_cols = ['kv_cache_percentage' ,'input_token_length' ,'num_request_waiting' ,'num_request_running' ,'num_tokens_generated' ]
296+ # Create raw DataFrames (without interaction)
297+ ttft_raw_data = {
298+ 'kv_cache_percentage' : features ['kv_cache_percentage' ],
299+ 'input_token_length' : features ['input_token_length' ],
300+ 'num_request_waiting' : features ['num_request_waiting' ],
301+ 'num_request_running' : features ['num_request_running' ],
302+ 'prefix_cache_score' : features ['prefix_cache_score' ]
303+ }
304+
305+ tpot_raw_data = {
306+ 'kv_cache_percentage' : features ['kv_cache_percentage' ],
307+ 'input_token_length' : features ['input_token_length' ],
308+ 'num_request_waiting' : features ['num_request_waiting' ],
309+ 'num_request_running' : features ['num_request_running' ],
310+ 'num_tokens_generated' : features ['num_tokens_generated' ]
311+ }
312+
313+ # Prepare features with interactions
314+ df_ttft_raw = pd .DataFrame ([ttft_raw_data ])
315+ df_ttft = self ._prepare_features_with_interaction (df_ttft_raw , "ttft" )
316+
250317
251- # Create DataFrames for predictions
252- df_ttft = pd . DataFrame ([{ col : features [ col ] for col in ttft_cols }] )
253- df_tpot = pd .DataFrame ([{ col : features [ col ] for col in tpot_cols } ])
318+ df_tpot_raw = pd . DataFrame ([ tpot_raw_data ])
319+ df_tpot = self . _prepare_features_with_interaction ( df_tpot_raw , "tpot" )
320+ # df_tpot = pd.DataFrame([tpot_raw_data ])
254321
255322 if self .model_type == ModelType .BAYESIAN_RIDGE :
256- # Bayesian Ridge logic (unchanged)
257- ttft_scaled = self .ttft_scaler .transform (df_ttft )
323+
324+ ttft_for_scale = df_ttft .drop (columns = ['prefill_score_bucket' ], errors = 'ignore' )
325+ ttft_scaled = self .ttft_scaler .transform (ttft_for_scale )
258326 tpot_scaled = self .tpot_scaler .transform (df_tpot )
259327
260328 ttft_pred_mean , ttft_std = self .ttft_model .predict (ttft_scaled , return_std = True )
@@ -267,14 +335,12 @@ def predict(self, features: dict) -> Tuple[float, float]:
267335 return ttft_pred , tpot_pred
268336
269337 elif self .model_type == ModelType .XGBOOST :
270- # XGBoost logic (unchanged)
271338 ttft_pred = self .ttft_model .predict (df_ttft )
272339 tpot_pred = self .tpot_model .predict (df_tpot )
273340
274341 return ttft_pred [0 ], tpot_pred [0 ]
275342
276- else : # LightGBM - NEW
277- # LightGBM quantile regression directly predicts the quantile
343+ else : # LightGBM
278344 ttft_pred = self .ttft_model .predict (df_ttft )
279345 tpot_pred = self .tpot_model .predict (df_tpot )
280346
@@ -289,36 +355,56 @@ def predict(self, features: dict) -> Tuple[float, float]:
289355 logging .error ("Error in predict():" , exc_info = True )
290356 raise HTTPException (status_code = 500 , detail = "Internal error during prediction" )
291357
292- # 5. Update predict_batch method to handle LightGBM
293358 def predict_batch (self , features_list : List [dict ]) -> Tuple [np .ndarray , np .ndarray ]:
294359 """Make batch quantile predictions using the loaded models."""
295360 try :
296361 with self .lock :
297362 if not self .is_ready :
298363 raise HTTPException (status_code = 503 , detail = "Models not ready" )
299364
300- # Validation logic remains the same
301- required = ['kv_cache_percentage' , 'input_token_length' , 'num_request_waiting' , 'num_request_running' , 'num_tokens_generated' , 'prefix_cache_score' ]
365+ # Validation
366+ required = ['kv_cache_percentage' , 'input_token_length' , 'num_request_waiting' ,
367+ 'num_request_running' , 'num_tokens_generated' , 'prefix_cache_score' ]
302368 for i , features in enumerate (features_list ):
303369 for f in required :
304370 if f not in features :
305371 raise ValueError (f"Missing required feature '{ f } ' in request { i } " )
306372 if not isinstance (features [f ], (int , float )):
307373 raise ValueError (f"Invalid type for feature '{ f } ' in request { i } : expected number" )
308374
309- # Feature columns and DataFrame creation remains the same
310- ttft_cols = ['kv_cache_percentage' ,'input_token_length' ,'num_request_waiting' ,'num_request_running' ,'prefix_cache_score' ]
311- tpot_cols = ['kv_cache_percentage' ,'input_token_length' ,'num_request_waiting' ,'num_request_running' ,'num_tokens_generated' ]
375+ # Create raw feature data (without interaction)
376+ ttft_raw_data = []
377+ tpot_raw_data = []
378+
379+ for features in features_list :
380+ ttft_raw_data .append ({
381+ 'kv_cache_percentage' : features ['kv_cache_percentage' ],
382+ 'input_token_length' : features ['input_token_length' ],
383+ 'num_request_waiting' : features ['num_request_waiting' ],
384+ 'num_request_running' : features ['num_request_running' ],
385+ 'prefix_cache_score' : features ['prefix_cache_score' ]
386+ })
387+
388+ tpot_raw_data .append ({
389+ 'kv_cache_percentage' : features ['kv_cache_percentage' ],
390+ 'input_token_length' : features ['input_token_length' ],
391+ 'num_request_waiting' : features ['num_request_waiting' ],
392+ 'num_request_running' : features ['num_request_running' ],
393+ 'num_tokens_generated' : features ['num_tokens_generated' ]
394+ })
312395
313- ttft_data = [{col : features [col ] for col in ttft_cols } for features in features_list ]
314- tpot_data = [{col : features [col ] for col in tpot_cols } for features in features_list ]
396+ # Prepare features with interactions
397+ df_ttft_raw = pd .DataFrame (ttft_raw_data )
398+ df_ttft_batch = self ._prepare_features_with_interaction (df_ttft_raw , "ttft" )
399+ #df_ttft_batch = pd.DataFrame(ttft_raw_data)
315400
316- df_ttft_batch = pd .DataFrame (ttft_data )
317- df_tpot_batch = pd .DataFrame (tpot_data )
401+ df_tpot_raw = pd .DataFrame (tpot_raw_data )
402+ df_tpot_batch = self ._prepare_features_with_interaction (df_tpot_raw , "tpot" )
403+ #df_tpot_batch = pd.DataFrame(tpot_raw_data)
318404
319405 if self .model_type == ModelType .BAYESIAN_RIDGE :
320- # Bayesian Ridge logic (unchanged )
321- ttft_scaled = self .ttft_scaler .transform (df_ttft_batch )
406+ ttft_for_scale = df_ttft_batch . drop ( columns = [ 'prefill_score_bucket' ], errors = 'ignore' )
407+ ttft_scaled = self .ttft_scaler .transform (ttft_for_scale )
322408 tpot_scaled = self .tpot_scaler .transform (df_tpot_batch )
323409
324410 ttft_pred_mean , ttft_std = self .ttft_model .predict (ttft_scaled , return_std = True )
@@ -331,14 +417,12 @@ def predict_batch(self, features_list: List[dict]) -> Tuple[np.ndarray, np.ndarr
331417 return ttft_pred , tpot_pred
332418
333419 elif self .model_type == ModelType .XGBOOST :
334- # XGBoost logic (unchanged)
335420 ttft_pred = self .ttft_model .predict (df_ttft_batch )
336421 tpot_pred = self .tpot_model .predict (df_tpot_batch )
337422
338423 return ttft_pred , tpot_pred
339424
340- else : # LightGBM - NEW
341- # LightGBM quantile regression directly predicts the quantile
425+ else : # LightGBM
342426 ttft_pred = self .ttft_model .predict (df_ttft_batch )
343427 tpot_pred = self .tpot_model .predict (df_tpot_batch )
344428
0 commit comments