diff --git a/pkg/epp/backend/metrics/fake.go b/pkg/epp/backend/metrics/fake.go index 613ebf5ec..1c7a90528 100644 --- a/pkg/epp/backend/metrics/fake.go +++ b/pkg/epp/backend/metrics/fake.go @@ -32,8 +32,9 @@ import ( // FakePodMetrics is an implementation of PodMetrics that doesn't run the async refresh loop. type FakePodMetrics struct { - Pod *backend.Pod - Metrics *MetricsState + Pod *backend.Pod + Metrics *MetricsState + Attributes *datalayer.Attributes } func (fpm *FakePodMetrics) String() string { @@ -51,6 +52,9 @@ func (fpm *FakePodMetrics) GetMetrics() *MetricsState { func (fpm *FakePodMetrics) UpdatePod(pod *datalayer.PodInfo) { fpm.Pod = pod } +func (fpm *FakePodMetrics) GetAttributes() *datalayer.Attributes { + return fpm.Attributes +} func (*FakePodMetrics) Put(string, datalayer.Cloneable) {} func (*FakePodMetrics) Get(string) (datalayer.Cloneable, bool) { return nil, false } diff --git a/pkg/epp/backend/metrics/pod_metrics.go b/pkg/epp/backend/metrics/pod_metrics.go index a1114aecf..4d22ef18c 100644 --- a/pkg/epp/backend/metrics/pod_metrics.go +++ b/pkg/epp/backend/metrics/pod_metrics.go @@ -126,6 +126,9 @@ func (pm *podMetrics) stopRefreshLoop() { func (*podMetrics) Put(string, datalayer.Cloneable) {} func (*podMetrics) Get(string) (datalayer.Cloneable, bool) { return nil, false } func (*podMetrics) Keys() []string { return nil } +func (*podMetrics) GetAttributes() *datalayer.Attributes { + return nil +} func (pm *podMetrics) UpdateMetrics(updated *MetricsState) { updated.UpdateTime = time.Now() diff --git a/pkg/epp/datalayer/endpoint.go b/pkg/epp/datalayer/endpoint.go index 74c11905e..96c5423d5 100644 --- a/pkg/epp/datalayer/endpoint.go +++ b/pkg/epp/datalayer/endpoint.go @@ -25,6 +25,7 @@ import ( type EndpointPodState interface { GetPod() *PodInfo UpdatePod(*PodInfo) + GetAttributes() *Attributes } // EndpointMetricsState allows management of the Metrics related attributes. @@ -89,6 +90,10 @@ func (srv *ModelServer) Keys() []string { return srv.attributes.Keys() } +func (srv *ModelServer) GetAttributes() *Attributes { + return srv.attributes +} + func (srv *ModelServer) Clone() *ModelServer { clone := &ModelServer{ attributes: srv.attributes.Clone(), diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index f6f7deebe..e892eaf08 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -32,6 +32,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" @@ -41,6 +42,11 @@ import ( requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" ) +const ( + prepareDataTimeout = 200 * time.Millisecond + prepareDataMaxRetries = 3 +) + // Datastore defines the interface required by the Director. type Datastore interface { PoolGet() (*v1.InferencePool, error) @@ -89,16 +95,28 @@ type Director struct { defaultPriority int } -// HandleRequest orchestrates the request lifecycle. -// It always returns the requestContext even in the error case, as the request context is used in error handling. -func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { - logger := log.FromContext(ctx) +// getInferenceObjective fetches the inferenceObjective from the datastore otherwise creates a new one based on reqCtx. +func (d *Director) getInferenceObjective(ctx context.Context, reqCtx *handlers.RequestContext) *v1alpha2.InferenceObjective { + infObjective := d.datastore.ObjectiveGet(reqCtx.ObjectiveKey) + if infObjective == nil { + log.FromContext(ctx).V(logutil.VERBOSE).Info("No associated InferenceObjective found, using default", "objectiveKey", reqCtx.ObjectiveKey) + infObjective = &v1alpha2.InferenceObjective{ + Spec: v1alpha2.InferenceObjectiveSpec{ + Priority: &d.defaultPriority, + }, + } + } else if infObjective.Spec.Priority == nil { + // Default to 0 if not specified. + infObjective.Spec.Priority = &d.defaultPriority + } + return infObjective +} - // Parse Request, Resolve Target Models, and Determine Parameters +// resolveTargetModel is a helper to update reqCtx with target model based on request. +func (d *Director) resolveTargetModel(reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { requestBodyMap := reqCtx.Request.Body var ok bool reqCtx.IncomingModelName, ok = requestBodyMap["model"].(string) - if !ok { return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request body"} } @@ -107,24 +125,28 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo reqCtx.TargetModelName = reqCtx.IncomingModelName } reqCtx.Request.Body["model"] = reqCtx.TargetModelName + return reqCtx, nil +} +// HandleRequest orchestrates the request lifecycle. +// It always returns the requestContext even in the error case, as the request context is used in error handling. +func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { + logger := log.FromContext(ctx) + + // Resolve target model and update req context. + reqCtx, err := d.resolveTargetModel(reqCtx) + if err != nil { + return reqCtx, err + } + + // Parse request body. requestBody, err := requtil.ExtractRequestBody(reqCtx.Request.Body) if err != nil { return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()} } - infObjective := d.datastore.ObjectiveGet(reqCtx.ObjectiveKey) - if infObjective == nil { - logger.V(logutil.VERBOSE).Info("No associated InferenceObjective found, using default", "objectiveKey", reqCtx.ObjectiveKey) - infObjective = &v1alpha2.InferenceObjective{ - Spec: v1alpha2.InferenceObjectiveSpec{ - Priority: &d.defaultPriority, - }, - } - } else if infObjective.Spec.Priority == nil { - // Default to 0 if not specified. - infObjective.Spec.Priority = &d.defaultPriority - } + // Parse inference objective. + infObjective := d.getInferenceObjective(ctx, reqCtx) // Prepare LLMRequest (needed for both saturation detection and Scheduler) reqCtx.SchedulingRequest = &schedulingtypes.LLMRequest{ @@ -144,13 +166,25 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo if len(candidatePods) == 0 { return reqCtx, errutil.Error{Code: errutil.ServiceUnavailable, Msg: "failed to find candidate pods for serving the request"} } - if err := d.admissionController.Admit(ctx, reqCtx, candidatePods, *infObjective.Spec.Priority); err != nil { logger.V(logutil.DEFAULT).Info("Request rejected by admission control", "error", err) return reqCtx, err } + snapshotOfCandidatePods := d.toSchedulerPodMetrics(candidatePods) + + // Prepare per request data by running PrepareData plugins. + if d.runPrepareDataPlugins(ctx, reqCtx.SchedulingRequest, snapshotOfCandidatePods) != nil { + // Don't fail the request if PrepareData plugins fail. + logger.V(logutil.DEFAULT).Error(err, "failed to prepare per request data") + } + + // Run admit request plugins + if !d.runAdmissionPlugins(ctx, reqCtx.SchedulingRequest, snapshotOfCandidatePods) { + logger.V(logutil.DEFAULT).Info("Request cannot be admitted") + return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "request cannot be admitted"} + } - result, err := d.scheduler.Schedule(ctx, reqCtx.SchedulingRequest, d.toSchedulerPodMetrics(candidatePods)) + result, err := d.scheduler.Schedule(ctx, reqCtx.SchedulingRequest, snapshotOfCandidatePods) if err != nil { return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()} } @@ -244,7 +278,11 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []schedulingtypes.Pod { pm := make([]schedulingtypes.Pod, len(pods)) for i, pod := range pods { - pm[i] = &schedulingtypes.PodMetrics{Pod: pod.GetPod().Clone(), MetricsState: pod.GetMetrics().Clone()} + if pod.GetAttributes() != nil { + pm[i] = &schedulingtypes.PodMetrics{Pod: pod.GetPod().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: pod.GetAttributes().Clone()} + } else { + pm[i] = &schedulingtypes.PodMetrics{Pod: pod.GetPod().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: datalayer.NewAttributes()} + } } return pm @@ -315,6 +353,62 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling } } +// prepareDataWithRetriesAndTimeout executes the PrepareRequestData plugins with retries and timeout. +func prepareDataWithRetriesAndTimeout(plugin PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { + currentTimeout := prepareDataTimeout + for i := 0; i <= prepareDataMaxRetries; i++ { + errCh := make(chan error, 1) + go func() { + errCh <- plugin.PrepareRequestData(ctx, request, pods) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + if err != nil { + log.FromContext(ctx).V(logutil.DEBUG).Info("PrepareData plugin failed, retrying...", "plugin", plugin.TypedName(), "retry", i+1, "error", err) + continue + } + return nil // Success + case <-time.After(currentTimeout): + log.FromContext(ctx).V(logutil.DEBUG).Info("PrepareData plugin timed out, retrying...", "plugin", plugin.TypedName(), "retry", i+1, "timeout", currentTimeout) + if i == prepareDataMaxRetries { + return fmt.Errorf("PrepareData plugin %s failed after %d retries", plugin.TypedName().String(), prepareDataMaxRetries) + } + } + } + return nil +} + +// TODO: Execute plugins in parallel once DAG execution is supported. +// runPrepareDataPlugins executes PrepareDataPlugins sequentially. +func (d *Director) runPrepareDataPlugins(ctx context.Context, + request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { + for _, plugin := range d.requestControlPlugins.prepareDataPlugins { + err := prepareDataWithRetriesAndTimeout(plugin, ctx, request, pods) + if err != nil { + return err + } + } + + return nil +} + +func (d *Director) runAdmissionPlugins(ctx context.Context, + request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) bool { + loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) + for _, plugin := range d.requestControlPlugins.admissionPlugins { + loggerDebug.Info("Running AdmitRequest plugin", "plugin", plugin.TypedName()) + if denyReason := plugin.AdmitRequest(ctx, request, pods); denyReason != nil { + loggerDebug.Info("AdmitRequest plugin denied the request", "plugin", plugin.TypedName(), "reason", denyReason.Error()) + return false + } + loggerDebug.Info("Completed running AdmitRequest plugin successfully", "plugin", plugin.TypedName()) + } + return true +} + func (d *Director) runResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) for _, plugin := range d.requestControlPlugins.responseReceivedPlugins { diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index 8cb9c91a5..e705beefd 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "maps" "testing" "time" @@ -37,6 +38,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" @@ -48,6 +50,10 @@ import ( testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) +const ( + mockProducedDataKey = "producedDataKey" +) + // --- Mocks --- type mockAdmissionController struct { @@ -66,9 +72,16 @@ func (m *mockAdmissionController) Admit( type mockScheduler struct { scheduleResults *schedulingtypes.SchedulingResult scheduleErr error + dataProduced bool // denotes whether data production is expected. } -func (m *mockScheduler) Schedule(_ context.Context, _ *schedulingtypes.LLMRequest, _ []schedulingtypes.Pod) (*schedulingtypes.SchedulingResult, error) { +func (m *mockScheduler) Schedule(_ context.Context, _ *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) (*schedulingtypes.SchedulingResult, error) { + if pods != nil && m.dataProduced { + data, ok := pods[0].Get(mockProducedDataKey) + if !ok || data.(mockProducedDataType).value != 42 { + return nil, errors.New("expected produced data not found in pod") + } + } return m.scheduleResults, m.scheduleErr } @@ -93,6 +106,66 @@ func (ds *mockDatastore) PodList(predicate func(backendmetrics.PodMetrics) bool) return res } +type mockPrepareDataPlugin struct { + name string + produces map[string]any + consumes map[string]any +} + +func (m *mockPrepareDataPlugin) TypedName() plugins.TypedName { + return plugins.TypedName{Name: m.name, Type: "mock"} +} + +func (m *mockPrepareDataPlugin) Produces() map[string]any { + return m.produces +} + +func (m *mockPrepareDataPlugin) Consumes() map[string]any { + return m.consumes +} + +func (m *mockPrepareDataPlugin) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { + pods[0].Put(mockProducedDataKey, mockProducedDataType{value: 42}) + return nil +} + +func newMockPrepareDataPlugin(name string) *mockPrepareDataPlugin { + return &mockPrepareDataPlugin{ + name: name, + produces: map[string]any{mockProducedDataKey: 0}, + consumes: map[string]any{}, + } +} + +type mockAdmissionPlugin struct { + tn plugins.TypedName + denialError error +} + +func newMockAdmissionPlugin(name string, denialError error) *mockAdmissionPlugin { + return &mockAdmissionPlugin{ + tn: plugins.TypedName{Type: "mock-admit-data", Name: name}, + denialError: denialError, + } +} + +func (m *mockAdmissionPlugin) TypedName() plugins.TypedName { + return m.tn +} + +func (m *mockAdmissionPlugin) AdmitRequest(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { + return m.denialError +} + +type mockProducedDataType struct { + value int +} + +// Clone implements types.Cloneable. +func (m mockProducedDataType) Clone() datalayer.Cloneable { + return mockProducedDataType{value: m.value} +} + func TestDirector_HandleRequest(t *testing.T) { ctx := logutil.NewTestLoggerIntoContext(context.Background()) @@ -167,6 +240,7 @@ func TestDirector_HandleRequest(t *testing.T) { TargetPods: []schedulingtypes.Pod{ &schedulingtypes.ScoredPod{ Pod: &schedulingtypes.PodMetrics{ + AttributeMap: datalayer.NewAttributes(), Pod: &backend.Pod{ Address: "192.168.1.100", Port: "8000", @@ -177,6 +251,7 @@ func TestDirector_HandleRequest(t *testing.T) { }, &schedulingtypes.ScoredPod{ Pod: &schedulingtypes.PodMetrics{ + AttributeMap: datalayer.NewAttributes(), Pod: &backend.Pod{ Address: "192.168.2.100", Port: "8000", @@ -187,6 +262,7 @@ func TestDirector_HandleRequest(t *testing.T) { }, &schedulingtypes.ScoredPod{ Pod: &schedulingtypes.PodMetrics{ + AttributeMap: datalayer.NewAttributes(), Pod: &backend.Pod{ Address: "192.168.4.100", Port: "8000", @@ -211,6 +287,8 @@ func TestDirector_HandleRequest(t *testing.T) { wantReqCtx *handlers.RequestContext // Fields to check in the returned RequestContext wantMutatedBodyModel string // Expected model in reqCtx.Request.Body after PostDispatch targetModelName string // Expected model name after target model resolution + admitRequestDenialError error // Expected denial error from admission plugin + prepareDataPlugin *mockPrepareDataPlugin }{ { name: "successful completions request", @@ -265,6 +343,85 @@ func TestDirector_HandleRequest(t *testing.T) { wantMutatedBodyModel: model, targetModelName: model, }, + { + name: "successful chat completions request with prepare data plugins", + reqBodyMap: map[string]any{ + "model": model, + "messages": []any{ + map[string]any{ + "role": "user", + "content": "critical prompt", + }, + }, + }, + mockAdmissionController: &mockAdmissionController{admitErr: nil}, + schedulerMockSetup: func(m *mockScheduler) { + m.scheduleResults = defaultSuccessfulScheduleResults + m.dataProduced = true + }, + wantReqCtx: &handlers.RequestContext{ + TargetModelName: model, + TargetPod: &backend.Pod{ + NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, + Address: "192.168.1.100", + Port: "8000", + MetricsHost: "192.168.1.100:8000", + }, + TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000", + }, + wantMutatedBodyModel: model, + targetModelName: model, + prepareDataPlugin: newMockPrepareDataPlugin("test-plugin"), + }, + { + name: "successful chat completions request with admit request plugins", + reqBodyMap: map[string]any{ + "model": model, + "messages": []any{ + map[string]any{ + "role": "user", + "content": "critical prompt", + }, + }, + }, + mockAdmissionController: &mockAdmissionController{admitErr: nil}, + schedulerMockSetup: func(m *mockScheduler) { + m.scheduleResults = defaultSuccessfulScheduleResults + }, + wantReqCtx: &handlers.RequestContext{ + TargetModelName: model, + TargetPod: &backend.Pod{ + NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, + Address: "192.168.1.100", + Port: "8000", + MetricsHost: "192.168.1.100:8000", + }, + TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000", + }, + wantMutatedBodyModel: model, + targetModelName: model, + admitRequestDenialError: nil, + }, + { + name: "denied request by admit request plugin", + reqBodyMap: map[string]any{ + "model": model, + "messages": []any{ + map[string]any{ + "role": "user", + "content": "critical prompt", + }, + }, + }, + mockAdmissionController: &mockAdmissionController{admitErr: nil}, + schedulerMockSetup: func(m *mockScheduler) { + m.scheduleResults = defaultSuccessfulScheduleResults + }, + wantMutatedBodyModel: model, + targetModelName: model, + admitRequestDenialError: errors.New("denied by admit plugin"), + wantErrCode: errutil.Internal, + }, { name: "successful chat completions request with multiple messages", reqBodyMap: map[string]any{ @@ -414,7 +571,12 @@ func TestDirector_HandleRequest(t *testing.T) { if test.schedulerMockSetup != nil { test.schedulerMockSetup(mockSched) } - director := NewDirectorWithConfig(ds, mockSched, test.mockAdmissionController, NewConfig()) + config := NewConfig() + if test.prepareDataPlugin != nil { + config = config.WithPrepareDataPlugins(test.prepareDataPlugin) + } + config = config.WithAdmissionPlugins(newMockAdmissionPlugin("test-admit-plugin", test.admitRequestDenialError)) + director := NewDirectorWithConfig(ds, mockSched, test.mockAdmissionController, config) reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ @@ -428,9 +590,7 @@ func TestDirector_HandleRequest(t *testing.T) { TargetModelName: test.targetModelName, } // Deep copy the body map. - for k, v := range test.reqBodyMap { - reqCtx.Request.Body[k] = v - } + maps.Copy(reqCtx.Request.Body, test.reqBodyMap) returnedReqCtx, err := director.HandleRequest(ctx, reqCtx) diff --git a/pkg/epp/requestcontrol/plugins.go b/pkg/epp/requestcontrol/plugins.go index 30f31f070..8c6602049 100644 --- a/pkg/epp/requestcontrol/plugins.go +++ b/pkg/epp/requestcontrol/plugins.go @@ -57,3 +57,21 @@ type ResponseComplete interface { plugins.Plugin ResponseComplete(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod) } + +// PrepareRequestData is called by the director before scheduling requests. +// PrepareDataPlugin plugin is implemented by data producers which produce data from different sources. +type PrepareDataPlugin interface { + plugins.ProducerPlugin + plugins.ConsumerPlugin + PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error +} + +// AdmissionPlugin is called by the director after the prepare data phase and before scheduling. +// When a request has to go through multiple AdmissionPlugin, +// the request is admitted only if all plugins say that the request should be admitted. +type AdmissionPlugin interface { + plugins.Plugin + // AdmitRequest returns the denial reason, wrapped as error if the request is denied. + // If the request is allowed, it returns nil. + AdmitRequest(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error +} diff --git a/pkg/epp/requestcontrol/request_control_config.go b/pkg/epp/requestcontrol/request_control_config.go index ffa6c6609..9701be999 100644 --- a/pkg/epp/requestcontrol/request_control_config.go +++ b/pkg/epp/requestcontrol/request_control_config.go @@ -23,6 +23,8 @@ import ( // NewConfig creates a new Config object and returns its pointer. func NewConfig() *Config { return &Config{ + admissionPlugins: []AdmissionPlugin{}, + prepareDataPlugins: []PrepareDataPlugin{}, preRequestPlugins: []PreRequest{}, responseReceivedPlugins: []ResponseReceived{}, responseStreamingPlugins: []ResponseStreaming{}, @@ -32,6 +34,8 @@ func NewConfig() *Config { // Config provides a configuration for the requestcontrol plugins. type Config struct { + admissionPlugins []AdmissionPlugin + prepareDataPlugins []PrepareDataPlugin preRequestPlugins []PreRequest responseReceivedPlugins []ResponseReceived responseStreamingPlugins []ResponseStreaming @@ -66,10 +70,21 @@ func (c *Config) WithResponseCompletePlugins(plugins ...ResponseComplete) *Confi return c } +// WithPrepareDataPlugins sets the given plugins as the PrepareData plugins. +func (c *Config) WithPrepareDataPlugins(plugins ...PrepareDataPlugin) *Config { + c.prepareDataPlugins = plugins + return c +} + +// WithAdmissionPlugins sets the given plugins as the AdmitRequest plugins. +func (c *Config) WithAdmissionPlugins(plugins ...AdmissionPlugin) *Config { + c.admissionPlugins = plugins + return c +} + // AddPlugins adds the given plugins to the Config. // The type of each plugin is checked and added to the corresponding list of plugins in the Config. // If a plugin implements multiple plugin interfaces, it will be added to each corresponding list. - func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) { for _, plugin := range pluginObjects { if preRequestPlugin, ok := plugin.(PreRequest); ok { @@ -84,5 +99,8 @@ func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) { if responseCompletePlugin, ok := plugin.(ResponseComplete); ok { c.responseCompletePlugins = append(c.responseCompletePlugins, responseCompletePlugin) } + if prepareDataPlugin, ok := plugin.(PrepareDataPlugin); ok { + c.prepareDataPlugins = append(c.prepareDataPlugins, prepareDataPlugin) + } } } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index 9def3e4e9..8f2c84c1f 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -57,6 +57,8 @@ const ( DefaultLRUCapacityPerServer = 31250 PrefixCachePluginType = "prefix-cache-scorer" + + PrefixCacheMatchKey = "PrefixCacheMatchKey" ) const ( @@ -195,8 +197,17 @@ func (p *Plugin) WithName(name string) *Plugin { return p } -// Score returns the scoring result for the given list of pods based on context. -func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +func (p *Plugin) Consumes() map[string]any { + return map[string]any{} +} + +func (p *Plugin) Produces() map[string]any { + return map[string]any{ + PrefixCacheMatchKey: &SchedulingContextState{}, + } +} + +func (p *Plugin) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) { // pre score step, hashing prompt and find longest prefix match. hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch) state := &SchedulingContextState{ @@ -204,12 +215,33 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), } - cycleState.Write(plugins.StateKey(p.TypedName().String()), state) + // TODO: Instead store this in the pods attribute map to avoid global state in the plugin. p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state) +} + +// Score returns the scoring result for the given list of pods based on context. +func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { + // TODO(rahulgurnani): Remove duplication with PrepareRequestData after testing. + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) + if err != nil { + // This should not happen, but in case it does, we recalculate the state. + // In unit tests, this doesn't happen as PrepareRequestData is always called before Score. + // TODO: When the prefix plugin is split into separate score plugin and pre-request plugin, + // remove this recalculation. + log.FromContext(ctx).Error(err, "failed to read prefix plugin state, recalculating") + hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch) + state = &SchedulingContextState{ + PrefixHashes: hashes, + PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), + } + p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state) + } + // TODO(rahulgurnani): cleanup the cycleState after all the changes are done. Seems llm-d-scheduler relies on cyclestate presently. + cycleState.Write(plugins.StateKey(p.TypedName().String()), state) + log.FromContext(ctx).V(logutil.TRACE).Info("prefix cached state", "cached-servers", state.PrefixCacheServers, "hashes", state.PrefixHashes) // calculate the scores of pods scores := make(map[types.Pod]float64, len(pods)) - total := len(state.PrefixHashes) podScoreFunc := func(pod types.Pod) float64 { if total == 0 { diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 59a09db52..7a27dce7f 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -55,6 +55,7 @@ func TestPrefixPluginCompletion(t *testing.T) { }, }, } + plugin.PrepareRequestData(context.Background(), req1, pods) scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods) state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) @@ -87,6 +88,7 @@ func TestPrefixPluginCompletion(t *testing.T) { }, }, } + plugin.PrepareRequestData(context.Background(), req2, pods) scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) @@ -118,6 +120,7 @@ func TestPrefixPluginCompletion(t *testing.T) { }, }, } + plugin.PrepareRequestData(context.Background(), req3, pods) scores = plugin.Score(context.Background(), types.NewCycleState(), req3, pods) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) @@ -148,6 +151,7 @@ func TestPrefixPluginCompletion(t *testing.T) { }, }, } + plugin.PrepareRequestData(context.Background(), req4, pods) scores = plugin.Score(context.Background(), types.NewCycleState(), req4, pods) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req4.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) @@ -178,6 +182,7 @@ func TestPrefixPluginCompletion(t *testing.T) { }, }, } + plugin.PrepareRequestData(context.Background(), req5, pods) scores = plugin.Score(context.Background(), types.NewCycleState(), req5, pods) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req5.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) @@ -223,6 +228,7 @@ func TestPrefixPluginChatCompletions(t *testing.T) { }, }, } + plugin.PrepareRequestData(context.Background(), req1, pods) scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods) state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) @@ -258,6 +264,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { }, }, } + plugin.PrepareRequestData(context.Background(), req1, pods) scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods) state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) @@ -293,6 +300,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { }, }, } + plugin.PrepareRequestData(context.Background(), req2, pods) scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) @@ -328,6 +336,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { }, }, } + plugin.PrepareRequestData(context.Background(), req3, pods) scores = plugin.Score(context.Background(), types.NewCycleState(), req3, pods) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) @@ -387,6 +396,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) { } b.ResetTimer() + plugin.PrepareRequestData(context.Background(), req, pods) // Benchmark the scoring operation scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods) _ = scores // Use the result to prevent optimization @@ -468,8 +478,9 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) { } b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { // Benchmark the scoring operation + plugin.PrepareRequestData(context.Background(), req, pods) scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods) _ = scores // Use the result to prevent optimization diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 6f9bec8ad..8e0553fae 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -24,6 +24,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" ) const nilString = "" @@ -191,6 +192,9 @@ type Pod interface { GetPod() *backend.Pod GetMetrics() *backendmetrics.MetricsState String() string + Get(string) (datalayer.Cloneable, bool) + Put(string, datalayer.Cloneable) + Keys() []string } type ScoredPod struct { @@ -217,6 +221,7 @@ func (pm *PodMetrics) GetMetrics() *backendmetrics.MetricsState { type PodMetrics struct { *backend.Pod *backendmetrics.MetricsState + datalayer.AttributeMap } // ProfileRunResult captures the profile run result. @@ -229,3 +234,8 @@ type SchedulingResult struct { ProfileResults map[string]*ProfileRunResult PrimaryProfileName string } + +// Cloneable types support cloning of the value. +type Cloneable interface { + Clone() Cloneable +}