@@ -57,6 +57,8 @@ const (
5757 DefaultLRUCapacityPerServer = 31250
5858
5959 PrefixCachePluginType = "prefix-cache-scorer"
60+
61+ PrefixCacheMatchKey = "PrefixCacheMatchKey"
6062)
6163
6264const (
@@ -195,17 +197,45 @@ func (p *Plugin) WithName(name string) *Plugin {
195197 return p
196198}
197199
198- // Score returns the scoring result for the given list of pods based on context.
199- func (p * Plugin ) Score (ctx context.Context , cycleState * types.CycleState , request * types.LLMRequest , pods []types.Pod ) map [types.Pod ]float64 {
200+ func (p * Plugin ) Consumes () map [string ]any {
201+ return map [string ]any {}
202+ }
203+
204+ func (p * Plugin ) Produces () map [string ]any {
205+ return map [string ]any {
206+ PrefixCacheMatchKey : & SchedulingContextState {},
207+ }
208+ }
209+
210+ func (p * Plugin ) PrepareRequestData (ctx context.Context , request * types.LLMRequest , pods []types.Pod ) {
200211 // pre score step, hashing prompt and find longest prefix match.
201212 hashes := hashPrompt (ctx , request , getBlockSize (pods , p .config .DefaultBlockSize ), p .config .MaxPrefixBlocksToMatch )
202213 state := & SchedulingContextState {
203214 PrefixHashes : hashes ,
204215 PrefixCacheServers : p .matchLongestPrefix (ctx , hashes ),
205216 }
206217
207- cycleState . Write ( plugins . StateKey ( p . TypedName (). String ()), state )
218+ // TODO: Instead store this in the pods attribute map to avoid global state in the plugin.
208219 p .pluginState .Write (request .RequestId , plugins .StateKey (p .TypedName ().String ()), state )
220+ }
221+
222+ // Score returns the scoring result for the given list of pods based on context.
223+ func (p * Plugin ) Score (ctx context.Context , cycleState * types.CycleState , request * types.LLMRequest , pods []types.Pod ) map [types.Pod ]float64 {
224+ // TODO(rahulgurnani): Remove duplication with PrepareRequestData after testing.
225+ state , err := plugins .ReadPluginStateKey [* SchedulingContextState ](p .pluginState , request .RequestId , plugins .StateKey (p .TypedName ().String ()))
226+ if err != nil {
227+ // This should not happen, but in case it does, we recalculate the state.
228+ log .FromContext (ctx ).Error (err , "failed to read prefix plugin state, recalculating" )
229+ hashes := hashPrompt (ctx , request , getBlockSize (pods , p .config .DefaultBlockSize ), p .config .MaxPrefixBlocksToMatch )
230+ state = & SchedulingContextState {
231+ PrefixHashes : hashes ,
232+ PrefixCacheServers : p .matchLongestPrefix (ctx , hashes ),
233+ }
234+ p .pluginState .Write (request .RequestId , plugins .StateKey (p .TypedName ().String ()), state )
235+ }
236+ // TODO(rahulgurnani): cleanup the cycleState after all the changes are done. Seems llm-d-scheduler relies on this state.
237+ cycleState .Write (plugins .StateKey (p .TypedName ().String ()), state )
238+
209239 log .FromContext (ctx ).V (logutil .TRACE ).Info ("prefix cached state" , "cached-servers" , state .PrefixCacheServers , "hashes" , state .PrefixHashes )
210240 // calculate the scores of pods
211241 scores := make (map [types.Pod ]float64 , len (pods ))
0 commit comments