Skip to content

Commit d486d29

Browse files
committed
Add reponse to prefix cache in nonStreaming mode.
1 parent 646e952 commit d486d29

File tree

7 files changed

+717
-26
lines changed

7 files changed

+717
-26
lines changed

pkg/epp/handlers/server.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ type Request struct {
115115
}
116116
type Response struct {
117117
Headers map[string]string
118+
Body []byte
118119
}
119120
type StreamRequestState int
120121

pkg/epp/requestcontrol/director.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,20 @@ func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *hand
306306

307307
// HandleResponseBodyComplete is called when the response body is fully received.
308308
func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
309-
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
309+
requestID := reqCtx.Request.Headers[requtil.RequestIdHeaderKey]
310+
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk", requtil.RequestIdHeaderKey, requestID)
310311
logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete")
312+
llmResponse, err := schedulingtypes.NewLLMResponseFromBytes(reqCtx.Response.Body)
313+
if err != nil {
314+
logger.Error(err, "HandleResponseBodyComplete: failed to convert the response to LLMResponse.")
315+
return reqCtx, err
316+
}
311317
response := &Response{
312-
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
318+
RequestId: requestID,
313319
Headers: reqCtx.Response.Headers,
320+
// Currently use the first choice as the response body to process.
321+
Body: llmResponse.GetFirstChoiceContent(),
314322
}
315-
316323
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
317324

318325
logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")

pkg/epp/requestcontrol/director_test.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,23 @@ func TestDirector_HandleResponseComplete(t *testing.T) {
675675
mockSched := &mockScheduler{}
676676
director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseCompletePlugins(pc1))
677677

678+
chatCompletionJSON := `{
679+
"choices": [
680+
{
681+
"message": {
682+
"role": "assistant",
683+
"content": "Hello!"
684+
},
685+
"finish_reason": "stop"
686+
}
687+
],
688+
"usage": {
689+
"prompt_tokens": 1,
690+
"completion_tokens": 2,
691+
"total_tokens": 3
692+
}
693+
}`
694+
678695
reqCtx := &handlers.RequestContext{
679696
Request: &handlers.Request{
680697
Headers: map[string]string{
@@ -683,6 +700,7 @@ func TestDirector_HandleResponseComplete(t *testing.T) {
683700
},
684701
Response: &handlers.Response{
685702
Headers: map[string]string{"X-Test-Complete-Header": "CompleteValue"},
703+
Body: []byte(chatCompletionJSON),
686704
},
687705
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
688706
}
@@ -696,11 +714,14 @@ func TestDirector_HandleResponseComplete(t *testing.T) {
696714
t.Errorf("Scheduler.OnComplete RequestId mismatch (-want +got):\n%s", diff)
697715
}
698716
if diff := cmp.Diff(reqCtx.Response.Headers, pc1.lastRespOnComplete.Headers); diff != "" {
699-
t.Errorf("Scheduler.OnComplete Headers mismatch (-want +got):\n%s", diff)
717+
t.Errorf("Scheduler.OnComplete response headers mismatch (-want +got):\n%s", diff)
700718
}
701719
if diff := cmp.Diff("namespace1/test-pod-name", pc1.lastTargetPodOnComplete); diff != "" {
702720
t.Errorf("Scheduler.OnComplete TargetPodName mismatch (-want +got):\n%s", diff)
703721
}
722+
if diff := cmp.Diff("Hello!", pc1.lastRespOnComplete.Body); diff != "" {
723+
t.Errorf("Scheduler.OnComplete response body mismatch (-want +got):\n%s", diff)
724+
}
704725
}
705726

706727
const (

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package prefix
1818

1919
import (
20+
"bytes"
2021
"context"
2122
"encoding/binary"
2223
"encoding/json"
@@ -28,6 +29,7 @@ import (
2829
k8stypes "k8s.io/apimachinery/pkg/types"
2930
"sigs.k8s.io/controller-runtime/pkg/log"
3031

32+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
3133
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
3234
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
3335
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
@@ -114,6 +116,10 @@ var _ plugins.StateData = &SchedulingContextState{}
114116
type SchedulingContextState struct {
115117
// PrefixHashes is a list of prefix hashes of the request prompt broken into blocks.
116118
PrefixHashes []BlockHash
119+
// RestBytes is the trailing bytes that not able to fill in a full block and left over.
120+
// If not empty, this will be used as the starting block for the following response that will
121+
// be added to the response as well. This happens especially at the multi-turn scenario.
122+
RestBytes []byte
117123
// A map of server to its longest prefix cache match length.
118124
PrefixCacheServers map[ServerID]int
119125
}
@@ -190,9 +196,10 @@ func (p *Plugin) WithName(name string) *Plugin {
190196
// Score returns the scoring result for the given list of pods based on context.
191197
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
192198
// pre score step, hashing prompt and find longest prefix match.
193-
hashes := hashPrompt(ctx, request, p.config.BlockSize, p.config.MaxPrefixBlocksToMatch)
199+
hashes, restBytes := hashPrompt(ctx, request, p.config.BlockSize, p.config.MaxPrefixBlocksToMatch)
194200
state := &SchedulingContextState{
195201
PrefixHashes: hashes,
202+
RestBytes: restBytes,
196203
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
197204
}
198205

@@ -223,7 +230,6 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
223230
targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile
224231

225232
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
226-
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it
227233
if err != nil {
228234
log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId)
229235
return
@@ -296,47 +302,58 @@ func (m *Plugin) CleanUpInactivePods(ctx context.Context, handle plugins.Handle)
296302
// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
297303
// hash[0] is calculated including the model name and cache_salt(if provided), since different models generally don't share prefix cache.
298304
// For block i, hash(i) = hash(block i content, hash(i-1)).
299-
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
305+
// Also return the extra string.
306+
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) {
300307
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
301308
if request == nil || request.Body == nil {
302309
loggerDebug.Info("Request or request data is nil, skipping hashing")
303-
return nil
310+
return nil, nil
304311
}
305-
306312
userInput, err := getUserInputBytes(request)
307313
if err != nil {
308314
loggerDebug.Error(err, "Failed to get user input bytes")
309-
return nil
315+
return nil, nil
310316
}
317+
prevBlockHash := defaultPrevBlock(request)
318+
return hashInputWithPrevBlockHash(ctx, prevBlockHash, 0, userInput, cacheBlockSize, maxPrefixBlocks)
319+
}
311320

312-
if len(userInput) < cacheBlockSize {
313-
loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize)
314-
return nil
315-
}
316-
if len(userInput) > cacheBlockSize*maxPrefixBlocks {
317-
loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
318-
userInput = userInput[:maxPrefixBlocks*cacheBlockSize]
319-
}
320-
// Split the body into blocks of size cacheBlockSize.
321-
// If the last block is smaller than cacheBlockSize, it will be ignored.
322-
res := make([]BlockHash, 0, len(userInput)/cacheBlockSize)
323-
// Add the model to the first block hash so that different models have different hashes even with the same body.
321+
func defaultPrevBlock(request *types.LLMRequest) BlockHash {
324322
h := xxhash.New()
323+
// Add the model to the first block hash so that different models have different hashes even with the same body.
325324
_, _ = h.Write([]byte(request.TargetModel))
326325
if cacheSalt := request.Body.CacheSalt(); cacheSalt != "" {
327326
_, _ = h.Write([]byte(cacheSalt))
328327
}
329328

330-
prevBlockHash := BlockHash(h.Sum64())
331-
for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize {
329+
return BlockHash(h.Sum64())
330+
}
331+
332+
func hashInputWithPrevBlockHash(ctx context.Context, prevBlockHash BlockHash, prevBlockLength int, input []byte, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) {
333+
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
334+
if len(input)+prevBlockLength < cacheBlockSize {
335+
loggerDebug.Info("Request body too small for prefix cache", "size", len(input), "block size", cacheBlockSize)
336+
return nil, input
337+
}
338+
if len(input)+prevBlockLength > cacheBlockSize*maxPrefixBlocks {
339+
loggerDebug.Info("Truncating input", "size", len(input), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
340+
input = input[:(maxPrefixBlocks*cacheBlockSize - prevBlockLength)]
341+
}
342+
// Split the body into blocks of size cacheBlockSize.
343+
// If the last block is smaller than cacheBlockSize, it will be ignored.
344+
res := make([]BlockHash, 0, len(input)/cacheBlockSize)
345+
lastOffSet := 0
346+
h := xxhash.New()
347+
for i := 0; i+cacheBlockSize <= len(input); i += cacheBlockSize {
332348
h.Reset()
333-
_, _ = h.Write(userInput[i : i+cacheBlockSize])
349+
_, _ = h.Write(input[i : i+cacheBlockSize])
334350
_, _ = h.Write(toBytes(prevBlockHash))
335351
res = append(res, BlockHash(h.Sum64()))
336352

337353
prevBlockHash = res[len(res)-1]
354+
lastOffSet = i + cacheBlockSize
338355
}
339-
return res
356+
return res, input[lastOffSet:]
340357
}
341358

342359
func toBytes(i BlockHash) []byte {
@@ -353,3 +370,30 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
353370
// must be chat-completions request at this point, return bytes of entire messages
354371
return json.Marshal(request.Body.ChatCompletions.Messages)
355372
}
373+
374+
func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) {
375+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
376+
if err != nil {
377+
log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId)
378+
return
379+
}
380+
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it.
381+
var input bytes.Buffer
382+
input.Write(state.RestBytes)
383+
input.Write([]byte(response.Body))
384+
385+
server := ServerID(targetPod.NamespacedName)
386+
prevBlockHash := defaultPrevBlock(request)
387+
prevBlockHashLength := 0
388+
if len(state.PrefixHashes) > 0 {
389+
prevBlockHash = state.PrefixHashes[len(state.PrefixHashes)-1]
390+
prevBlockHashLength = len(state.PrefixHashes)
391+
}
392+
inputBytes := input.Bytes()
393+
hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, inputBytes, p.config.BlockSize, p.config.MaxPrefixBlocksToMatch)
394+
p.wg.Add(1)
395+
go func() {
396+
p.indexer.Add(hashBlocks, server)
397+
p.wg.Done()
398+
}()
399+
}

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030

3131
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
3232
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
33+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
3334
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3435
)
3536

@@ -199,6 +200,89 @@ func TestPrefixPluginCompletion(t *testing.T) {
199200
plugin.wg.Wait()
200201
}
201202

203+
func TestPrefixPluginCompletionWithResponse(t *testing.T) {
204+
config := Config{
205+
BlockSize: 4,
206+
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
207+
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
208+
}
209+
plugin := New(context.Background(), config)
210+
211+
pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}
212+
pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}
213+
pods := []types.Pod{pod1, pod2}
214+
215+
// -- First Request --
216+
// This initial request will populate the cache.
217+
req1 := &types.LLMRequest{
218+
RequestId: uuid.NewString(),
219+
TargetModel: "test-model1",
220+
Body: &types.LLMRequestBody{
221+
Completions: &types.CompletionsRequest{
222+
Prompt: "aaaaaa",
223+
},
224+
},
225+
}
226+
scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods)
227+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String()))
228+
assert.NoError(t, err)
229+
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
230+
// Input size is 6, hash block size is 4, so the last 2 characters are ignored.
231+
// Total hashes = 1 (for the "aaaa" block) + 1 (for the model prefix).
232+
assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect")
233+
assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers yet")
234+
assert.Equal(t, float64(0), scores[pod1], "score for pod1 should be 0 on first request")
235+
assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0 on first request")
236+
237+
// Simulate that the scheduler picked pod1 for the first request.
238+
schedulingResult := &types.SchedulingResult{
239+
PrimaryProfileName: "default",
240+
ProfileResults: map[string]*types.ProfileRunResult{
241+
"default": {TargetPods: []types.Pod{pod1}},
242+
},
243+
}
244+
plugin.PreRequest(context.Background(), req1, schedulingResult, 0)
245+
plugin.wg.Wait()
246+
247+
// -- Simulate Response Completion --
248+
// The ResponseComplete hook is called. The plugin should update pod1's KV cache
249+
// with the full context of the completed interaction (prompt + response).
250+
// - Initial Prompt: "aaaaaa"
251+
// - Response Body: "bb"
252+
// - Cached Sequence: "aaaaaabb" (length 8)
253+
// This sequence creates two 4-character blocks to be cached: "aaaa" and "aabb".
254+
plugin.ResponseComplete(context.Background(), req1, &requestcontrol.Response{Body: "bb"}, pod1.GetPod())
255+
plugin.wg.Wait()
256+
257+
// -- Second Request: Multi-turn Follow-up --
258+
// This request simulates a follow-up message in a chat. The prompt contains the
259+
// entire conversation history ("aaaaaabb") plus new text ("cc").
260+
// The plugin should find that the first two blocks ("aaaa", "aabb") of this new
261+
// prompt are already cached on pod1, giving it a perfect match score of 1.0.
262+
// Pod2 has no matching cache entries and should score 0.
263+
req2 := &types.LLMRequest{
264+
RequestId: uuid.NewString(),
265+
TargetModel: "test-model1",
266+
Body: &types.LLMRequestBody{
267+
Completions: &types.CompletionsRequest{
268+
Prompt: "aaaaaabbcc",
269+
},
270+
},
271+
}
272+
scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods)
273+
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String()))
274+
assert.NoError(t, err)
275+
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
276+
// Input size is 10, hash block size is 4. The prompt "aaaaaabb" generates 2 hashes.
277+
// The last 2 characters ("cc") are ignored.
278+
assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect")
279+
// It should find a server (pod1) that has cached the prefixes.
280+
assert.Equal(t, 1, len(state.PrefixCacheServers), "a cached server should have been found")
281+
// The score for pod1 should be 1.0 because both prompt blocks ("aaaa" and "aabb") were found in its cache.
282+
assert.Equal(t, float64(1), scores[pod1], "score for pod1 should be a perfect match")
283+
assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0")
284+
}
285+
202286
func TestPrefixPluginChatCompletions(t *testing.T) {
203287
config := Config{
204288
BlockSize: 4,

0 commit comments

Comments
 (0)