From 63d9a7150ef55f780fea6f59d8a72a8200466a07 Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Fri, 3 Oct 2025 23:22:30 +0000 Subject: [PATCH 1/5] Fix function comment and pass existing logger into HandleResponseBodyStreaming --- pkg/epp/handlers/response.go | 2 +- pkg/epp/handlers/response_test.go | 3 ++- pkg/epp/handlers/server.go | 2 +- pkg/epp/requestcontrol/director.go | 4 ++-- pkg/epp/requestcontrol/director_test.go | 4 +++- pkg/epp/server/server_test.go | 3 ++- 6 files changed, 11 insertions(+), 7 deletions(-) diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go index 1cbacbae3..5760cbfc6 100644 --- a/pkg/epp/handlers/response.go +++ b/pkg/epp/handlers/response.go @@ -68,7 +68,7 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques // The function is to handle streaming response if the modelServer is streaming. func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) { logger := log.FromContext(ctx) - _, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx) + _, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx, logger) if err != nil { logger.Error(err, "error in HandleResponseBodyStreaming") } diff --git a/pkg/epp/handlers/response_test.go b/pkg/epp/handlers/response_test.go index 63b2de0da..290161167 100644 --- a/pkg/epp/handlers/response_test.go +++ b/pkg/epp/handlers/response_test.go @@ -21,6 +21,7 @@ import ( "encoding/json" "testing" + "github.com/go-logr/logr" "github.com/google/go-cmp/cmp" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" @@ -62,7 +63,7 @@ data: [DONE] type mockDirector struct{} -func (m *mockDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { +func (m *mockDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext, logger logr.Logger) (*RequestContext, error) { return reqCtx, nil } func (m *mockDirector) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 0d5305574..ccc6e5b9a 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -55,7 +55,7 @@ func NewStreamingServer(datastore Datastore, director Director) *StreamingServer type Director interface { HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) - HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) + HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext, logger logr.Logger) (*RequestContext, error) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) GetRandomPod() *backend.Pod } diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index f6f7deebe..e19c1fa0f 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -26,6 +26,7 @@ import ( "strings" "time" + "github.com/go-logr/logr" "sigs.k8s.io/controller-runtime/pkg/log" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" @@ -265,8 +266,7 @@ func (d *Director) HandleResponseReceived(ctx context.Context, reqCtx *handlers. } // HandleResponseBodyStreaming is called every time a chunk of the response body is received. -func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { - logger := log.FromContext(ctx).WithValues("stage", "bodyChunk") +func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext, logger logr.Logger) (*handlers.RequestContext, error) { logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk") response := &Response{ RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index 8cb9c91a5..1cf26971a 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -32,6 +32,7 @@ import ( "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/log" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" @@ -659,6 +660,7 @@ func TestDirector_HandleResponseStreaming(t *testing.T) { ds := datastore.NewDatastore(t.Context(), nil, 0) mockSched := &mockScheduler{} director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseStreamingPlugins(ps1)) + logger := log.FromContext(ctx) reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ @@ -672,7 +674,7 @@ func TestDirector_HandleResponseStreaming(t *testing.T) { TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } - _, err := director.HandleResponseBodyStreaming(ctx, reqCtx) + _, err := director.HandleResponseBodyStreaming(ctx, reqCtx, logger) if err != nil { t.Fatalf("HandleResponseBodyStreaming() returned unexpected error: %v", err) } diff --git a/pkg/epp/server/server_test.go b/pkg/epp/server/server_test.go index 9032e2c9f..e42dba259 100644 --- a/pkg/epp/server/server_test.go +++ b/pkg/epp/server/server_test.go @@ -22,6 +22,7 @@ import ( "testing" pb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/go-logr/logr" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -184,7 +185,7 @@ func (ts *testDirector) HandleResponseReceived(ctx context.Context, reqCtx *hand return reqCtx, nil } -func (ts *testDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { +func (ts *testDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext, logger logr.Logger) (*handlers.RequestContext, error) { return reqCtx, nil } From 1a7793acd95a62e19e47ff0ad39328e18b5972de Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Wed, 15 Oct 2025 21:15:40 +0000 Subject: [PATCH 2/5] Revert logging parameter addition, keeping consistent with existing format for plugins --- pkg/epp/handlers/response.go | 2 +- pkg/epp/handlers/response_test.go | 3 +-- pkg/epp/handlers/server.go | 2 +- pkg/epp/requestcontrol/director.go | 4 ++-- pkg/epp/requestcontrol/director_test.go | 4 +--- pkg/epp/server/server_test.go | 3 +-- 6 files changed, 7 insertions(+), 11 deletions(-) diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go index 5760cbfc6..1cbacbae3 100644 --- a/pkg/epp/handlers/response.go +++ b/pkg/epp/handlers/response.go @@ -68,7 +68,7 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques // The function is to handle streaming response if the modelServer is streaming. func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) { logger := log.FromContext(ctx) - _, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx, logger) + _, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx) if err != nil { logger.Error(err, "error in HandleResponseBodyStreaming") } diff --git a/pkg/epp/handlers/response_test.go b/pkg/epp/handlers/response_test.go index 290161167..63b2de0da 100644 --- a/pkg/epp/handlers/response_test.go +++ b/pkg/epp/handlers/response_test.go @@ -21,7 +21,6 @@ import ( "encoding/json" "testing" - "github.com/go-logr/logr" "github.com/google/go-cmp/cmp" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" @@ -63,7 +62,7 @@ data: [DONE] type mockDirector struct{} -func (m *mockDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext, logger logr.Logger) (*RequestContext, error) { +func (m *mockDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { return reqCtx, nil } func (m *mockDirector) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index ccc6e5b9a..0d5305574 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -55,7 +55,7 @@ func NewStreamingServer(datastore Datastore, director Director) *StreamingServer type Director interface { HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) - HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext, logger logr.Logger) (*RequestContext, error) + HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) GetRandomPod() *backend.Pod } diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index e19c1fa0f..f6f7deebe 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -26,7 +26,6 @@ import ( "strings" "time" - "github.com/go-logr/logr" "sigs.k8s.io/controller-runtime/pkg/log" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" @@ -266,7 +265,8 @@ func (d *Director) HandleResponseReceived(ctx context.Context, reqCtx *handlers. } // HandleResponseBodyStreaming is called every time a chunk of the response body is received. -func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext, logger logr.Logger) (*handlers.RequestContext, error) { +func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { + logger := log.FromContext(ctx).WithValues("stage", "bodyChunk") logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk") response := &Response{ RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index 1cf26971a..8cb9c91a5 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -32,7 +32,6 @@ import ( "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" "sigs.k8s.io/controller-runtime/pkg/client/fake" - "sigs.k8s.io/controller-runtime/pkg/log" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" @@ -660,7 +659,6 @@ func TestDirector_HandleResponseStreaming(t *testing.T) { ds := datastore.NewDatastore(t.Context(), nil, 0) mockSched := &mockScheduler{} director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseStreamingPlugins(ps1)) - logger := log.FromContext(ctx) reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ @@ -674,7 +672,7 @@ func TestDirector_HandleResponseStreaming(t *testing.T) { TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } - _, err := director.HandleResponseBodyStreaming(ctx, reqCtx, logger) + _, err := director.HandleResponseBodyStreaming(ctx, reqCtx) if err != nil { t.Fatalf("HandleResponseBodyStreaming() returned unexpected error: %v", err) } diff --git a/pkg/epp/server/server_test.go b/pkg/epp/server/server_test.go index e42dba259..9032e2c9f 100644 --- a/pkg/epp/server/server_test.go +++ b/pkg/epp/server/server_test.go @@ -22,7 +22,6 @@ import ( "testing" pb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/go-logr/logr" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -185,7 +184,7 @@ func (ts *testDirector) HandleResponseReceived(ctx context.Context, reqCtx *hand return reqCtx, nil } -func (ts *testDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext, logger logr.Logger) (*handlers.RequestContext, error) { +func (ts *testDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { return reqCtx, nil } From 028974c565d0db745a1a474d2fa10564bbec3f0e Mon Sep 17 00:00:00 2001 From: bobzetian Date: Tue, 14 Oct 2025 20:35:16 +0000 Subject: [PATCH 3/5] Add reponse to prefix cache in nonStreaming mode. --- pkg/epp/handlers/server.go | 1 + pkg/epp/requestcontrol/director.go | 13 +- pkg/epp/requestcontrol/director_test.go | 23 +- .../framework/plugins/multi/prefix/plugin.go | 86 +++- .../plugins/multi/prefix/plugin_test.go | 84 ++++ pkg/epp/scheduling/types/llmresponse.go | 135 ++++++ pkg/epp/scheduling/types/llmresponse_test.go | 399 ++++++++++++++++++ 7 files changed, 717 insertions(+), 24 deletions(-) create mode 100644 pkg/epp/scheduling/types/llmresponse.go create mode 100644 pkg/epp/scheduling/types/llmresponse_test.go diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 0d5305574..6b6999dd0 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -115,6 +115,7 @@ type Request struct { } type Response struct { Headers map[string]string + Body []byte } type StreamRequestState int diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index f6f7deebe..3b5c8ac8d 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -280,13 +280,20 @@ func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *hand // HandleResponseBodyComplete is called when the response body is fully received. func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { - logger := log.FromContext(ctx).WithValues("stage", "bodyChunk") + requestID := reqCtx.Request.Headers[requtil.RequestIdHeaderKey] + logger := log.FromContext(ctx).WithValues("stage", "bodyChunk", requtil.RequestIdHeaderKey, requestID) logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete") + llmResponse, err := schedulingtypes.NewLLMResponseFromBytes(reqCtx.Response.Body) + if err != nil { + logger.Error(err, "HandleResponseBodyComplete: failed to convert the response to LLMResponse.") + return reqCtx, err + } response := &Response{ - RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], + RequestId: requestID, Headers: reqCtx.Response.Headers, + // Currently use the first choice as the response body to process. + Body: llmResponse.GetFirstChoiceContent(), } - d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod) logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete") diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index 8cb9c91a5..cfc00ab3c 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -696,6 +696,23 @@ func TestDirector_HandleResponseComplete(t *testing.T) { mockSched := &mockScheduler{} director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseCompletePlugins(pc1)) + chatCompletionJSON := `{ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3 + } + }` + reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ Headers: map[string]string{ @@ -704,6 +721,7 @@ func TestDirector_HandleResponseComplete(t *testing.T) { }, Response: &handlers.Response{ Headers: map[string]string{"X-Test-Complete-Header": "CompleteValue"}, + Body: []byte(chatCompletionJSON), }, TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } @@ -717,11 +735,14 @@ func TestDirector_HandleResponseComplete(t *testing.T) { t.Errorf("Scheduler.OnComplete RequestId mismatch (-want +got):\n%s", diff) } if diff := cmp.Diff(reqCtx.Response.Headers, pc1.lastRespOnComplete.Headers); diff != "" { - t.Errorf("Scheduler.OnComplete Headers mismatch (-want +got):\n%s", diff) + t.Errorf("Scheduler.OnComplete response headers mismatch (-want +got):\n%s", diff) } if diff := cmp.Diff("namespace1/test-pod-name", pc1.lastTargetPodOnComplete); diff != "" { t.Errorf("Scheduler.OnComplete TargetPodName mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff("Hello!", pc1.lastRespOnComplete.Body); diff != "" { + t.Errorf("Scheduler.OnComplete response body mismatch (-want +got):\n%s", diff) + } } const ( diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index c58c16791..40e333699 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -17,6 +17,7 @@ limitations under the License. package prefix import ( + "bytes" "context" "encoding/binary" "encoding/json" @@ -28,6 +29,7 @@ import ( k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" @@ -117,6 +119,10 @@ var _ plugins.StateData = &SchedulingContextState{} type SchedulingContextState struct { // PrefixHashes is a list of prefix hashes of the request prompt broken into blocks. PrefixHashes []BlockHash + // RestBytes is the trailing bytes that not able to fill in a full block and left over. + // If not empty, this will be used as the starting block for the following response that will + // be added to the response as well. This happens especially at the multi-turn scenario. + RestBytes []byte // A map of server to its longest prefix cache match length. PrefixCacheServers map[ServerID]int } @@ -193,9 +199,10 @@ func (p *Plugin) WithName(name string) *Plugin { // Score returns the scoring result for the given list of pods based on context. func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { // pre score step, hashing prompt and find longest prefix match. - hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch) + hashes, restBytes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch) state := &SchedulingContextState{ PrefixHashes: hashes, + RestBytes: restBytes, PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), } @@ -301,47 +308,59 @@ func (m *Plugin) CleanUpInactivePods(ctx context.Context, handle plugins.Handle) // hashPrompt divides the prompt into blocks and calculate the prefix cache for each block. // hash[0] is calculated including the model name and cache_salt(if provided), since different models generally don't share prefix cache. // For block i, hash(i) = hash(block i content, hash(i-1)). -func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash { +// Also return the extra string. +func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) if request == nil || request.Body == nil { loggerDebug.Info("Request or request data is nil, skipping hashing") - return nil + return nil, nil } userInput, err := getUserInputBytes(request) if err != nil { loggerDebug.Error(err, "Failed to get user input bytes") - return nil + return nil, nil } + prevBlockHash := defaultPrevBlock(request) + return hashInputWithPrevBlockHash(ctx, prevBlockHash, 0, userInput, cacheBlockSize, maxPrefixBlocks) +} - if len(userInput) < cacheBlockSize { - loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize) - return nil - } - if len(userInput) > cacheBlockSize*maxPrefixBlocks { - loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize) - userInput = userInput[:maxPrefixBlocks*cacheBlockSize] - } - // Split the body into blocks of size cacheBlockSize. - // If the last block is smaller than cacheBlockSize, it will be ignored. - res := make([]BlockHash, 0, len(userInput)/cacheBlockSize) - // Add the model to the first block hash so that different models have different hashes even with the same body. +func defaultPrevBlock(request *types.LLMRequest) BlockHash { h := xxhash.New() + // Add the model to the first block hash so that different models have different hashes even with the same body. _, _ = h.Write([]byte(request.TargetModel)) if cacheSalt := request.Body.CacheSalt(); cacheSalt != "" { _, _ = h.Write([]byte(cacheSalt)) } - prevBlockHash := BlockHash(h.Sum64()) - for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize { + return BlockHash(h.Sum64()) +} + +func hashInputWithPrevBlockHash(ctx context.Context, prevBlockHash BlockHash, prevBlockLength int, input []byte, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) { + loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) + if len(input)+prevBlockLength < cacheBlockSize { + loggerDebug.Info("Request body too small for prefix cache", "size", len(input), "block size", cacheBlockSize) + return nil, input + } + if len(input)+prevBlockLength > cacheBlockSize*maxPrefixBlocks { + loggerDebug.Info("Truncating input", "size", len(input), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize) + input = input[:(maxPrefixBlocks*cacheBlockSize - prevBlockLength)] + } + // Split the body into blocks of size cacheBlockSize. + // If the last block is smaller than cacheBlockSize, it will be ignored. + res := make([]BlockHash, 0, len(input)/cacheBlockSize) + lastOffSet := 0 + h := xxhash.New() + for i := 0; i+cacheBlockSize <= len(input); i += cacheBlockSize { h.Reset() - _, _ = h.Write(userInput[i : i+cacheBlockSize]) + _, _ = h.Write(input[i : i+cacheBlockSize]) _, _ = h.Write(toBytes(prevBlockHash)) res = append(res, BlockHash(h.Sum64())) prevBlockHash = res[len(res)-1] + lastOffSet = i + cacheBlockSize } - return res + return res, input[lastOffSet:] } func toBytes(i BlockHash) []byte { @@ -359,6 +378,33 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) { return json.Marshal(request.Body.ChatCompletions.Messages) } +func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) { + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) + if err != nil { + log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId) + return + } + p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it. + var input bytes.Buffer + input.Write(state.RestBytes) + input.Write([]byte(response.Body)) + + server := ServerID(targetPod.NamespacedName) + prevBlockHash := defaultPrevBlock(request) + prevBlockHashLength := 0 + if len(state.PrefixHashes) > 0 { + prevBlockHash = state.PrefixHashes[len(state.PrefixHashes)-1] + prevBlockHashLength = len(state.PrefixHashes) + } + inputBytes := input.Bytes() + hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, inputBytes, p.config.DefaultBlockSize, p.config.MaxPrefixBlocksToMatch) + p.wg.Add(1) + go func() { + p.indexer.Add(hashBlocks, server) + p.wg.Done() + }() +} + func getBlockSize(pods []types.Pod, defaultBlockSize int) int { if len(pods) == 0 { return defaultBlockSize diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 59a09db52..83371d65e 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -30,6 +30,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -199,6 +200,89 @@ func TestPrefixPluginCompletion(t *testing.T) { plugin.wg.Wait() } +func TestPrefixPluginCompletionWithResponse(t *testing.T) { + config := Config{ + DefaultBlockSize: 4, + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + LRUCapacityPerServer: DefaultLRUCapacityPerServer, + } + plugin := New(context.Background(), config) + + pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}} + pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}} + pods := []types.Pod{pod1, pod2} + + // -- First Request -- + // This initial request will populate the cache. + req1 := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "test-model1", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "aaaaaa", + }, + }, + } + scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods) + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String())) + assert.NoError(t, err) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) + // Input size is 6, hash block size is 4, so the last 2 characters are ignored. + // Total hashes = 1 (for the "aaaa" block) + 1 (for the model prefix). + assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers yet") + assert.Equal(t, float64(0), scores[pod1], "score for pod1 should be 0 on first request") + assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0 on first request") + + // Simulate that the scheduler picked pod1 for the first request. + schedulingResult := &types.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*types.ProfileRunResult{ + "default": {TargetPods: []types.Pod{pod1}}, + }, + } + plugin.PreRequest(context.Background(), req1, schedulingResult, 0) + plugin.wg.Wait() + + // -- Simulate Response Completion -- + // The ResponseComplete hook is called. The plugin should update pod1's KV cache + // with the full context of the completed interaction (prompt + response). + // - Initial Prompt: "aaaaaa" + // - Response Body: "bb" + // - Cached Sequence: "aaaaaabb" (length 8) + // This sequence creates two 4-character blocks to be cached: "aaaa" and "aabb". + plugin.ResponseComplete(context.Background(), req1, &requestcontrol.Response{Body: "bb"}, pod1.GetPod()) + plugin.wg.Wait() + + // -- Second Request: Multi-turn Follow-up -- + // This request simulates a follow-up message in a chat. The prompt contains the + // entire conversation history ("aaaaaabb") plus new text ("cc"). + // The plugin should find that the first two blocks ("aaaa", "aabb") of this new + // prompt are already cached on pod1, giving it a perfect match score of 1.0. + // Pod2 has no matching cache entries and should score 0. + req2 := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "test-model1", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "aaaaaabbcc", + }, + }, + } + scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods) + state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String())) + assert.NoError(t, err) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) + // Input size is 10, hash block size is 4. The prompt "aaaaaabb" generates 2 hashes. + // The last 2 characters ("cc") are ignored. + assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") + // It should find a server (pod1) that has cached the prefixes. + assert.Equal(t, 1, len(state.PrefixCacheServers), "a cached server should have been found") + // The score for pod1 should be 1.0 because both prompt blocks ("aaaa" and "aabb") were found in its cache. + assert.Equal(t, float64(1), scores[pod1], "score for pod1 should be a perfect match") + assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0") +} + func TestPrefixPluginChatCompletions(t *testing.T) { config := Config{ DefaultBlockSize: 4, diff --git a/pkg/epp/scheduling/types/llmresponse.go b/pkg/epp/scheduling/types/llmresponse.go new file mode 100644 index 000000000..1061f3ccc --- /dev/null +++ b/pkg/epp/scheduling/types/llmresponse.go @@ -0,0 +1,135 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "encoding/json" + "fmt" +) + +// LLMResponse is a structured representation of a parsed LLM response body. +// An LLMResponse must contain exactly one of ChatCompletion or LegacyCompletion. +type LLMResponse struct { + // ChatCompletion is the representation of the OpenAI /v1/chat/completions response body. + ChatCompletion *ChatCompletionResponse `json:"chat_completion,omitempty"` + // LegacyCompletion is the representation of the OpenAI /v1/completions response body. + LegacyCompletion *LegacyCompletionResponse `json:"legacy_completion,omitempty"` +} + +// GetFirstChoiceContent extracts the primary text content from the first choice +// in either a ChatCompletion or a LegacyCompletion response. +func (res *LLMResponse) GetFirstChoiceContent() string { + if res.ChatCompletion != nil && len(res.ChatCompletion.Choices) > 0 { + return res.ChatCompletion.Choices[0].Message.Content + } else if res.LegacyCompletion != nil && len(res.LegacyCompletion.Choices) > 0 { + return res.LegacyCompletion.Choices[0].Text + } + return "" +} + +// ChatCompletionResponse represents the full response body for the chat completions API. +type ChatCompletionResponse struct { + Choices []ChatChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +func (r *ChatCompletionResponse) String() string { + if r == nil { + return nilString + } + contentLen := 0 + if len(r.Choices) > 0 { + contentLen = len(r.Choices[0].Message.Content) + } + return fmt.Sprintf("{ContentLength: %d, Usage: %s}", contentLen, r.Usage) +} + +// ChatChoice represents a single choice in the chat completion response. +type ChatChoice struct { + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// ChatMessage represents the message object within a choice. +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// LegacyCompletionResponse represents the full response body for the legacy completions API. +type LegacyCompletionResponse struct { + Choices []LegacyChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +func (r *LegacyCompletionResponse) String() string { + if r == nil { + return nilString + } + textLen := 0 + if len(r.Choices) > 0 { + textLen = len(r.Choices[0].Text) + } + return fmt.Sprintf("{TextLength: %d, Usage: %v}", textLen, r.Usage) +} + +// LegacyChoice represents a single choice in the legacy completion response. +type LegacyChoice struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` +} + +// Usage represents the token usage data common to all response formats. +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +func (u *Usage) String() string { + if u == nil { + return nilString + } + return fmt.Sprintf("{Prompt: %d, Completion: %d, Total: %d}", u.PromptTokens, u.CompletionTokens, u.TotalTokens) +} + +// NewLLMResponseFromBytes initializes an LLMResponse by trying to parse the data +// as a chat completion and then as a legacy completion response. +func NewLLMResponseFromBytes(body []byte) (*LLMResponse, error) { + if len(body) == 0 { + return nil, fmt.Errorf("input bytes are empty") + } + + // Attempt to unmarshal as a ChatCompletionResponse first. + var chatResp ChatCompletionResponse + if err := json.Unmarshal(body, &chatResp); err == nil { + // Check if the role is set to distinguish ChatCompletion and LegacyCompletion. + if len(chatResp.Choices) > 0 && chatResp.Choices[0].Message.Role != "" { + return &LLMResponse{ChatCompletion: &chatResp}, nil + } + } + + // Try to unmarshal as a LegacyCompletionResponse. + var legacyResp LegacyCompletionResponse + if err := json.Unmarshal(body, &legacyResp); err == nil { + if len(legacyResp.Choices) > 0 { + return &LLMResponse{LegacyCompletion: &legacyResp}, nil + } + } + + return nil, fmt.Errorf("failed to unmarshal body into any known LLM response format") +} diff --git a/pkg/epp/scheduling/types/llmresponse_test.go b/pkg/epp/scheduling/types/llmresponse_test.go new file mode 100644 index 000000000..8904062a3 --- /dev/null +++ b/pkg/epp/scheduling/types/llmresponse_test.go @@ -0,0 +1,399 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestNewLLMResponseFromBytes(t *testing.T) { + chatCompletionJSON := `{ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3 + } + }` + + legacyCompletionJSON := `{ + "choices": [ + { + "text": "Hello there!", + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 4, + "completion_tokens": 5, + "total_tokens": 9 + } + }` + + chatCompletionEmptyChoicesJSON := `{ + "choices": [], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3 + } + }` + + legacyCompletionEmptyChoicesJSON := `{ + "choices": [], + "usage": { + "prompt_tokens": 4, + "completion_tokens": 5, + "total_tokens": 9 + } + }` + + chatCompletionEmptyUsageJSON := `{ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!" + }, + "finish_reason": "stop" + } + ] + }` + + legacyCompletionEmptyUsageJSON := `{ + "choices": [ + { + "text": "Hello there!", + "finish_reason": "stop" + } + ] + }` + + invalidJSON := `{"invalid": json}` + unstructuredJSON := `{"foo": "bar"}` + + testCases := []struct { + name string + input []byte + want *LLMResponse + wantError bool + }{ + { + name: "valid chat completion response", + input: []byte(chatCompletionJSON), + want: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{ + { + Message: ChatMessage{ + Role: "assistant", + Content: "Hello!", + }, + FinishReason: "stop", + }, + }, + Usage: &Usage{ + PromptTokens: 1, + CompletionTokens: 2, + TotalTokens: 3, + }, + }, + }, + wantError: false, + }, + { + name: "valid legacy completion response", + input: []byte(legacyCompletionJSON), + want: &LLMResponse{ + LegacyCompletion: &LegacyCompletionResponse{ + Choices: []LegacyChoice{ + { + Text: "Hello there!", + FinishReason: "stop", + }, + }, + Usage: &Usage{ + PromptTokens: 4, + CompletionTokens: 5, + TotalTokens: 9, + }, + }, + }, + wantError: false, + }, + { + name: "invalid json", + input: []byte(invalidJSON), + want: nil, + wantError: true, + }, + { + name: "empty input", + input: []byte{}, + want: nil, + wantError: true, + }, + { + name: "unstructured json", + input: []byte(unstructuredJSON), + want: nil, + wantError: true, + }, + { + name: "chat completion with empty choices", + input: []byte(chatCompletionEmptyChoicesJSON), + want: nil, + wantError: true, + }, + { + name: "legacy completion with empty choices", + input: []byte(legacyCompletionEmptyChoicesJSON), + want: nil, + wantError: true, + }, + { + name: "chat completion with empty usage", + input: []byte(chatCompletionEmptyUsageJSON), + want: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{ + { + Message: ChatMessage{ + Role: "assistant", + Content: "Hello!", + }, + FinishReason: "stop", + }, + }, + }, + }, + wantError: false, + }, + { + name: "legacy completion with empty usage", + input: []byte(legacyCompletionEmptyUsageJSON), + want: &LLMResponse{ + LegacyCompletion: &LegacyCompletionResponse{ + Choices: []LegacyChoice{ + { + Text: "Hello there!", + FinishReason: "stop", + }, + }, + }, + }, + wantError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := NewLLMResponseFromBytes(tc.input) + + if (err != nil) != tc.wantError { + t.Errorf("NewLLMResponseFromBytes() error = %v, wantError %v", err, tc.wantError) + return + } + + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("NewLLMResponseFromBytes() (-want +got): %v", diff) + } + }) + } +} + +func TestGetFirstChoiceContent(t *testing.T) { + testCases := []struct { + name string + res *LLMResponse + want string + }{ + { + name: "chatCompletion with choice", + res: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{ + {Message: ChatMessage{Content: "Hello from Chat"}}, + }, + }, + }, + want: "Hello from Chat", + }, + { + name: "legacyCompletion with choice", + res: &LLMResponse{ + LegacyCompletion: &LegacyCompletionResponse{ + Choices: []LegacyChoice{ + {Text: "Hello from Legacy"}, + }, + }, + }, + want: "Hello from Legacy", + }, + { + name: "chatCompletion with no choices", + res: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{}, + }, + }, + want: "", + }, + { + name: "legacyCompletion with no choices", + res: &LLMResponse{ + LegacyCompletion: &LegacyCompletionResponse{ + Choices: []LegacyChoice{}, + }, + }, + want: "", + }, + { + name: "LLMResponse with all fields nil", + res: &LLMResponse{ + ChatCompletion: nil, + LegacyCompletion: nil, + }, + want: "", + }, + { + name: "Empty LLMResponse struct", + res: &LLMResponse{}, + want: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := tc.res.GetFirstChoiceContent() + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("GetFirstChoiceContent() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestUsage_String(t *testing.T) { + var nilUsage *Usage + tests := []struct { + name string + u *Usage + want string + }{ + { + name: "nil usage", + u: nilUsage, + want: nilString, + }, + { + name: "non-nil usage", + u: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, + want: "{Prompt: 1, Completion: 2, Total: 3}", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.u.String(); got != tt.want { + t.Errorf("Usage.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestChatCompletionResponse_String(t *testing.T) { + var nilResp *ChatCompletionResponse + tests := []struct { + name string + r *ChatCompletionResponse + want string + }{ + { + name: "nil response", + r: nilResp, + want: nilString, + }, + { + name: "response with no choices", + r: &ChatCompletionResponse{Choices: []ChatChoice{}, Usage: &Usage{}}, + want: "{ContentLength: 0, Usage: {Prompt: 0, Completion: 0, Total: 0}}", + }, + { + name: "response with choices", + r: &ChatCompletionResponse{ + Choices: []ChatChoice{ + {Message: ChatMessage{Content: "hello"}}, + }, + Usage: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, + }, + want: "{ContentLength: 5, Usage: {Prompt: 1, Completion: 2, Total: 3}}", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.r.String(); got != tt.want { + t.Errorf("ChatCompletionResponse.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLegacyCompletionResponse_String(t *testing.T) { + var nilResp *LegacyCompletionResponse + tests := []struct { + name string + r *LegacyCompletionResponse + want string + }{ + { + name: "nil response", + r: nilResp, + want: nilString, + }, + { + name: "response with no choices", + r: &LegacyCompletionResponse{Choices: []LegacyChoice{}, Usage: &Usage{}}, + want: "{TextLength: 0, Usage: {Prompt: 0, Completion: 0, Total: 0}}", + }, + { + name: "response with choices", + r: &LegacyCompletionResponse{ + Choices: []LegacyChoice{ + {Text: "hello world"}, + }, + Usage: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, + }, + want: "{TextLength: 11, Usage: {Prompt: 1, Completion: 2, Total: 3}}", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.r.String(); got != tt.want { + t.Errorf("LegacyCompletionResponse.String() = %v, want %v", got, tt.want) + } + }) + } +} From 39ae663f041657c6661c397f41d454faf9999c43 Mon Sep 17 00:00:00 2001 From: bobzetian Date: Wed, 15 Oct 2025 04:58:11 +0000 Subject: [PATCH 4/5] make ResponseComplete to accept LLMResponse and update the encoding method of Messages in ChatCompletions. --- pkg/epp/handlers/server.go | 1 + pkg/epp/requestcontrol/director.go | 10 +-- pkg/epp/requestcontrol/director_test.go | 18 ++-- pkg/epp/requestcontrol/plugins.go | 2 +- .../framework/plugins/multi/prefix/plugin.go | 26 +++--- .../plugins/multi/prefix/plugin_test.go | 52 ++++++++++-- pkg/epp/scheduling/types/llmresponse.go | 47 +++++----- pkg/epp/scheduling/types/llmresponse_test.go | 85 +++++++++++-------- pkg/epp/scheduling/types/types.go | 25 +++++- pkg/epp/scheduling/types/types_test.go | 69 +++++++++++++++ 10 files changed, 240 insertions(+), 95 deletions(-) create mode 100644 pkg/epp/scheduling/types/types_test.go diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 6b6999dd0..b589d66ec 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -304,6 +304,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) break } + reqCtx.Response.Body = body reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, responseBody) if responseErr != nil { if logger.V(logutil.DEBUG).Enabled() { diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index 3b5c8ac8d..0c88a0d0f 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -288,13 +288,7 @@ func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handl logger.Error(err, "HandleResponseBodyComplete: failed to convert the response to LLMResponse.") return reqCtx, err } - response := &Response{ - RequestId: requestID, - Headers: reqCtx.Response.Headers, - // Currently use the first choice as the response body to process. - Body: llmResponse.GetFirstChoiceContent(), - } - d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod) + d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, llmResponse, reqCtx.TargetPod) logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete") return reqCtx, nil @@ -344,7 +338,7 @@ func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *sch } } -func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *schedulingtypes.LLMResponse, targetPod *backend.Pod) { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) for _, plugin := range d.requestControlPlugins.responseCompletePlugins { loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName()) diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index cfc00ab3c..e9110b96c 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -712,6 +712,10 @@ func TestDirector_HandleResponseComplete(t *testing.T) { "total_tokens": 3 } }` + wantLLMResponse, err := schedulingtypes.NewLLMResponseFromBytes([]byte(chatCompletionJSON)) + if err != nil { + t.Fatalf("NewLLMResponseFromBytes failed with error: %v", err) + } reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ @@ -726,21 +730,15 @@ func TestDirector_HandleResponseComplete(t *testing.T) { TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } - _, err := director.HandleResponseBodyComplete(ctx, reqCtx) + _, err = director.HandleResponseBodyComplete(ctx, reqCtx) if err != nil { t.Fatalf("HandleResponseBodyComplete() returned unexpected error: %v", err) } - if diff := cmp.Diff("test-req-id-for-complete", pc1.lastRespOnComplete.RequestId); diff != "" { - t.Errorf("Scheduler.OnComplete RequestId mismatch (-want +got):\n%s", diff) - } - if diff := cmp.Diff(reqCtx.Response.Headers, pc1.lastRespOnComplete.Headers); diff != "" { - t.Errorf("Scheduler.OnComplete response headers mismatch (-want +got):\n%s", diff) - } if diff := cmp.Diff("namespace1/test-pod-name", pc1.lastTargetPodOnComplete); diff != "" { t.Errorf("Scheduler.OnComplete TargetPodName mismatch (-want +got):\n%s", diff) } - if diff := cmp.Diff("Hello!", pc1.lastRespOnComplete.Body); diff != "" { + if diff := cmp.Diff(wantLLMResponse, pc1.lastRespOnComplete); diff != "" { t.Errorf("Scheduler.OnComplete response body mismatch (-want +got):\n%s", diff) } } @@ -765,7 +763,7 @@ type testResponseStreaming struct { type testResponseComplete struct { tn plugins.TypedName - lastRespOnComplete *Response + lastRespOnComplete *schedulingtypes.LLMResponse lastTargetPodOnComplete string } @@ -809,7 +807,7 @@ func (p *testResponseStreaming) ResponseStreaming(_ context.Context, _ *scheduli p.lastTargetPodOnStreaming = targetPod.NamespacedName.String() } -func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *schedulingtypes.LLMResponse, targetPod *backend.Pod) { p.lastRespOnComplete = response p.lastTargetPodOnComplete = targetPod.NamespacedName.String() } diff --git a/pkg/epp/requestcontrol/plugins.go b/pkg/epp/requestcontrol/plugins.go index 30f31f070..0018e18ce 100644 --- a/pkg/epp/requestcontrol/plugins.go +++ b/pkg/epp/requestcontrol/plugins.go @@ -55,5 +55,5 @@ type ResponseStreaming interface { // ResponseComplete is called by the director after the complete response is sent. type ResponseComplete interface { plugins.Plugin - ResponseComplete(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod) + ResponseComplete(ctx context.Context, request *types.LLMRequest, response *types.LLMResponse, targetPod *backend.Pod) } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index 40e333699..04ec09ece 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -123,6 +123,8 @@ type SchedulingContextState struct { // If not empty, this will be used as the starting block for the following response that will // be added to the response as well. This happens especially at the multi-turn scenario. RestBytes []byte + // BlockSize is the block size used to caculate the hash of the request/response. + BlockSize int // A map of server to its longest prefix cache match length. PrefixCacheServers map[ServerID]int } @@ -198,11 +200,13 @@ func (p *Plugin) WithName(name string) *Plugin { // Score returns the scoring result for the given list of pods based on context. func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { + blockSize := getBlockSize(pods, p.config.DefaultBlockSize) // pre score step, hashing prompt and find longest prefix match. - hashes, restBytes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch) + hashes, restBytes := hashPrompt(ctx, request, blockSize, p.config.MaxPrefixBlocksToMatch) state := &SchedulingContextState{ PrefixHashes: hashes, RestBytes: restBytes, + BlockSize: blockSize, PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), } @@ -233,7 +237,6 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) - p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it if err != nil { log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId) return @@ -251,9 +254,7 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche total := len(state.PrefixHashes) matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)] - - blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config.DefaultBlockSize) - metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize) + metrics.RecordPrefixCacheMatch(matchLen*state.BlockSize, total*state.BlockSize) } // matchLongestPrefix returns a map of servers and length of prefix that each server caches. @@ -375,19 +376,25 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) { } // must be chat-completions request at this point, return bytes of entire messages - return json.Marshal(request.Body.ChatCompletions.Messages) + return types.MarshalMessagesToJSON(request.Body.ChatCompletions.Messages...) } -func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) { +func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *types.LLMResponse, targetPod *backend.Pod) { state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) if err != nil { log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId) return } p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it. + + reponseForKVCache, err := response.FirstChoiceContent() + if err != nil { + log.FromContext(ctx).Error(err, "failed to get first choice content", "requestID", request.RequestId) + return + } var input bytes.Buffer input.Write(state.RestBytes) - input.Write([]byte(response.Body)) + input.Write(reponseForKVCache) server := ServerID(targetPod.NamespacedName) prevBlockHash := defaultPrevBlock(request) @@ -396,8 +403,7 @@ func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest prevBlockHash = state.PrefixHashes[len(state.PrefixHashes)-1] prevBlockHashLength = len(state.PrefixHashes) } - inputBytes := input.Bytes() - hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, inputBytes, p.config.DefaultBlockSize, p.config.MaxPrefixBlocksToMatch) + hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, input.Bytes(), state.BlockSize, p.config.MaxPrefixBlocksToMatch) p.wg.Add(1) go func() { p.indexer.Add(hashBlocks, server) diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 83371d65e..4532bf287 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -30,7 +30,6 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -201,8 +200,9 @@ func TestPrefixPluginCompletion(t *testing.T) { } func TestPrefixPluginCompletionWithResponse(t *testing.T) { + const defaultBlockSize = 4 config := Config{ - DefaultBlockSize: 4, + DefaultBlockSize: defaultBlockSize, MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, LRUCapacityPerServer: DefaultLRUCapacityPerServer, } @@ -231,6 +231,9 @@ func TestPrefixPluginCompletionWithResponse(t *testing.T) { // Total hashes = 1 (for the "aaaa" block) + 1 (for the model prefix). assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect") assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers yet") + // The last 2 characters are recorded in restBytes of the state. + assert.Equal(t, 2, len(state.RestBytes), "number of restBytes is incorrect") + assert.Equal(t, defaultBlockSize, state.BlockSize, "blockSize is incorrect") assert.Equal(t, float64(0), scores[pod1], "score for pod1 should be 0 on first request") assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0 on first request") @@ -241,7 +244,7 @@ func TestPrefixPluginCompletionWithResponse(t *testing.T) { "default": {TargetPods: []types.Pod{pod1}}, }, } - plugin.PreRequest(context.Background(), req1, schedulingResult, 0) + plugin.PreRequest(context.Background(), req1, schedulingResult) plugin.wg.Wait() // -- Simulate Response Completion -- @@ -251,7 +254,16 @@ func TestPrefixPluginCompletionWithResponse(t *testing.T) { // - Response Body: "bb" // - Cached Sequence: "aaaaaabb" (length 8) // This sequence creates two 4-character blocks to be cached: "aaaa" and "aabb". - plugin.ResponseComplete(context.Background(), req1, &requestcontrol.Response{Body: "bb"}, pod1.GetPod()) + resp1 := &types.LLMResponse{ + Completion: &types.CompletionResponse{ + Choices: []types.CompletionChoice{ + { + Text: "bb", + }, + }, + }, + } + plugin.ResponseComplete(context.Background(), req1, resp1, pod1.GetPod()) plugin.wg.Wait() // -- Second Request: Multi-turn Follow-up -- @@ -278,6 +290,9 @@ func TestPrefixPluginCompletionWithResponse(t *testing.T) { assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") // It should find a server (pod1) that has cached the prefixes. assert.Equal(t, 1, len(state.PrefixCacheServers), "a cached server should have been found") + // The last 2 characters ("cc") are recorded in restBytes of the state. + assert.Equal(t, 2, len(state.RestBytes), "number of restBytes is incorrect") + assert.Equal(t, defaultBlockSize, state.BlockSize, "blockSize is incorrect") // The score for pod1 should be 1.0 because both prompt blocks ("aaaa" and "aabb") were found in its cache. assert.Equal(t, float64(1), scores[pod1], "score for pod1 should be a perfect match") assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0") @@ -362,6 +377,19 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { plugin.PreRequest(context.Background(), req1, schedulingResult) plugin.wg.Wait() + resp1 := &types.LLMResponse{ + ChatCompletion: &types.ChatCompletionResponse{ + Choices: []types.ChatChoice{ + { + Message: types.Message{Role: "assistant", Content: types.Content{Raw: "I'm doing well, thank you! How can I help you today?"}}, + }, + }, + }, + } + // Trigger to simulate the resp1 is added to the kvCache recording. + plugin.ResponseComplete(context.Background(), req1, resp1, pod1.GetPod()) + plugin.wg.Wait() + // Second request adds assistant response and new user message (conversation grows) req2 := &types.LLMRequest{ RequestId: uuid.NewString(), @@ -389,12 +417,26 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { cachedBlocks := state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)] expectedScore := float64(cachedBlocks) / float64(extendedHashCount) assert.Equal(t, expectedScore, scores[pod1], "pod1 should have prefix cache hit") + assert.Greater(t, scores[pod1], float64(0.5), "given the response is also prefix cached the cache hit should be well above 0.5") assert.Equal(t, float64(0), scores[pod2], "pod2 should have no cache hit") // Simulate pod1 was picked again plugin.PreRequest(context.Background(), req2, schedulingResult) plugin.wg.Wait() + resp2 := &types.LLMResponse{ + ChatCompletion: &types.ChatCompletionResponse{ + Choices: []types.ChatChoice{ + { + Message: types.Message{Role: "assistant", Content: types.Content{Raw: "Prefix caching is a technique where..."}}, + }, + }, + }, + } + // Trigger to simulate the resp1 is added to the kvCache recording. + plugin.ResponseComplete(context.Background(), req2, resp2, pod1.GetPod()) + plugin.wg.Wait() + // Third request continues the conversation even further req3 := &types.LLMRequest{ RequestId: uuid.NewString(), @@ -424,7 +466,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { cachedBlocks = state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)] expectedScore = float64(cachedBlocks) / float64(longHashCount) assert.Equal(t, expectedScore, scores[pod1], "pod1 should have higher prefix cache hit") - assert.Greater(t, scores[pod1], float64(0.5), "cache hit rate should be substantial for growing conversation") + assert.Greater(t, scores[pod1], float64(0.8), "cache hit rate should be substantial for growing conversation") assert.Equal(t, float64(0), scores[pod2], "pod2 should still have no cache hit") } diff --git a/pkg/epp/scheduling/types/llmresponse.go b/pkg/epp/scheduling/types/llmresponse.go index 1061f3ccc..950c995c8 100644 --- a/pkg/epp/scheduling/types/llmresponse.go +++ b/pkg/epp/scheduling/types/llmresponse.go @@ -18,6 +18,7 @@ package types import ( "encoding/json" + "errors" "fmt" ) @@ -26,19 +27,19 @@ import ( type LLMResponse struct { // ChatCompletion is the representation of the OpenAI /v1/chat/completions response body. ChatCompletion *ChatCompletionResponse `json:"chat_completion,omitempty"` - // LegacyCompletion is the representation of the OpenAI /v1/completions response body. - LegacyCompletion *LegacyCompletionResponse `json:"legacy_completion,omitempty"` + // Completion is the representation of the OpenAI /v1/completions response body. + Completion *CompletionResponse `json:"legacy_completion,omitempty"` } -// GetFirstChoiceContent extracts the primary text content from the first choice -// in either a ChatCompletion or a LegacyCompletion response. -func (res *LLMResponse) GetFirstChoiceContent() string { +// FirstChoiceContent extracts the first choice of the response. +func (res *LLMResponse) FirstChoiceContent() ([]byte, error) { if res.ChatCompletion != nil && len(res.ChatCompletion.Choices) > 0 { - return res.ChatCompletion.Choices[0].Message.Content - } else if res.LegacyCompletion != nil && len(res.LegacyCompletion.Choices) > 0 { - return res.LegacyCompletion.Choices[0].Text + return MarshalMessagesToJSON(res.ChatCompletion.Choices[0].Message) } - return "" + if res.Completion != nil && len(res.Completion.Choices) > 0 { + return []byte(res.Completion.Choices[0].Text), nil + } + return nil, errors.New("no choices found in the LLM response") } // ChatCompletionResponse represents the full response body for the chat completions API. @@ -53,15 +54,15 @@ func (r *ChatCompletionResponse) String() string { } contentLen := 0 if len(r.Choices) > 0 { - contentLen = len(r.Choices[0].Message.Content) + contentLen = len(r.Choices[0].Message.Content.Raw) } return fmt.Sprintf("{ContentLength: %d, Usage: %s}", contentLen, r.Usage) } // ChatChoice represents a single choice in the chat completion response. type ChatChoice struct { - Message ChatMessage `json:"message"` - FinishReason string `json:"finish_reason"` + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` } // ChatMessage represents the message object within a choice. @@ -70,13 +71,13 @@ type ChatMessage struct { Content string `json:"content"` } -// LegacyCompletionResponse represents the full response body for the legacy completions API. -type LegacyCompletionResponse struct { - Choices []LegacyChoice `json:"choices"` - Usage *Usage `json:"usage,omitempty"` +// CompletionResponse represents the full response body for the legacy completions API. +type CompletionResponse struct { + Choices []CompletionChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` } -func (r *LegacyCompletionResponse) String() string { +func (r *CompletionResponse) String() string { if r == nil { return nilString } @@ -87,8 +88,8 @@ func (r *LegacyCompletionResponse) String() string { return fmt.Sprintf("{TextLength: %d, Usage: %v}", textLen, r.Usage) } -// LegacyChoice represents a single choice in the legacy completion response. -type LegacyChoice struct { +// CompletionChoice represents a single choice in the legacy completion response. +type CompletionChoice struct { Text string `json:"text"` FinishReason string `json:"finish_reason"` } @@ -111,7 +112,7 @@ func (u *Usage) String() string { // as a chat completion and then as a legacy completion response. func NewLLMResponseFromBytes(body []byte) (*LLMResponse, error) { if len(body) == 0 { - return nil, fmt.Errorf("input bytes are empty") + return nil, errors.New("input bytes are empty") } // Attempt to unmarshal as a ChatCompletionResponse first. @@ -124,12 +125,12 @@ func NewLLMResponseFromBytes(body []byte) (*LLMResponse, error) { } // Try to unmarshal as a LegacyCompletionResponse. - var legacyResp LegacyCompletionResponse + var legacyResp CompletionResponse if err := json.Unmarshal(body, &legacyResp); err == nil { if len(legacyResp.Choices) > 0 { - return &LLMResponse{LegacyCompletion: &legacyResp}, nil + return &LLMResponse{Completion: &legacyResp}, nil } } - return nil, fmt.Errorf("failed to unmarshal body into any known LLM response format") + return nil, errors.New("failed to unmarshal body into any known LLM response format") } diff --git a/pkg/epp/scheduling/types/llmresponse_test.go b/pkg/epp/scheduling/types/llmresponse_test.go index 8904062a3..4aba140bf 100644 --- a/pkg/epp/scheduling/types/llmresponse_test.go +++ b/pkg/epp/scheduling/types/llmresponse_test.go @@ -109,9 +109,11 @@ func TestNewLLMResponseFromBytes(t *testing.T) { ChatCompletion: &ChatCompletionResponse{ Choices: []ChatChoice{ { - Message: ChatMessage{ - Role: "assistant", - Content: "Hello!", + Message: Message{ + Role: "assistant", + Content: Content{ + Raw: "Hello!", + }, }, FinishReason: "stop", }, @@ -129,8 +131,8 @@ func TestNewLLMResponseFromBytes(t *testing.T) { name: "valid legacy completion response", input: []byte(legacyCompletionJSON), want: &LLMResponse{ - LegacyCompletion: &LegacyCompletionResponse{ - Choices: []LegacyChoice{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{ { Text: "Hello there!", FinishReason: "stop", @@ -182,9 +184,11 @@ func TestNewLLMResponseFromBytes(t *testing.T) { ChatCompletion: &ChatCompletionResponse{ Choices: []ChatChoice{ { - Message: ChatMessage{ - Role: "assistant", - Content: "Hello!", + Message: Message{ + Role: "assistant", + Content: Content{ + Raw: "Hello!", + }, }, FinishReason: "stop", }, @@ -197,8 +201,8 @@ func TestNewLLMResponseFromBytes(t *testing.T) { name: "legacy completion with empty usage", input: []byte(legacyCompletionEmptyUsageJSON), want: &LLMResponse{ - LegacyCompletion: &LegacyCompletionResponse{ - Choices: []LegacyChoice{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{ { Text: "Hello there!", FinishReason: "stop", @@ -226,33 +230,37 @@ func TestNewLLMResponseFromBytes(t *testing.T) { } } -func TestGetFirstChoiceContent(t *testing.T) { +func TestFirstChoiceContent(t *testing.T) { testCases := []struct { - name string - res *LLMResponse - want string + name string + res *LLMResponse + want []byte + wantError bool }{ { name: "chatCompletion with choice", res: &LLMResponse{ ChatCompletion: &ChatCompletionResponse{ Choices: []ChatChoice{ - {Message: ChatMessage{Content: "Hello from Chat"}}, + {Message: Message{Role: "assistant", + Content: Content{ + Raw: "Hello from Chat", + }}}, }, }, }, - want: "Hello from Chat", + want: []byte(`{"role":"assistant","content":"Hello from Chat"},`), }, { name: "legacyCompletion with choice", res: &LLMResponse{ - LegacyCompletion: &LegacyCompletionResponse{ - Choices: []LegacyChoice{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{ {Text: "Hello from Legacy"}, }, }, }, - want: "Hello from Legacy", + want: []byte(`Hello from Legacy`), }, { name: "chatCompletion with no choices", @@ -261,37 +269,40 @@ func TestGetFirstChoiceContent(t *testing.T) { Choices: []ChatChoice{}, }, }, - want: "", + wantError: true, }, { name: "legacyCompletion with no choices", res: &LLMResponse{ - LegacyCompletion: &LegacyCompletionResponse{ - Choices: []LegacyChoice{}, + Completion: &CompletionResponse{ + Choices: []CompletionChoice{}, }, }, - want: "", + wantError: true, }, { name: "LLMResponse with all fields nil", res: &LLMResponse{ - ChatCompletion: nil, - LegacyCompletion: nil, + ChatCompletion: nil, + Completion: nil, }, - want: "", + wantError: true, }, { - name: "Empty LLMResponse struct", - res: &LLMResponse{}, - want: "", + name: "Empty LLMResponse struct", + res: &LLMResponse{}, + wantError: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - got := tc.res.GetFirstChoiceContent() + got, err := tc.res.FirstChoiceContent() + if tc.wantError != (err != nil) { + t.Errorf("FirstChoiceContent() wantError is %v, but got error: %v", tc.wantError, err) + } if diff := cmp.Diff(tc.want, got); diff != "" { - t.Errorf("GetFirstChoiceContent() mismatch (-want +got):\n%s", diff) + t.Errorf("FirstChoiceContent() mismatch (-want +got):\n%s", diff) } }) } @@ -345,7 +356,7 @@ func TestChatCompletionResponse_String(t *testing.T) { name: "response with choices", r: &ChatCompletionResponse{ Choices: []ChatChoice{ - {Message: ChatMessage{Content: "hello"}}, + {Message: Message{Content: Content{Raw: "hello"}}}, }, Usage: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, }, @@ -362,10 +373,10 @@ func TestChatCompletionResponse_String(t *testing.T) { } func TestLegacyCompletionResponse_String(t *testing.T) { - var nilResp *LegacyCompletionResponse + var nilResp *CompletionResponse tests := []struct { name string - r *LegacyCompletionResponse + r *CompletionResponse want string }{ { @@ -375,13 +386,13 @@ func TestLegacyCompletionResponse_String(t *testing.T) { }, { name: "response with no choices", - r: &LegacyCompletionResponse{Choices: []LegacyChoice{}, Usage: &Usage{}}, + r: &CompletionResponse{Choices: []CompletionChoice{}, Usage: &Usage{}}, want: "{TextLength: 0, Usage: {Prompt: 0, Completion: 0, Total: 0}}", }, { name: "response with choices", - r: &LegacyCompletionResponse{ - Choices: []LegacyChoice{ + r: &CompletionResponse{ + Choices: []CompletionChoice{ {Text: "hello world"}, }, Usage: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 6f9bec8ad..05a25b2d5 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -17,6 +17,7 @@ limitations under the License. package types import ( + "bytes" "encoding/json" "errors" "fmt" @@ -26,7 +27,10 @@ import ( backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" ) -const nilString = "" +const ( + nilString = "" + messageSplit = "," +) // LLMRequest is a structured representation of the fields we parse out of the LLMRequest body. type LLMRequest struct { @@ -187,6 +191,25 @@ func (mc Content) PlainText() string { return sb.String() } +// MarshalMessagesToJSON converts a slice of Message structs into a JSON byte slice. +// This is used to create a consistent byte representation for prefix caching calculations, +// allowing us to identify common prefixes between LLM requests and responses. +func MarshalMessagesToJSON(messages ...Message) ([]byte, error) { + if len(messages) == 0 { + return []byte{}, nil + } + var buf bytes.Buffer + for _, msg := range messages { + jsonBytes, err := json.Marshal(msg) + if err != nil { + return []byte{}, err + } + buf.Write(jsonBytes) + buf.WriteString(messageSplit) + } + return buf.Bytes(), nil +} + type Pod interface { GetPod() *backend.Pod GetMetrics() *backendmetrics.MetricsState diff --git a/pkg/epp/scheduling/types/types_test.go b/pkg/epp/scheduling/types/types_test.go new file mode 100644 index 000000000..59713a32a --- /dev/null +++ b/pkg/epp/scheduling/types/types_test.go @@ -0,0 +1,69 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "bytes" + "testing" +) + +func TestMarshalMessagesToJSON(t *testing.T) { + testCases := []struct { + name string + messages []Message + want []byte + wantErr bool + }{ + { + name: "empty messages", + messages: []Message{}, + want: []byte{}, + wantErr: false, + }, + { + name: "single message", + messages: []Message{ + {Role: "user", Content: Content{Raw: "Hello"}}, + }, + want: []byte(`{"role":"user","content":"Hello"},`), + wantErr: false, + }, + { + name: "multiple messages", + messages: []Message{ + {Role: "user", Content: Content{Raw: "Hello"}}, + {Role: "assistant", Content: Content{Raw: "Hi there!"}}, + }, + want: []byte(`{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there!"},`), + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := MarshalMessagesToJSON(tc.messages...) + if (err != nil) != tc.wantErr { + t.Errorf("MarshalMessagesToJSON() error = %v, wantErr %v", err, tc.wantErr) + return + } + + if !bytes.Equal(got, tc.want) { + t.Errorf("MarshalMessagesToJSON() got = %s, want %s", string(got), string(tc.want)) + } + }) + } +} From 368d54fe7d1aa00e3b0e88e263b0b44161fc6932 Mon Sep 17 00:00:00 2001 From: bobzetian Date: Wed, 15 Oct 2025 08:23:18 +0000 Subject: [PATCH 5/5] Add streaming response process. --- pkg/epp/handlers/response.go | 89 +++------ pkg/epp/handlers/response_test.go | 57 ++++-- pkg/epp/handlers/server.go | 71 ++++--- pkg/epp/requestcontrol/director.go | 8 +- pkg/epp/requestcontrol/director_test.go | 4 +- pkg/epp/scheduling/types/llmresponse.go | 184 ++++++++++++++++++- pkg/epp/scheduling/types/llmresponse_test.go | 110 +++++++++++ 7 files changed, 400 insertions(+), 123 deletions(-) diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go index 1cbacbae3..cf7d1ff60 100644 --- a/pkg/epp/handlers/response.go +++ b/pkg/epp/handlers/response.go @@ -17,16 +17,15 @@ limitations under the License. package handlers import ( + "bytes" "context" - "encoding/json" - "fmt" - "strings" configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -36,23 +35,19 @@ const ( ) // HandleResponseBody always returns the requestContext even in the error case, as the request context is used in error handling. -func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, response map[string]any) (*RequestContext, error) { +func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, body []byte) (*RequestContext, error) { logger := log.FromContext(ctx) - responseBytes, err := json.Marshal(response) + llmResponse, err := types.NewLLMResponseFromBytes(body) if err != nil { - return reqCtx, fmt.Errorf("error marshalling responseBody - %w", err) - } - if response["usage"] != nil { - usg := response["usage"].(map[string]any) - usage := Usage{ - PromptTokens: int(usg["prompt_tokens"].(float64)), - CompletionTokens: int(usg["completion_tokens"].(float64)), - TotalTokens: int(usg["total_tokens"].(float64)), + logger.Error(err, "failed to create LLMResponse from bytes") + } else { + reqCtx.SchedulingResponse = llmResponse + if usage := reqCtx.SchedulingResponse.Usage(); usage != nil { + reqCtx.Usage = usage + logger.V(logutil.VERBOSE).Info("Response generated", "usage", usage) } - reqCtx.Usage = usage - logger.V(logutil.VERBOSE).Info("Response generated", "usage", reqCtx.Usage) } - reqCtx.ResponseSize = len(responseBytes) + reqCtx.ResponseSize = len(body) // ResponseComplete is to indicate the response is complete. In non-streaming // case, it will be set to be true once the response is processed; in // streaming case, it will be set to be true once the last chunk is processed. @@ -60,25 +55,36 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques // will add the processing for streaming case. reqCtx.ResponseComplete = true - reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true) + reqCtx.respBodyResp = generateResponseBodyResponses(body, true) return s.director.HandleResponseBodyComplete(ctx, reqCtx) } // The function is to handle streaming response if the modelServer is streaming. -func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) { +func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, streamBody []byte) { logger := log.FromContext(ctx) _, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx) if err != nil { logger.Error(err, "error in HandleResponseBodyStreaming") } - if strings.Contains(responseText, streamingEndMsg) { +} + +func (s *StreamingServer) HandleResponseBodyModelStreamingComplete(ctx context.Context, reqCtx *RequestContext, streamBody []byte) { + logger := log.FromContext(ctx) + if bytes.Contains(streamBody, []byte(streamingEndMsg)) { reqCtx.ResponseComplete = true - resp := parseRespForUsage(ctx, responseText) - reqCtx.Usage = resp.Usage - metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens) - metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens) - _, err := s.director.HandleResponseBodyComplete(ctx, reqCtx) + resp, err := types.NewLLMResponseFromStream(streamBody) + if err != nil { + logger.Error(err, "error in converting stream response to LLMResponse.") + } else { + reqCtx.SchedulingResponse = resp + if usage := resp.Usage(); usage != nil { + reqCtx.Usage = usage + metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, usage.PromptTokens) + metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, usage.CompletionTokens) + } + } + _, err = s.director.HandleResponseBodyComplete(ctx, reqCtx) if err != nil { logger.Error(err, "error in HandleResponseBodyComplete") } @@ -153,41 +159,6 @@ func (s *StreamingServer) generateResponseHeaders(reqCtx *RequestContext) []*con return headers } -// Example message if "stream_options": {"include_usage": "true"} is included in the request: -// data: {"id":"...","object":"text_completion","created":1739400043,"model":"food-review-0","choices":[], -// "usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}} -// -// data: [DONE] -// -// Noticed that vLLM returns two entries in one response. -// We need to strip the `data:` prefix and next Data: [DONE] from the message to fetch response data. -// -// If include_usage is not included in the request, `data: [DONE]` is returned separately, which -// indicates end of streaming. -func parseRespForUsage(ctx context.Context, responseText string) ResponseBody { - response := ResponseBody{} - logger := log.FromContext(ctx) - - lines := strings.Split(responseText, "\n") - for _, line := range lines { - if !strings.HasPrefix(line, streamingRespPrefix) { - continue - } - content := strings.TrimPrefix(line, streamingRespPrefix) - if content == "[DONE]" { - continue - } - - byteSlice := []byte(content) - if err := json.Unmarshal(byteSlice, &response); err != nil { - logger.Error(err, "unmarshaling response body") - continue - } - } - - return response -} - type ResponseBody struct { Usage Usage `json:"usage"` } diff --git a/pkg/epp/handlers/response_test.go b/pkg/epp/handlers/response_test.go index 63b2de0da..46d7644b6 100644 --- a/pkg/epp/handlers/response_test.go +++ b/pkg/epp/handlers/response_test.go @@ -18,12 +18,12 @@ package handlers import ( "context" - "encoding/json" "testing" "github.com/google/go-cmp/cmp" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -52,12 +52,33 @@ const ( } ` - streamingBodyWithoutUsage = `data: {"id":"cmpl-41764c93-f9d2-4f31-be08-3ba04fa25394","object":"text_completion","created":1740002445,"model":"food-review-0","choices":[],"usage":null} - ` + streamingBodyWithoutUsage = ` + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"}}]} - streamingBodyWithUsage = `data: {"id":"cmpl-41764c93-f9d2-4f31-be08-3ba04fa25394","object":"text_completion","created":1740002445,"model":"food-review-0","choices":[],"usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}} -data: [DONE] - ` + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" world"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":null} + + data: [DONE] + ` + + streamingBodyWithUsage = ` + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" world"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":7,"total_tokens":12}} + + data: [DONE] + ` ) type mockDirector struct{} @@ -88,13 +109,13 @@ func TestHandleResponseBody(t *testing.T) { name string body []byte reqCtx *RequestContext - want Usage + want *types.Usage wantErr bool }{ { name: "success", body: []byte(body), - want: Usage{ + want: &types.Usage{ PromptTokens: 11, TotalTokens: 111, CompletionTokens: 100, @@ -110,12 +131,7 @@ func TestHandleResponseBody(t *testing.T) { if reqCtx == nil { reqCtx = &RequestContext{} } - var responseMap map[string]any - marshalErr := json.Unmarshal(test.body, &responseMap) - if marshalErr != nil { - t.Error(marshalErr, "Error unmarshaling request body") - } - _, err := server.HandleResponseBody(ctx, reqCtx, responseMap) + _, err := server.HandleResponseBody(ctx, reqCtx, test.body) if err != nil { if !test.wantErr { t.Fatalf("HandleResponseBody returned unexpected error: %v, want %v", err, test.wantErr) @@ -136,7 +152,7 @@ func TestHandleStreamedResponseBody(t *testing.T) { name string body string reqCtx *RequestContext - want Usage + want *types.Usage wantErr bool }{ { @@ -155,10 +171,10 @@ func TestHandleStreamedResponseBody(t *testing.T) { modelServerStreaming: true, }, wantErr: false, - want: Usage{ - PromptTokens: 7, - TotalTokens: 17, - CompletionTokens: 10, + want: &types.Usage{ + PromptTokens: 5, + TotalTokens: 12, + CompletionTokens: 7, }, }, } @@ -171,7 +187,8 @@ func TestHandleStreamedResponseBody(t *testing.T) { if reqCtx == nil { reqCtx = &RequestContext{} } - server.HandleResponseBodyModelStreaming(ctx, reqCtx, test.body) + server.HandleResponseBodyModelStreaming(ctx, reqCtx, []byte(test.body)) + server.HandleResponseBodyModelStreamingComplete(ctx, reqCtx, []byte(test.body)) if diff := cmp.Diff(test.want, reqCtx.Usage); diff != "" { t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff) diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index b589d66ec..90fdb1c2a 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -85,14 +85,15 @@ type RequestContext struct { RequestReceivedTimestamp time.Time ResponseCompleteTimestamp time.Time RequestSize int - Usage Usage + Usage *schedulingtypes.Usage ResponseSize int ResponseComplete bool ResponseStatusCode string RequestRunning bool Request *Request - SchedulingRequest *schedulingtypes.LLMRequest + SchedulingRequest *schedulingtypes.LLMRequest + SchedulingResponse *schedulingtypes.LLMResponse RequestState StreamRequestState modelServerStreaming bool @@ -115,7 +116,6 @@ type Request struct { } type Response struct { Headers map[string]string - Body []byte } type StreamRequestState int @@ -268,13 +268,13 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) reqCtx.respHeaderResp = s.generateResponseHeaderResponse(reqCtx) case *extProcPb.ProcessingRequest_ResponseBody: + body = append(body, v.ResponseBody.Body...) if reqCtx.modelServerStreaming { // Currently we punt on response parsing if the modelServer is streaming, and we just passthrough. - - responseText := string(v.ResponseBody.Body) - s.HandleResponseBodyModelStreaming(ctx, reqCtx, responseText) + s.HandleResponseBodyModelStreaming(ctx, reqCtx, v.ResponseBody.Body) if v.ResponseBody.EndOfStream { loggerTrace.Info("stream completed") + s.HandleResponseBodyModelStreamingComplete(ctx, reqCtx, body) reqCtx.ResponseCompleteTimestamp = time.Now() metrics.RecordRequestLatencies(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp) @@ -283,39 +283,36 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) } reqCtx.respBodyResp = generateResponseBodyResponses(v.ResponseBody.Body, v.ResponseBody.EndOfStream) - } else { - body = append(body, v.ResponseBody.Body...) - - // Message is buffered, we can read and decode. - if v.ResponseBody.EndOfStream { - loggerTrace.Info("stream completed") - // Don't send a 500 on a response error. Just let the message passthrough and log our error for debugging purposes. - // We assume the body is valid JSON, err messages are not guaranteed to be json, and so capturing and sending a 500 obfuscates the response message. - // Using the standard 'err' var will send an immediate error response back to the caller. - var responseErr error - responseErr = json.Unmarshal(body, &responseBody) - if responseErr != nil { - if logger.V(logutil.DEBUG).Enabled() { - logger.V(logutil.DEBUG).Error(responseErr, "Error unmarshalling request body", "body", string(body)) - } else { - logger.V(logutil.DEFAULT).Error(responseErr, "Error unmarshalling request body", "body", string(body)) - } - reqCtx.respBodyResp = generateResponseBodyResponses(body, true) - break + } else if v.ResponseBody.EndOfStream { + loggerTrace.Info("stream completed") + // Don't send a 500 on a response error. Just let the message passthrough and log our error for debugging purposes. + // We assume the body is valid JSON, err messages are not guaranteed to be json, and so capturing and sending a 500 obfuscates the response message. + // Using the standard 'err' var will send an immediate error response back to the caller. + var responseErr error + responseErr = json.Unmarshal(body, &responseBody) + if responseErr != nil { + if logger.V(logutil.DEBUG).Enabled() { + logger.V(logutil.DEBUG).Error(responseErr, "Error unmarshalling request body", "body", string(body)) + } else { + logger.V(logutil.DEFAULT).Error(responseErr, "Error unmarshalling request body", "body", string(body)) } + reqCtx.respBodyResp = generateResponseBodyResponses(body, true) + break + } - reqCtx.Response.Body = body - reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, responseBody) - if responseErr != nil { - if logger.V(logutil.DEBUG).Enabled() { - logger.V(logutil.DEBUG).Error(responseErr, "Failed to process response body", "request", req) - } else { - logger.V(logutil.DEFAULT).Error(responseErr, "Failed to process response body") - } - } else if reqCtx.ResponseComplete { - reqCtx.ResponseCompleteTimestamp = time.Now() - metrics.RecordRequestLatencies(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp) - metrics.RecordResponseSizes(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.ResponseSize) + reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, body) + if responseErr != nil { + if logger.V(logutil.DEBUG).Enabled() { + logger.V(logutil.DEBUG).Error(responseErr, "Failed to process response body", "request", req) + } else { + logger.V(logutil.DEFAULT).Error(responseErr, "Failed to process response body") + } + } else if reqCtx.ResponseComplete { + reqCtx.ResponseCompleteTimestamp = time.Now() + metrics.RecordRequestLatencies(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp) + metrics.RecordResponseSizes(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.ResponseSize) + if reqCtx.Usage != nil { + // Response complete does not guarantee the Usage is populated. metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.Usage.PromptTokens) metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.Usage.CompletionTokens) } diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index 0c88a0d0f..d3567f363 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -20,6 +20,7 @@ package requestcontrol import ( "context" + "errors" "fmt" "math/rand" "net" @@ -283,12 +284,11 @@ func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handl requestID := reqCtx.Request.Headers[requtil.RequestIdHeaderKey] logger := log.FromContext(ctx).WithValues("stage", "bodyChunk", requtil.RequestIdHeaderKey, requestID) logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete") - llmResponse, err := schedulingtypes.NewLLMResponseFromBytes(reqCtx.Response.Body) - if err != nil { - logger.Error(err, "HandleResponseBodyComplete: failed to convert the response to LLMResponse.") + if reqCtx.SchedulingResponse == nil { + err := errors.New("nil scheduling response from reqCtx") return reqCtx, err } - d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, llmResponse, reqCtx.TargetPod) + d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, reqCtx.SchedulingResponse, reqCtx.TargetPod) logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete") return reqCtx, nil diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index e9110b96c..835f9d816 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -725,9 +725,9 @@ func TestDirector_HandleResponseComplete(t *testing.T) { }, Response: &handlers.Response{ Headers: map[string]string{"X-Test-Complete-Header": "CompleteValue"}, - Body: []byte(chatCompletionJSON), }, - TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, + SchedulingResponse: wantLLMResponse, + TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } _, err = director.HandleResponseBodyComplete(ctx, reqCtx) diff --git a/pkg/epp/scheduling/types/llmresponse.go b/pkg/epp/scheduling/types/llmresponse.go index 950c995c8..b2f1ec5e9 100644 --- a/pkg/epp/scheduling/types/llmresponse.go +++ b/pkg/epp/scheduling/types/llmresponse.go @@ -17,15 +17,22 @@ limitations under the License. package types import ( + "bytes" "encoding/json" "errors" "fmt" + "sort" +) + +const ( + // StreamDone is the special string indicating the end of a streaming response. + StreamDone = "[DONE]" ) // LLMResponse is a structured representation of a parsed LLM response body. // An LLMResponse must contain exactly one of ChatCompletion or LegacyCompletion. type LLMResponse struct { - // ChatCompletion is the representation of the OpenAI /v1/chat/completions response body. + // ChatCompletion is the representation of the OpenAI /vv1/chat/completions response body. ChatCompletion *ChatCompletionResponse `json:"chat_completion,omitempty"` // Completion is the representation of the OpenAI /v1/completions response body. Completion *CompletionResponse `json:"legacy_completion,omitempty"` @@ -42,6 +49,13 @@ func (res *LLMResponse) FirstChoiceContent() ([]byte, error) { return nil, errors.New("no choices found in the LLM response") } +func (res *LLMResponse) Usage() *Usage { + if res.ChatCompletion != nil { + return res.ChatCompletion.Usage + } + return res.Completion.Usage +} + // ChatCompletionResponse represents the full response body for the chat completions API. type ChatCompletionResponse struct { Choices []ChatChoice `json:"choices"` @@ -108,6 +122,174 @@ func (u *Usage) String() string { return fmt.Sprintf("{Prompt: %d, Completion: %d, Total: %d}", u.PromptTokens, u.CompletionTokens, u.TotalTokens) } +// ChatCompletionStreamChoiceDelta represents the delta in a streaming choice. +type ChatCompletionStreamChoiceDelta struct { + Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` +} + +// ChatCompletionStreamChoice represents a choice in a streaming response. +type ChatCompletionStreamChoice struct { + Index int `json:"index"` + Delta ChatCompletionStreamChoiceDelta `json:"delta"` + FinishReason string `json:"finish_reason,omitempty"` +} + +// ChatCompletionChunk represents a chunk of a streaming chat completion response. +type ChatCompletionChunk struct { + Choices []ChatCompletionStreamChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +// CompletionStreamChoice represents a choice in a streaming completion response. +type CompletionStreamChoice struct { + Text string `json:"text"` + Index int `json:"index"` + FinishReason string `json:"finish_reason,omitempty"` +} + +// CompletionChunk represents a chunk of a streaming completion response. +type CompletionChunk struct { + Choices []CompletionStreamChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +// NewLLMResponseFromStream initializes an LLMResponse from a streaming response. +func NewLLMResponseFromStream(body []byte) (*LLMResponse, error) { + if len(body) == 0 { + return nil, errors.New("input bytes are empty") + } + + lines := bytes.Split(body, []byte("data: ")) + + // Determine stream type from the first data chunk. + for _, line := range lines { + line = bytes.TrimSpace(line) + + jsonData := bytes.TrimPrefix(line, []byte("data: ")) + if len(jsonData) == 0 || string(jsonData) == StreamDone { + continue + } + + if bytes.Contains(jsonData, []byte(`"delta":`)) { + return processChatStream(lines) + } + if bytes.Contains(jsonData, []byte(`"text":`)) { + return processCompletionStream(lines) + } + } + + return nil, errors.New("failed to determine stream type or find choices") +} + +func processChatStream(lines [][]byte) (*LLMResponse, error) { + chatChoices := make(map[int]*ChatChoice) + var chatUsage *Usage + + for _, line := range lines { + line = bytes.TrimSpace(line) + jsonData := bytes.TrimPrefix(line, []byte("data: ")) + if len(jsonData) == 0 || string(jsonData) == StreamDone { + continue + } + + var chunk ChatCompletionChunk + if err := json.Unmarshal(jsonData, &chunk); err != nil { + continue // Ignore malformed chunks + } + + if chunk.Usage != nil { + chatUsage = chunk.Usage + } + for _, choiceChunk := range chunk.Choices { + if _, ok := chatChoices[choiceChunk.Index]; !ok { + chatChoices[choiceChunk.Index] = &ChatChoice{Message: Message{}} + } + choice := chatChoices[choiceChunk.Index] + choice.Message.Role += choiceChunk.Delta.Role + choice.Message.Content.Raw += choiceChunk.Delta.Content + if choiceChunk.FinishReason != "" { + choice.FinishReason = choiceChunk.FinishReason + } + } + } + + if len(chatChoices) == 0 && chatUsage == nil { + return nil, errors.New("no choices or usage found in chat stream") + } + + return aggregateChatStream(chatChoices, chatUsage), nil +} + +func processCompletionStream(lines [][]byte) (*LLMResponse, error) { + completionChoices := make(map[int]*CompletionChoice) + var completionUsage *Usage + + for _, line := range lines { + line = bytes.TrimSpace(line) + jsonData := bytes.TrimPrefix(line, []byte("data: ")) + if len(jsonData) == 0 || string(jsonData) == StreamDone { + continue + } + + var chunk CompletionChunk + if err := json.Unmarshal(jsonData, &chunk); err != nil { + continue // Ignore malformed chunks + } + + if chunk.Usage != nil { + completionUsage = chunk.Usage + } + for _, choiceChunk := range chunk.Choices { + if _, ok := completionChoices[choiceChunk.Index]; !ok { + completionChoices[choiceChunk.Index] = &CompletionChoice{} + } + choice := completionChoices[choiceChunk.Index] + choice.Text += choiceChunk.Text + if choiceChunk.FinishReason != "" { + choice.FinishReason = choiceChunk.FinishReason + } + } + } + + if len(completionChoices) == 0 && completionUsage == nil { + return nil, errors.New("no choices or usage found in completion stream") + } + + return aggregateCompletionStream(completionChoices, completionUsage), nil +} + +func aggregateChatStream(choices map[int]*ChatChoice, usage *Usage) *LLMResponse { + resp := &ChatCompletionResponse{Usage: usage} + keys := make([]int, 0, len(choices)) + for k := range choices { + keys = append(keys, k) + } + sort.Ints(keys) + finalChoices := make([]ChatChoice, len(keys)) + for i, k := range keys { + finalChoices[i] = *choices[k] + } + resp.Choices = finalChoices + + return &LLMResponse{ChatCompletion: resp} +} + +func aggregateCompletionStream(choices map[int]*CompletionChoice, usage *Usage) *LLMResponse { + resp := &CompletionResponse{Usage: usage} + keys := make([]int, 0, len(choices)) + for k := range choices { + keys = append(keys, k) + } + sort.Ints(keys) + finalChoices := make([]CompletionChoice, len(keys)) + for i, k := range keys { + finalChoices[i] = *choices[k] + } + resp.Choices = finalChoices + return &LLMResponse{Completion: resp} +} + // NewLLMResponseFromBytes initializes an LLMResponse by trying to parse the data // as a chat completion and then as a legacy completion response. func NewLLMResponseFromBytes(body []byte) (*LLMResponse, error) { diff --git a/pkg/epp/scheduling/types/llmresponse_test.go b/pkg/epp/scheduling/types/llmresponse_test.go index 4aba140bf..775ff3f1f 100644 --- a/pkg/epp/scheduling/types/llmresponse_test.go +++ b/pkg/epp/scheduling/types/llmresponse_test.go @@ -17,6 +17,7 @@ limitations under the License. package types import ( + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -230,6 +231,115 @@ func TestNewLLMResponseFromBytes(t *testing.T) { } } +func TestNewLLMResponseFromStream(t *testing.T) { + testCases := []struct { + name string + streamData []byte + want *LLMResponse + wantErr bool + errContains string + }{ + { + name: "valid chat stream with content and usage", + streamData: []byte(` + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" world"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":7,"total_tokens":12}} + + data: [DONE] + `), + want: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{ + { + Message: Message{ + Role: "assistant", + Content: Content{Raw: "Hello world"}, + }, + FinishReason: "stop", + }, + }, + Usage: &Usage{ + PromptTokens: 5, + CompletionTokens: 7, + TotalTokens: 12, + }, + }, + }, + wantErr: false, + }, + { + name: "valid completion stream with content and usage", + streamData: []byte(` + data: {"id":"cmpl-1","object":"text_completion","choices":[{"index":0,"text":"Hello"}]} + + data: {"id":"cmpl-1","object":"text_completion","choices":[{"index":0,"text":" world"}]} + + data: {"id":"cmpl-1","object":"text_completion","choices":[{"index":0,"text":"","finish_reason":"stop"}]} + + data: {"id":"cmpl-1","object":"text_completion","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":7,"total_tokens":12}} + + data: [DONE] + `), + want: &LLMResponse{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{ + { + Text: "Hello world", + FinishReason: "stop", + }, + }, + Usage: &Usage{ + PromptTokens: 5, + CompletionTokens: 7, + TotalTokens: 12, + }, + }, + }, + }, + { + name: "empty stream data", + streamData: []byte(""), + wantErr: true, + errContains: "input bytes are empty", + }, + { + name: "stream with no choices", + streamData: []byte(`data: [DONE]`), + wantErr: true, + errContains: "failed to determine stream type", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := NewLLMResponseFromStream(tc.streamData) + + if tc.wantErr { + if err == nil { + t.Errorf("Expected an error, but got nil") + } + if err != nil && tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("Expected error to contain '%s', but got '%s'", tc.errContains, err.Error()) + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("NewLLMResponseFromStream() mismatch (-want +got):\n%s", diff) + } + } + }) + } +} + func TestFirstChoiceContent(t *testing.T) { testCases := []struct { name string