Skip to content

Commit 2d25426

Browse files
Break PostResponse requestcontrol plugin into 3 separate plugins to add streamed request functionality (#1661)
* Break out PostResponse plugin into 3 constituent plugins for request recieved, streaming, and complete * Fix typo in variable names * Log typed name in director.go and remove redundant director nil check in response.go * Renamed the post response plugins to not include the word post. * Fix function comment and pass existing logger into HandleResponseBodyStreaming * Update pkg/epp/requestcontrol/plugins.go Co-authored-by: Nir Rozenbaum <nirro@il.ibm.com> * Update pkg/epp/requestcontrol/request_control_config.go Co-authored-by: Nir Rozenbaum <nirro@il.ibm.com> * Update pkg/epp/requestcontrol/director.go Co-authored-by: Nir Rozenbaum <nirro@il.ibm.com> * Fix comments andlogs, simplify Director defintion to take in config * Revert logging parameter addition, keeping consistent with existing format for plugins --------- Co-authored-by: Nir Rozenbaum <nirro@il.ibm.com>
1 parent 6c1a6e9 commit 2d25426

File tree

9 files changed

+310
-56
lines changed

9 files changed

+310
-56
lines changed

pkg/epp/handlers/response.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,27 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques
6161
reqCtx.ResponseComplete = true
6262

6363
reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true)
64-
return reqCtx, nil
64+
65+
return s.director.HandleResponseBodyComplete(ctx, reqCtx)
6566
}
6667

6768
// The function is to handle streaming response if the modelServer is streaming.
6869
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
70+
logger := log.FromContext(ctx)
71+
_, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx)
72+
if err != nil {
73+
logger.Error(err, "error in HandleResponseBodyStreaming")
74+
}
6975
if strings.Contains(responseText, streamingEndMsg) {
76+
reqCtx.ResponseComplete = true
7077
resp := parseRespForUsage(ctx, responseText)
7178
reqCtx.Usage = resp.Usage
7279
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens)
7380
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens)
81+
_, err := s.director.HandleResponseBodyComplete(ctx, reqCtx)
82+
if err != nil {
83+
logger.Error(err, "error in HandleResponseBodyComplete")
84+
}
7485
}
7586
}
7687

@@ -83,7 +94,7 @@ func (s *StreamingServer) HandleResponseHeaders(ctx context.Context, reqCtx *Req
8394
}
8495
}
8596

86-
reqCtx, err := s.director.HandleResponse(ctx, reqCtx)
97+
reqCtx, err := s.director.HandleResponseReceived(ctx, reqCtx)
8798

8899
return reqCtx, err
89100
}

pkg/epp/handlers/response_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323

2424
"github.com/google/go-cmp/cmp"
2525

26+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
2627
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2728
)
2829

@@ -59,6 +60,27 @@ data: [DONE]
5960
`
6061
)
6162

63+
type mockDirector struct{}
64+
65+
func (m *mockDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
66+
return reqCtx, nil
67+
}
68+
func (m *mockDirector) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
69+
return reqCtx, nil
70+
}
71+
func (m *mockDirector) HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
72+
return reqCtx, nil
73+
}
74+
func (m *mockDirector) HandlePreRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
75+
return reqCtx, nil
76+
}
77+
func (m *mockDirector) GetRandomPod() *backend.Pod {
78+
return &backend.Pod{}
79+
}
80+
func (m *mockDirector) HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
81+
return reqCtx, nil
82+
}
83+
6284
func TestHandleResponseBody(t *testing.T) {
6385
ctx := logutil.NewTestLoggerIntoContext(context.Background())
6486

@@ -83,6 +105,7 @@ func TestHandleResponseBody(t *testing.T) {
83105
for _, test := range tests {
84106
t.Run(test.name, func(t *testing.T) {
85107
server := &StreamingServer{}
108+
server.director = &mockDirector{}
86109
reqCtx := test.reqCtx
87110
if reqCtx == nil {
88111
reqCtx = &RequestContext{}
@@ -143,6 +166,7 @@ func TestHandleStreamedResponseBody(t *testing.T) {
143166
for _, test := range tests {
144167
t.Run(test.name, func(t *testing.T) {
145168
server := &StreamingServer{}
169+
server.director = &mockDirector{}
146170
reqCtx := test.reqCtx
147171
if reqCtx == nil {
148172
reqCtx = &RequestContext{}

pkg/epp/handlers/server.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ func NewStreamingServer(datastore Datastore, director Director) *StreamingServer
5454

5555
type Director interface {
5656
HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
57-
HandleResponse(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
57+
HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
58+
HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
59+
HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
5860
GetRandomPod() *backend.Pod
5961
}
6062

@@ -121,7 +123,7 @@ const (
121123
HeaderRequestResponseComplete StreamRequestState = 1
122124
BodyRequestResponsesComplete StreamRequestState = 2
123125
TrailerRequestResponsesComplete StreamRequestState = 3
124-
ResponseRecieved StreamRequestState = 4
126+
ResponseReceived StreamRequestState = 4
125127
HeaderResponseResponseComplete StreamRequestState = 5
126128
BodyResponseResponsesComplete StreamRequestState = 6
127129
TrailerResponseResponsesComplete StreamRequestState = 7
@@ -251,7 +253,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
251253
loggerTrace.Info("model server is streaming response")
252254
}
253255
}
254-
reqCtx.RequestState = ResponseRecieved
256+
reqCtx.RequestState = ResponseReceived
255257

256258
var responseErr error
257259
reqCtx, responseErr = s.HandleResponseHeaders(ctx, reqCtx, v)
@@ -377,7 +379,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces
377379
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
378380
}
379381
}
380-
if r.RequestState == ResponseRecieved && r.respHeaderResp != nil {
382+
if r.RequestState == ResponseReceived && r.respHeaderResp != nil {
381383
loggerTrace.Info("Sending response header response", "obj", r.respHeaderResp)
382384
if err := srv.Send(r.respHeaderResp); err != nil {
383385
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)

pkg/epp/requestcontrol/director.go

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,11 @@ func NewDirectorWithConfig(
6262
config *Config,
6363
) *Director {
6464
return &Director{
65-
datastore: datastore,
66-
scheduler: scheduler,
67-
admissionController: admissionController,
68-
preRequestPlugins: config.preRequestPlugins,
69-
postResponsePlugins: config.postResponsePlugins,
70-
defaultPriority: 0, // define default priority explicitly
65+
datastore: datastore,
66+
scheduler: scheduler,
67+
admissionController: admissionController,
68+
requestControlPlugins: *config,
69+
defaultPriority: 0, // define default priority explicitly
7170
}
7271
}
7372

@@ -81,11 +80,10 @@ func NewDirectorWithConfig(
8180
// - Preparing the request context for the Envoy ext_proc filter to route the request.
8281
// - Running PostResponse plugins.
8382
type Director struct {
84-
datastore Datastore
85-
scheduler Scheduler
86-
admissionController AdmissionController
87-
preRequestPlugins []PreRequest
88-
postResponsePlugins []PostResponse
83+
datastore Datastore
84+
scheduler Scheduler
85+
admissionController AdmissionController
86+
requestControlPlugins Config
8987
// we just need a pointer to an int variable since priority is a pointer in InferenceObjective
9088
// no need to set this in the constructor, since the value we want is the default int val
9189
// and value types cannot be nil
@@ -261,19 +259,49 @@ func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []sch
261259
return pm
262260
}
263261

264-
func (d *Director) HandleResponse(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
262+
// HandleResponseReceived is called when the response headers are received.
263+
func (d *Director) HandleResponseReceived(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
265264
response := &Response{
266265
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
267266
Headers: reqCtx.Response.Headers,
268267
}
269268

270269
// TODO: to extend fallback functionality, handle cases where target pod is unavailable
271270
// https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/1224
272-
d.runPostResponsePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
271+
d.runResponseReceivedPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
273272

274273
return reqCtx, nil
275274
}
276275

276+
// HandleResponseBodyStreaming is called every time a chunk of the response body is received.
277+
func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
278+
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
279+
logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk")
280+
response := &Response{
281+
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
282+
Headers: reqCtx.Response.Headers,
283+
}
284+
285+
d.runResponseStreamingPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
286+
logger.V(logutil.TRACE).Info("Exiting HandleResponseBodyChunk")
287+
return reqCtx, nil
288+
}
289+
290+
// HandleResponseBodyComplete is called when the response body is fully received.
291+
func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
292+
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
293+
logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete")
294+
response := &Response{
295+
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
296+
Headers: reqCtx.Response.Headers,
297+
}
298+
299+
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
300+
301+
logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")
302+
return reqCtx, nil
303+
}
304+
277305
func (d *Director) GetRandomPod() *backend.Pod {
278306
pods := d.datastore.PodList(backendmetrics.AllPodsPredicate)
279307
if len(pods) == 0 {
@@ -287,22 +315,44 @@ func (d *Director) GetRandomPod() *backend.Pod {
287315
func (d *Director) runPreRequestPlugins(ctx context.Context, request *schedulingtypes.LLMRequest,
288316
schedulingResult *schedulingtypes.SchedulingResult, targetPort int) {
289317
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
290-
for _, plugin := range d.preRequestPlugins {
291-
loggerDebug.Info("Running pre-request plugin", "plugin", plugin.TypedName())
318+
for _, plugin := range d.requestControlPlugins.preRequestPlugins {
319+
loggerDebug.Info("Running PreRequest plugin", "plugin", plugin.TypedName())
292320
before := time.Now()
293321
plugin.PreRequest(ctx, request, schedulingResult, targetPort)
294322
metrics.RecordPluginProcessingLatency(PreRequestExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
295-
loggerDebug.Info("Completed running pre-request plugin successfully", "plugin", plugin.TypedName())
323+
loggerDebug.Info("Completed running PreRequest plugin successfully", "plugin", plugin.TypedName())
324+
}
325+
}
326+
327+
func (d *Director) runResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
328+
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
329+
for _, plugin := range d.requestControlPlugins.responseReceivedPlugins {
330+
loggerDebug.Info("Running ResponseReceived plugin", "plugin", plugin.TypedName())
331+
before := time.Now()
332+
plugin.ResponseReceived(ctx, request, response, targetPod)
333+
metrics.RecordPluginProcessingLatency(ResponseReceivedExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
334+
loggerDebug.Info("Completed running ResponseReceived plugin successfully", "plugin", plugin.TypedName())
335+
}
336+
}
337+
338+
func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
339+
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
340+
for _, plugin := range d.requestControlPlugins.responseStreamingPlugins {
341+
loggerTrace.Info("Running ResponseStreaming plugin", "plugin", plugin.TypedName())
342+
before := time.Now()
343+
plugin.ResponseStreaming(ctx, request, response, targetPod)
344+
metrics.RecordPluginProcessingLatency(ResponseStreamingExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
345+
loggerTrace.Info("Completed running ResponseStreaming plugin successfully", "plugin", plugin.TypedName())
296346
}
297347
}
298348

299-
func (d *Director) runPostResponsePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
349+
func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
300350
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
301-
for _, plugin := range d.postResponsePlugins {
302-
loggerDebug.Info("Running post-response plugin", "plugin", plugin.TypedName())
351+
for _, plugin := range d.requestControlPlugins.responseCompletePlugins {
352+
loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName())
303353
before := time.Now()
304-
plugin.PostResponse(ctx, request, response, targetPod)
305-
metrics.RecordPluginProcessingLatency(PostResponseExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
306-
loggerDebug.Info("Completed running post-response plugin successfully", "plugin", plugin.TypedName())
354+
plugin.ResponseComplete(ctx, request, response, targetPod)
355+
metrics.RecordPluginProcessingLatency(ResponseCompleteExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
356+
loggerDebug.Info("Completed running ResponseComplete plugin successfully", "plugin", plugin.TypedName())
307357
}
308358
}

0 commit comments

Comments
 (0)