Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pkg/epp/backend/metrics/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ import (

// FakePodMetrics is an implementation of PodMetrics that doesn't run the async refresh loop.
type FakePodMetrics struct {
Pod *backend.Pod
Metrics *MetricsState
Pod *backend.Pod
Metrics *MetricsState
Attributes *datalayer.Attributes
}

func (fpm *FakePodMetrics) String() string {
Expand All @@ -51,6 +52,9 @@ func (fpm *FakePodMetrics) GetMetrics() *MetricsState {
func (fpm *FakePodMetrics) UpdatePod(pod *datalayer.PodInfo) {
fpm.Pod = pod
}
func (fpm *FakePodMetrics) GetAttributes() *datalayer.Attributes {
return fpm.Attributes
}

func (*FakePodMetrics) Put(string, datalayer.Cloneable) {}
func (*FakePodMetrics) Get(string) (datalayer.Cloneable, bool) { return nil, false }
Expand Down
3 changes: 3 additions & 0 deletions pkg/epp/backend/metrics/pod_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ func (pm *podMetrics) stopRefreshLoop() {
func (*podMetrics) Put(string, datalayer.Cloneable) {}
func (*podMetrics) Get(string) (datalayer.Cloneable, bool) { return nil, false }
func (*podMetrics) Keys() []string { return nil }
func (*podMetrics) GetAttributes() *datalayer.Attributes {
return nil
}

func (pm *podMetrics) UpdateMetrics(updated *MetricsState) {
updated.UpdateTime = time.Now()
Expand Down
5 changes: 5 additions & 0 deletions pkg/epp/datalayer/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
type EndpointPodState interface {
GetPod() *PodInfo
UpdatePod(*PodInfo)
GetAttributes() *Attributes
}

// EndpointMetricsState allows management of the Metrics related attributes.
Expand Down Expand Up @@ -89,6 +90,10 @@ func (srv *ModelServer) Keys() []string {
return srv.attributes.Keys()
}

func (srv *ModelServer) GetAttributes() *Attributes {
return srv.attributes
}

func (srv *ModelServer) Clone() *ModelServer {
clone := &ModelServer{
attributes: srv.attributes.Clone(),
Expand Down
144 changes: 123 additions & 21 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
Expand All @@ -41,6 +42,11 @@ import (
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
)

const (
prepareDataTimeout = 200 * time.Millisecond
prepareDataMaxRetries = 3
)

// Datastore defines the interface required by the Director.
type Datastore interface {
PoolGet() (*v1.InferencePool, error)
Expand Down Expand Up @@ -89,16 +95,28 @@ type Director struct {
defaultPriority int
}

// HandleRequest orchestrates the request lifecycle.
// It always returns the requestContext even in the error case, as the request context is used in error handling.
func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx)
// getInferenceObjective fetches the inferenceObjective from the datastore otherwise creates a new one based on reqCtx.
func (d *Director) getInferenceObjective(ctx context.Context, reqCtx *handlers.RequestContext) *v1alpha2.InferenceObjective {
infObjective := d.datastore.ObjectiveGet(reqCtx.ObjectiveKey)
if infObjective == nil {
log.FromContext(ctx).V(logutil.VERBOSE).Info("No associated InferenceObjective found, using default", "objectiveKey", reqCtx.ObjectiveKey)
infObjective = &v1alpha2.InferenceObjective{
Spec: v1alpha2.InferenceObjectiveSpec{
Priority: &d.defaultPriority,
},
}
} else if infObjective.Spec.Priority == nil {
// Default to 0 if not specified.
infObjective.Spec.Priority = &d.defaultPriority
}
return infObjective
}

// Parse Request, Resolve Target Models, and Determine Parameters
// resolveTargetModel is a helper to update reqCtx with target model based on request.
func (d *Director) resolveTargetModel(reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
requestBodyMap := reqCtx.Request.Body
var ok bool
reqCtx.IncomingModelName, ok = requestBodyMap["model"].(string)

if !ok {
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request body"}
}
Expand All @@ -107,24 +125,28 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
reqCtx.TargetModelName = reqCtx.IncomingModelName
}
reqCtx.Request.Body["model"] = reqCtx.TargetModelName
return reqCtx, nil
}

// HandleRequest orchestrates the request lifecycle.
// It always returns the requestContext even in the error case, as the request context is used in error handling.
func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx)

// Resolve target model and update req context.
reqCtx, err := d.resolveTargetModel(reqCtx)
if err != nil {
return reqCtx, err
}

// Parse request body.
requestBody, err := requtil.ExtractRequestBody(reqCtx.Request.Body)
if err != nil {
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()}
}

infObjective := d.datastore.ObjectiveGet(reqCtx.ObjectiveKey)
if infObjective == nil {
logger.V(logutil.VERBOSE).Info("No associated InferenceObjective found, using default", "objectiveKey", reqCtx.ObjectiveKey)
infObjective = &v1alpha2.InferenceObjective{
Spec: v1alpha2.InferenceObjectiveSpec{
Priority: &d.defaultPriority,
},
}
} else if infObjective.Spec.Priority == nil {
// Default to 0 if not specified.
infObjective.Spec.Priority = &d.defaultPriority
}
// Parse inference objective.
infObjective := d.getInferenceObjective(ctx, reqCtx)

// Prepare LLMRequest (needed for both saturation detection and Scheduler)
reqCtx.SchedulingRequest = &schedulingtypes.LLMRequest{
Expand All @@ -144,13 +166,24 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
if len(candidatePods) == 0 {
return reqCtx, errutil.Error{Code: errutil.ServiceUnavailable, Msg: "failed to find candidate pods for serving the request"}
}

if err := d.admissionController.Admit(ctx, reqCtx, candidatePods, *infObjective.Spec.Priority); err != nil {
logger.V(logutil.DEFAULT).Info("Request rejected by admission control", "error", err)
return reqCtx, err
}
snapshotOfCandidatePods := d.toSchedulerPodMetrics(candidatePods)

// Prepare per request data by running PrepareData plugins.
if d.runPrepareDataPlugins(ctx, reqCtx.SchedulingRequest, snapshotOfCandidatePods) != nil {
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "failed to prepare request data"}
}

// Run admit request plugins
if !d.withAdmissionPlugins(ctx, reqCtx.SchedulingRequest, snapshotOfCandidatePods) {
logger.V(logutil.DEFAULT).Info("Request cannot be admitted")
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "request cannot be admitted"}
}

result, err := d.scheduler.Schedule(ctx, reqCtx.SchedulingRequest, d.toSchedulerPodMetrics(candidatePods))
result, err := d.scheduler.Schedule(ctx, reqCtx.SchedulingRequest, snapshotOfCandidatePods)
if err != nil {
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
}
Expand Down Expand Up @@ -244,7 +277,11 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC
func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []schedulingtypes.Pod {
pm := make([]schedulingtypes.Pod, len(pods))
for i, pod := range pods {
pm[i] = &schedulingtypes.PodMetrics{Pod: pod.GetPod().Clone(), MetricsState: pod.GetMetrics().Clone()}
if pod.GetAttributes() != nil {
pm[i] = &schedulingtypes.PodMetrics{Pod: pod.GetPod().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: pod.GetAttributes().Clone()}
} else {
pm[i] = &schedulingtypes.PodMetrics{Pod: pod.GetPod().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: datalayer.NewAttributes()}
}
}

return pm
Expand Down Expand Up @@ -315,6 +352,71 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling
}
}

// executePlugins executes PrepareDataPlugins sequentially.
// TODO: Change to DAG execution in the following PRs.
func (d *Director) executePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod, plugins []PrepareDataPlugin) error {
for _, plugin := range plugins {
err := prepareDataWithRetriesAndTimeout(plugin, ctx, request, pods)
if err != nil {
return err
}
}
return nil
}

// prepareDataWithRetriesAndTimeout executes the PrepareRequestData plugins with retries and timeout.
func prepareDataWithRetriesAndTimeout(plugin PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
currentTimeout := prepareDataTimeout
for i := 0; i <= prepareDataMaxRetries; i++ {
errCh := make(chan error, 1)
go func() {
errCh <- plugin.PrepareRequestData(ctx, request, pods)
}()

select {
case <-ctx.Done():
return ctx.Err()
case err := <-errCh:
if err != nil {
log.FromContext(ctx).V(logutil.DEBUG).Info("PrepareData plugin failed, retrying...", "plugin", plugin.TypedName(), "retry", i+1, "error", err)
continue
}
return nil // Success
case <-time.After(currentTimeout):
log.FromContext(ctx).V(logutil.DEBUG).Info("PrepareData plugin timed out, retrying...", "plugin", plugin.TypedName(), "retry", i+1, "timeout", currentTimeout)
if i == prepareDataMaxRetries {
return fmt.Errorf("PrepareData plugin %s failed after %d retries", plugin.TypedName().String(), prepareDataMaxRetries)
}
}
}
return nil
}

func (d *Director) runPrepareDataPlugins(ctx context.Context,
request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
err := d.executePlugins(ctx, request, pods, d.requestControlPlugins.prepareDataPlugins)
if err != nil {
log.FromContext(ctx).Error(err, "failed to execute PrepareData plugins as DAG, falling back to parallel execution")
return err
}

return nil
}

func (d *Director) withAdmissionPlugins(ctx context.Context,
request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) bool {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
for _, plugin := range d.requestControlPlugins.admissionPlugins {
loggerDebug.Info("Running AdmitRequest plugin", "plugin", plugin.TypedName())
if denyReason := plugin.AdmitRequest(ctx, request, pods); denyReason != nil {
loggerDebug.Info("AdmitRequest plugin denied the request", "plugin", plugin.TypedName(), "reason", denyReason.Error())
return false
}
loggerDebug.Info("Completed running AdmitRequest plugin successfully", "plugin", plugin.TypedName())
}
return true
}

func (d *Director) runResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
for _, plugin := range d.requestControlPlugins.responseReceivedPlugins {
Expand Down
Loading