Skip to content

Commit 7ae9f7e

Browse files
committed
Update prefix match plugin to implement PrepareData plugin
1 parent 13b8c32 commit 7ae9f7e

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ const (
5757
DefaultLRUCapacityPerServer = 31250
5858

5959
PrefixCachePluginType = "prefix-cache-scorer"
60+
61+
PrefixCacheMatchKey = "PrefixCacheMatchKey"
6062
)
6163

6264
const (
@@ -195,17 +197,48 @@ func (p *Plugin) WithName(name string) *Plugin {
195197
return p
196198
}
197199

198-
// Score returns the scoring result for the given list of pods based on context.
199-
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
200+
func (p *Plugin) Consumes() map[string]any {
201+
return map[string]any{}
202+
}
203+
204+
func (p *Plugin) Produces() map[string]any {
205+
return map[string]any{
206+
PrefixCacheMatchKey: &SchedulingContextState{},
207+
}
208+
}
209+
210+
func (p *Plugin) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) {
200211
// pre score step, hashing prompt and find longest prefix match.
201212
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch)
202213
state := &SchedulingContextState{
203214
PrefixHashes: hashes,
204215
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
205216
}
206217

207-
cycleState.Write(plugins.StateKey(p.TypedName().String()), state)
218+
// TODO: Instead store this in the pods attribute map to avoid global state in the plugin.
208219
p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state)
220+
}
221+
222+
// Score returns the scoring result for the given list of pods based on context.
223+
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
224+
// TODO(rahulgurnani): Remove duplication with PrepareRequestData after testing.
225+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
226+
if err != nil {
227+
// This should not happen, but in case it does, we recalculate the state.
228+
// In unit tests, this doesn't happen as PrepareRequestData is always called before Score.
229+
// TODO: When the prefix plugin is split into separate score plugin and pre-request plugin,
230+
// remove this recalculation.
231+
log.FromContext(ctx).Error(err, "failed to read prefix plugin state, recalculating")
232+
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch)
233+
state = &SchedulingContextState{
234+
PrefixHashes: hashes,
235+
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
236+
}
237+
p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state)
238+
}
239+
// TODO(rahulgurnani): cleanup the cycleState after all the changes are done. Seems llm-d-scheduler relies on cyclestate presently.
240+
cycleState.Write(plugins.StateKey(p.TypedName().String()), state)
241+
209242
log.FromContext(ctx).V(logutil.TRACE).Info("prefix cached state", "cached-servers", state.PrefixCacheServers, "hashes", state.PrefixHashes)
210243
// calculate the scores of pods
211244
scores := make(map[types.Pod]float64, len(pods))

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
5555
},
5656
},
5757
}
58+
plugin.PrepareRequestData(context.Background(), req1, pods)
5859
scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods)
5960
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String()))
6061
assert.NoError(t, err)
@@ -87,6 +88,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
8788
},
8889
},
8990
}
91+
plugin.PrepareRequestData(context.Background(), req2, pods)
9092
scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods)
9193
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String()))
9294
assert.NoError(t, err)
@@ -118,6 +120,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
118120
},
119121
},
120122
}
123+
plugin.PrepareRequestData(context.Background(), req3, pods)
121124
scores = plugin.Score(context.Background(), types.NewCycleState(), req3, pods)
122125
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, plugins.StateKey(plugin.TypedName().String()))
123126
assert.NoError(t, err)
@@ -148,6 +151,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
148151
},
149152
},
150153
}
154+
plugin.PrepareRequestData(context.Background(), req4, pods)
151155
scores = plugin.Score(context.Background(), types.NewCycleState(), req4, pods)
152156
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req4.RequestId, plugins.StateKey(plugin.TypedName().String()))
153157
assert.NoError(t, err)
@@ -178,6 +182,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
178182
},
179183
},
180184
}
185+
plugin.PrepareRequestData(context.Background(), req5, pods)
181186
scores = plugin.Score(context.Background(), types.NewCycleState(), req5, pods)
182187
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req5.RequestId, plugins.StateKey(plugin.TypedName().String()))
183188
assert.NoError(t, err)
@@ -223,6 +228,7 @@ func TestPrefixPluginChatCompletions(t *testing.T) {
223228
},
224229
},
225230
}
231+
plugin.PrepareRequestData(context.Background(), req1, pods)
226232
scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods)
227233
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String()))
228234
assert.NoError(t, err)
@@ -258,6 +264,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
258264
},
259265
},
260266
}
267+
plugin.PrepareRequestData(context.Background(), req1, pods)
261268
scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods)
262269
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String()))
263270
assert.NoError(t, err)
@@ -293,6 +300,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
293300
},
294301
},
295302
}
303+
plugin.PrepareRequestData(context.Background(), req2, pods)
296304
scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods)
297305
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String()))
298306
assert.NoError(t, err)
@@ -328,6 +336,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
328336
},
329337
},
330338
}
339+
plugin.PrepareRequestData(context.Background(), req3, pods)
331340
scores = plugin.Score(context.Background(), types.NewCycleState(), req3, pods)
332341
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, plugins.StateKey(plugin.TypedName().String()))
333342
assert.NoError(t, err)
@@ -387,6 +396,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
387396
}
388397

389398
b.ResetTimer()
399+
plugin.PrepareRequestData(context.Background(), req, pods)
390400
// Benchmark the scoring operation
391401
scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods)
392402
_ = scores // Use the result to prevent optimization
@@ -468,8 +478,9 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) {
468478
}
469479

470480
b.ResetTimer()
471-
for i := 0; i < b.N; i++ {
481+
for b.Loop() {
472482
// Benchmark the scoring operation
483+
plugin.PrepareRequestData(context.Background(), req, pods)
473484
scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods)
474485
_ = scores // Use the result to prevent optimization
475486

0 commit comments

Comments
 (0)