Skip to content

Commit e52f121

Browse files
committed
Update prefix match plugin to implement PrepareData plugin
1 parent 13b8c32 commit e52f121

File tree

1 file changed

+33
-3
lines changed
  • pkg/epp/scheduling/framework/plugins/multi/prefix

1 file changed

+33
-3
lines changed

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ const (
5757
DefaultLRUCapacityPerServer = 31250
5858

5959
PrefixCachePluginType = "prefix-cache-scorer"
60+
61+
PrefixCacheMatchKey = "PrefixCacheMatchKey"
6062
)
6163

6264
const (
@@ -195,17 +197,45 @@ func (p *Plugin) WithName(name string) *Plugin {
195197
return p
196198
}
197199

198-
// Score returns the scoring result for the given list of pods based on context.
199-
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
200+
func (p *Plugin) Consumes() map[string]any {
201+
return map[string]any{}
202+
}
203+
204+
func (p *Plugin) Produces() map[string]any {
205+
return map[string]any{
206+
PrefixCacheMatchKey: &SchedulingContextState{},
207+
}
208+
}
209+
210+
func (p *Plugin) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) {
200211
// pre score step, hashing prompt and find longest prefix match.
201212
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch)
202213
state := &SchedulingContextState{
203214
PrefixHashes: hashes,
204215
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
205216
}
206217

207-
cycleState.Write(plugins.StateKey(p.TypedName().String()), state)
218+
// TODO: Instead store this in the pods attribute map to avoid global state in the plugin.
208219
p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state)
220+
}
221+
222+
// Score returns the scoring result for the given list of pods based on context.
223+
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
224+
// TODO(rahulgurnani): Remove duplication with PrepareRequestData after testing.
225+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
226+
if err != nil {
227+
// This should not happen, but in case it does, we recalculate the state.
228+
log.FromContext(ctx).Error(err, "failed to read prefix plugin state, recalculating")
229+
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch)
230+
state = &SchedulingContextState{
231+
PrefixHashes: hashes,
232+
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
233+
}
234+
p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state)
235+
}
236+
// TODO(rahulgurnani): cleanup the cycleState after all the changes are done. Seems llm-d-scheduler relies on this state.
237+
cycleState.Write(plugins.StateKey(p.TypedName().String()), state)
238+
209239
log.FromContext(ctx).V(logutil.TRACE).Info("prefix cached state", "cached-servers", state.PrefixCacheServers, "hashes", state.PrefixHashes)
210240
// calculate the scores of pods
211241
scores := make(map[types.Pod]float64, len(pods))

0 commit comments

Comments
 (0)