diff --git a/Dockerfile b/Dockerfile index d2fd2300b..70038a077 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,6 +24,7 @@ COPY internal ./internal COPY apix ./apix COPY api ./api COPY version ./version +COPY sidecars ./sidecars WORKDIR /src/cmd/epp RUN go build -ldflags="-X sigs.k8s.io/gateway-api-inference-extension/version.CommitSHA=${COMMIT_SHA} -X sigs.k8s.io/gateway-api-inference-extension/version.BuildRef=${BUILD_REF}" -o /epp diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index da81fdf46..79d90f24c 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -19,6 +19,7 @@ package runner import ( "context" "crypto/tls" + "encoding/json" "errors" "flag" "fmt" @@ -61,6 +62,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/profile" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/scorer" @@ -68,6 +70,7 @@ import ( runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" "sigs.k8s.io/gateway-api-inference-extension/version" ) @@ -108,6 +111,7 @@ var ( "then a self-signed certificate is used.") // metric flags totalQueuedRequestsMetric = flag.String("total-queued-requests-metric", runserver.DefaultTotalQueuedRequestsMetric, "Prometheus metric for the number of queued requests.") + totalRunningRequestsMetric = flag.String("total-running-requests-metric", runserver.DefaultTotalRunningRequestsMetric, "Prometheus metric for the number of running requests.") kvCacheUsagePercentageMetric = flag.String("kv-cache-usage-percentage-metric", runserver.DefaultKvCacheUsagePercentageMetric, "Prometheus metric for the fraction of KV-cache blocks currently in use (from 0 to 1).") // LoRA metrics loraInfoMetric = flag.String("lora-info-metric", runserver.DefaultLoraInfoMetric, "Prometheus metric for the LoRA info metrics (must be in vLLM label format).") @@ -127,7 +131,10 @@ var ( modelServerMetricsScheme = flag.String("model-server-metrics-scheme", "http", "Scheme to scrape metrics from pods") modelServerMetricsHttpsInsecureSkipVerify = flag.Bool("model-server-metrics-https-insecure-skip-verify", true, "When using 'https' scheme for 'model-server-metrics-scheme', configure 'InsecureSkipVerify' (default to true)") haEnableLeaderElection = flag.Bool("ha-enable-leader-election", false, "Enables leader election for high availability. When enabled, readiness probes will only pass on the leader.") - tracing = flag.Bool("tracing", true, "Enables emitting traces") + + // Latency Predictor Flag + enableLatencyPredictor = flag.Bool("enable-latency-predictor", false, "Enable the regression-based latency predictor and scheduler scorer.") + tracing = flag.Bool("tracing", true, "Enables emitting traces") setupLog = ctrl.Log.WithName("setup") ) @@ -297,9 +304,29 @@ func (r *Runner) Run(ctx context.Context) error { runtime.SetBlockProfileRate(1) } - err = r.parsePluginsConfiguration(ctx, datastore) + // =================================================================== + // == Latency Predictor Integration + // =================================================================== + var predictor latencypredictor.PredictorInterface // Use the interface type + if *enableLatencyPredictor { + setupLog.Info("Latency predictor is enabled. Initializing...") + predictor = latencypredictor.New(latencypredictor.ConfigFromEnv(), ctrl.Log.WithName("latency-predictor")) + + // For the runnable, you'll need to type assert back to the concrete type + concretePredictor := predictor.(*latencypredictor.Predictor) + if err := mgr.Add(runnable.NoLeaderElection(&predictorRunnable{predictor: concretePredictor})); err != nil { + setupLog.Error(err, "Failed to register latency predictor runnable") + return err + } + } else { + setupLog.Info("Latency predictor is disabled.") + predictor = nil // This will be a true nil interface + } + // =================================================================== + + err = r.parsePluginsConfiguration(ctx, predictor, datastore) if err != nil { - setupLog.Error(err, "Failed to parse plugins configuration") + setupLog.Error(err, "Failed to parse the configuration") return err } @@ -368,6 +395,7 @@ func (r *Runner) Run(ctx context.Context) error { Director: director, SaturationDetector: saturationDetector, UseExperimentalDatalayerV2: useDatalayerV2, // pluggable data layer feature flag + LatencyPredictor: predictor, } if err := serverRunner.SetupWithManager(ctx, mgr); err != nil { setupLog.Error(err, "Failed to setup EPP controllers") @@ -410,7 +438,14 @@ func (r *Runner) registerInTreePlugins() { plugins.Register(testfilter.HeaderBasedTestingFilterType, testfilter.HeaderBasedTestingFilterFactory) } -func (r *Runner) parsePluginsConfiguration(ctx context.Context, ds datastore.Datastore) error { +func (r *Runner) registerLatencyPredictorPlugins(predictor latencypredictor.PredictorInterface) { + plugins.Register(slo_aware_router.SLOAwareRouterPluginType, func(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + return slo_aware_router.NewSLOAwareRouter(predictor, slo_aware_router.HeadroomSelectionStrategy).WithName(name), nil + }) + plugins.Register(profile.SLOAwareProfileHandlerType, profile.SLOAwareProfileHandlerFactory) +} + +func (r *Runner) parsePluginsConfiguration(ctx context.Context, predictor latencypredictor.PredictorInterface, ds datastore.Datastore) error { if *configText == "" && *configFile == "" { return nil // configuring through code, not through file } @@ -429,6 +464,12 @@ func (r *Runner) parsePluginsConfiguration(ctx context.Context, ds datastore.Dat } r.registerInTreePlugins() + // If we have a latency predictor enabled and predictor and datastore are not nil, + // register the latency predictor plugins (currently just the SLO scorer). + if *enableLatencyPredictor && predictor != nil { + setupLog.Info("Registering latency predictor plugins") + r.registerLatencyPredictorPlugins(predictor) + } handle := plugins.NewEppHandle(ctx, ds.PodList) config, err := loader.LoadConfig(configBytes, handle, logger) @@ -459,6 +500,7 @@ func (r *Runner) setupMetricsCollection(setupLog logr.Logger, useExperimentalDat func setupMetricsV1(setupLog logr.Logger) (datalayer.EndpointFactory, error) { mapping, err := backendmetrics.NewMetricMapping( *totalQueuedRequestsMetric, + *totalRunningRequestsMetric, *kvCacheUsagePercentageMetric, *loraInfoMetric, *cacheInfoMetric, @@ -502,6 +544,7 @@ func setupDatalayer() (datalayer.EndpointFactory, error) { *modelServerMetricsHttpsInsecureSkipVerify, nil) extractor, err := dlmetrics.NewExtractor(*totalQueuedRequestsMetric, + *totalRunningRequestsMetric, *kvCacheUsagePercentageMetric, *loraInfoMetric, *cacheInfoMetric) @@ -613,3 +656,21 @@ func setupPprofHandlers(mgr ctrl.Manager) error { } return nil } + +// =================================================================== +// == Latency Predictor Plugin and Helpers +// =================================================================== + +// predictorRunnable implements controller-runtime's Runnable interface to manage the predictor's lifecycle. +type predictorRunnable struct { + predictor *latencypredictor.Predictor +} + +func (p *predictorRunnable) Start(ctx context.Context) error { + setupLog.Info("Starting latency predictor...") + p.predictor.Start(ctx) + <-ctx.Done() + setupLog.Info("Stopping latency predictor...") + p.predictor.Stop() + return nil +} diff --git a/go.mod b/go.mod index ab54e64f3..c32d4b5f8 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/elastic/crd-ref-docs v0.2.0 github.com/envoyproxy/go-control-plane/envoy v1.35.0 github.com/go-logr/logr v1.4.3 + github.com/go-logr/zapr v1.3.0 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/hashicorp/golang-lru/v2 v2.0.7 @@ -61,7 +62,6 @@ require ( github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/go-logr/zapr v1.3.0 // indirect github.com/go-openapi/jsonpointer v0.21.2 // indirect github.com/go-openapi/jsonreference v0.21.0 // indirect github.com/go-openapi/swag v0.23.1 // indirect diff --git a/pkg/epp/backend/metrics/metrics.go b/pkg/epp/backend/metrics/metrics.go index 5d2a85e96..d2d9b60f9 100644 --- a/pkg/epp/backend/metrics/metrics.go +++ b/pkg/epp/backend/metrics/metrics.go @@ -97,6 +97,15 @@ func (p *PodMetricsClientImpl) promToPodMetrics( } } + if p.MetricMapping.TotalRunningRequests != nil { + running, err := p.getMetric(metricFamilies, *p.MetricMapping.TotalRunningRequests) + if err == nil { + updated.RunningQueueSize = int(running.GetGauge().GetValue()) + } else { + errs = multierr.Append(errs, err) + } + } + if p.MetricMapping.KVCacheUtilization != nil { usage, err := p.getMetric(metricFamilies, *p.MetricMapping.KVCacheUtilization) if err == nil { diff --git a/pkg/epp/backend/metrics/metrics_spec.go b/pkg/epp/backend/metrics/metrics_spec.go index 7407f4ed7..b3c26db2c 100644 --- a/pkg/epp/backend/metrics/metrics_spec.go +++ b/pkg/epp/backend/metrics/metrics_spec.go @@ -29,10 +29,11 @@ type MetricSpec struct { // MetricMapping holds named MetricSpecs. type MetricMapping struct { - TotalQueuedRequests *MetricSpec - KVCacheUtilization *MetricSpec - LoraRequestInfo *MetricSpec - CacheConfigInfo *MetricSpec + TotalQueuedRequests *MetricSpec + TotalRunningRequests *MetricSpec + KVCacheUtilization *MetricSpec + LoraRequestInfo *MetricSpec + CacheConfigInfo *MetricSpec } // stringToMetricSpec converts a string to a MetricSpec. @@ -94,11 +95,15 @@ func stringToMetricSpec(specStr string) (*MetricSpec, error) { } // NewMetricMapping creates a MetricMapping from string values. -func NewMetricMapping(queuedStr, kvUsageStr, loraReqInfoStr, cacheInfoMetric string) (*MetricMapping, error) { +func NewMetricMapping(queuedStr, runningStr, kvUsageStr, loraReqInfoStr, cacheInfoMetric string) (*MetricMapping, error) { queuedSpec, err := stringToMetricSpec(queuedStr) if err != nil { return nil, fmt.Errorf("error parsing WaitingRequests: %w", err) } + runningSpec, err := stringToMetricSpec(runningStr) + if err != nil { + return nil, fmt.Errorf("error parsing RunningRequests: %w", err) + } kvUsageSpec, err := stringToMetricSpec(kvUsageStr) if err != nil { return nil, fmt.Errorf("error parsing KVCacheUsage: %w", err) @@ -114,10 +119,11 @@ func NewMetricMapping(queuedStr, kvUsageStr, loraReqInfoStr, cacheInfoMetric str } mapping := &MetricMapping{ - TotalQueuedRequests: queuedSpec, - KVCacheUtilization: kvUsageSpec, - LoraRequestInfo: loraReqInfoSpec, - CacheConfigInfo: cacheInfoSpec, + TotalQueuedRequests: queuedSpec, + TotalRunningRequests: runningSpec, + KVCacheUtilization: kvUsageSpec, + LoraRequestInfo: loraReqInfoSpec, + CacheConfigInfo: cacheInfoSpec, } return mapping, nil diff --git a/pkg/epp/datalayer/metrics/extractor.go b/pkg/epp/datalayer/metrics/extractor.go index f6142c494..95baec1b9 100644 --- a/pkg/epp/datalayer/metrics/extractor.go +++ b/pkg/epp/datalayer/metrics/extractor.go @@ -64,8 +64,8 @@ func Produces() map[string]any { // configured with the given metrics' specifications. // These are mandatory metrics per the MSP specification, and are used // as the basis for the built-in scheduling plugins. -func NewExtractor(queueSpec, kvusageSpec, loraSpec, cacheInfoSpec string) (*Extractor, error) { - mapping, err := NewMapping(queueSpec, kvusageSpec, loraSpec, cacheInfoSpec) +func NewExtractor(queueSpec, runningSpec, kvusageSpec, loraSpec, cacheInfoSpec string) (*Extractor, error) { + mapping, err := NewMapping(queueSpec, runningSpec, kvusageSpec, loraSpec, cacheInfoSpec) if err != nil { return nil, fmt.Errorf("failed to create extractor metrics Mapping - %w", err) } @@ -107,6 +107,15 @@ func (ext *Extractor) Extract(ctx context.Context, data any, ep datalayer.Endpoi } } + if spec := ext.mapping.TotalRunningRequests; spec != nil { // extract running requests + if metric, err := spec.getLatestMetric(families); err != nil { + errs = append(errs, err) + } else { + clone.RunningQueueSize = int(extractValue(metric)) + updated = true + } + } + if spec := ext.mapping.KVCacheUtilization; spec != nil { // extract KV cache usage if metric, err := spec.getLatestMetric(families); err != nil { errs = append(errs, err) diff --git a/pkg/epp/datalayer/metrics/mapping.go b/pkg/epp/datalayer/metrics/mapping.go index fab6cf75f..7b1fed9c1 100644 --- a/pkg/epp/datalayer/metrics/mapping.go +++ b/pkg/epp/datalayer/metrics/mapping.go @@ -23,20 +23,25 @@ import ( // Mapping holds specifications for the well-known metrics defined // in the Model Server Protocol. type Mapping struct { - TotalQueuedRequests *Spec - KVCacheUtilization *Spec - LoraRequestInfo *LoRASpec - CacheInfo *Spec + TotalQueuedRequests *Spec + TotalRunningRequests *Spec + KVCacheUtilization *Spec + LoraRequestInfo *LoRASpec + CacheInfo *Spec } // NewMapping creates a metrics.Mapping from the input specification strings. -func NewMapping(queue, kvusage, lora, cacheInfo string) (*Mapping, error) { +func NewMapping(queue, running, kvusage, lora, cacheInfo string) (*Mapping, error) { var errs []error queueSpec, err := parseStringToSpec(queue) if err != nil { errs = append(errs, err) } + runningSpec, err := parseStringToSpec(running) + if err != nil { + errs = append(errs, err) + } kvusageSpec, err := parseStringToSpec(kvusage) if err != nil { errs = append(errs, err) @@ -55,9 +60,10 @@ func NewMapping(queue, kvusage, lora, cacheInfo string) (*Mapping, error) { return nil, errors.Join(errs...) } return &Mapping{ - TotalQueuedRequests: queueSpec, - KVCacheUtilization: kvusageSpec, - LoraRequestInfo: loraSpec, - CacheInfo: cacheInfoSpec, + TotalQueuedRequests: queueSpec, + TotalRunningRequests: runningSpec, + KVCacheUtilization: kvusageSpec, + LoraRequestInfo: loraSpec, + CacheInfo: cacheInfoSpec, }, nil } diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index dade69469..5dcd0f4a0 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -63,7 +63,7 @@ type Datastore interface { // PodList lists pods matching the given predicate. PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool - PodDelete(podNAme string) + PodDelete(podName string) // Clears the store state, happens when the pool gets deleted. Clear() diff --git a/pkg/epp/metrics/metrics.go b/pkg/epp/metrics/metrics.go index 59c8976cd..af422f2b5 100644 --- a/pkg/epp/metrics/metrics.go +++ b/pkg/epp/metrics/metrics.go @@ -63,6 +63,193 @@ var ( []string{"model_name", "target_model_name", "error_code"}, ) + requestTTFT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model TTFT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.005, 0.025, 0.05, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 2, 3, + 4, 5, 6, 8, 10, 15, 20, 30, 45, 60, 120, 180, 240, 300, 360, 480, 600, 900, 1200, 1800, 2700, 3600, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestTTFTGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Inference model TTFT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestPredictedTTFT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_predicted_ttft_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model Predicted TTFT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.005, 0.025, 0.05, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 2, 3, + 4, 5, 6, 8, 10, 15, 20, 30, 45, 60, 120, 180, 240, 300, 360, 480, 600, 900, 1200, 1800, 2700, 3600, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestPredictedTTFTGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_predicted_ttft_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Inference model Predicted TTFT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + // New metrics for TTFT prediction duration + requestTTFTPredictionDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_prediction_duration_seconds", + Help: metricsutil.HelpMsgWithStability("Duration taken to generate TTFT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestTTFTPredictionDurationGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_prediction_duration_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Latest duration taken to generate TTFT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model TPOT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0005, 0.00205, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.125, 0.15, 0.2, 0.3, + 0.4, 0.5, 0.6, 0.8, 1, 1.5, 2, 3, 4.5, 6, 12, 18, 24, 30, 36, 48, 60, 90, 120, 180, 270, 360, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOTGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Inference model TPOT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + requestPredictedTPOT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_predicted_tpot_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model Predicted TPOT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0005, 0.00205, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.125, 0.15, 0.2, 0.3, + 0.4, 0.5, 0.6, 0.8, 1, 1.5, 2, 3, 4.5, 6, 12, 18, 24, 30, 36, 48, 60, 90, 120, 180, 270, 360, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestPredictedTPOTGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_predicted_tpot_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Inference model Predicted TPOT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + // New metrics for TPOT prediction duration + requestTPOTPredictionDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_prediction_duration_seconds", + Help: metricsutil.HelpMsgWithStability("Duration taken to generate TPOT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOTPredictionDurationGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_prediction_duration_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Latest duration taken to generate TPOT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + // SLO Violation Metrics + requestTTFTSLOViolation = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_slo_violation", + Help: metricsutil.HelpMsgWithStability("Boolean indicator (0 or 1) of whether the last TTFT measurement violated the SLO threshold for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestTTFTSLOViolationCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_slo_violation_total", + Help: metricsutil.HelpMsgWithStability("Counter of TTFT SLO violations for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOTSLOViolation = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_slo_violation", + Help: metricsutil.HelpMsgWithStability("Boolean indicator (0 or 1) of whether the last TPOT measurement violated the SLO threshold for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOTSLOViolationCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_slo_violation_total", + Help: metricsutil.HelpMsgWithStability("Counter of TPOT SLO violations for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + // SLO threshold gauges (for dynamic threshold management) + requestTTFTSLOThreshold = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_slo_threshold_seconds", + Help: metricsutil.HelpMsgWithStability("Current TTFT SLO threshold in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOTSLOThreshold = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_slo_threshold_seconds", + Help: metricsutil.HelpMsgWithStability("Current TPOT SLO threshold in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + requestLatencies = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Subsystem: InferenceObjectiveComponent, @@ -282,6 +469,32 @@ var registerMetrics sync.Once // Register all metrics. func Register(customCollectors ...prometheus.Collector) { registerMetrics.Do(func() { + metrics.Registry.MustRegister(requestTPOT) + metrics.Registry.MustRegister(requestTTFT) + + metrics.Registry.MustRegister(requestTPOTGauge) + metrics.Registry.MustRegister(requestTTFTGauge) + + metrics.Registry.MustRegister(requestPredictedTPOT) + metrics.Registry.MustRegister(requestPredictedTTFT) + + metrics.Registry.MustRegister(requestPredictedTPOTGauge) + metrics.Registry.MustRegister(requestPredictedTTFTGauge) + + // Register new prediction duration metrics + metrics.Registry.MustRegister(requestTPOTPredictionDuration) + metrics.Registry.MustRegister(requestTPOTPredictionDurationGauge) + metrics.Registry.MustRegister(requestTTFTPredictionDuration) + metrics.Registry.MustRegister(requestTTFTPredictionDurationGauge) + + // Register SLO violation metrics + metrics.Registry.MustRegister(requestTTFTSLOViolation) + metrics.Registry.MustRegister(requestTTFTSLOViolationCounter) + metrics.Registry.MustRegister(requestTPOTSLOViolation) + metrics.Registry.MustRegister(requestTPOTSLOViolationCounter) + metrics.Registry.MustRegister(requestTTFTSLOThreshold) + metrics.Registry.MustRegister(requestTPOTSLOThreshold) + metrics.Registry.MustRegister(requestCounter) metrics.Registry.MustRegister(requestErrCounter) metrics.Registry.MustRegister(requestLatencies) @@ -332,6 +545,30 @@ func Reset() { PrefixCacheHitLength.Reset() flowControlRequestQueueDuration.Reset() flowControlQueueSize.Reset() + + requestTPOT.Reset() + requestTTFT.Reset() + requestTPOTGauge.Reset() + requestTTFTGauge.Reset() + + requestPredictedTPOT.Reset() + requestPredictedTTFT.Reset() + requestPredictedTPOTGauge.Reset() + requestPredictedTTFTGauge.Reset() + + // Reset new prediction duration metrics + requestTPOTPredictionDuration.Reset() + requestTPOTPredictionDurationGauge.Reset() + requestTTFTPredictionDuration.Reset() + requestTTFTPredictionDurationGauge.Reset() + + // Reset SLO violation metrics + requestTTFTSLOViolation.Reset() + requestTTFTSLOViolationCounter.Reset() + requestTPOTSLOViolation.Reset() + requestTPOTSLOViolationCounter.Reset() + requestTTFTSLOThreshold.Reset() + requestTPOTSLOThreshold.Reset() } // RecordRequstCounter records the number of requests. @@ -363,6 +600,123 @@ func RecordRequestLatencies(ctx context.Context, modelName, targetModelName stri return true } +func RecordRequestTPOT(ctx context.Context, modelName, targetModelName string, tpot float64) bool { + if tpot < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TPOT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "tpot", tpot) + return false + } + requestTPOT.WithLabelValues(modelName, targetModelName).Observe(tpot) + requestTPOTGauge.WithLabelValues(modelName, targetModelName).Set(tpot) + return true +} + +// RecordRequestTPOTWithSLO records TPOT and checks for SLO violation. +// If tpot exceeds the threshold, it records a violation (sets gauge to 1 and increments counter). +// If tpot is within limits, it sets gauge to 0. +func RecordRequestTPOTWithSLO(ctx context.Context, modelName, targetModelName string, tpot float64, sloThreshold float64) bool { + if tpot < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TPOT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "tpot", tpot) + return false + } + + // Check for SLO violation (tpot exceeds threshold) + if tpot > sloThreshold { + requestTPOTSLOViolation.WithLabelValues(modelName, targetModelName).Set(1) + requestTPOTSLOViolationCounter.WithLabelValues(modelName, targetModelName).Inc() + log.FromContext(ctx).V(logutil.DEFAULT).Info("TPOT SLO violation detected", + "modelName", modelName, "targetModelName", targetModelName, "tpot", tpot, "threshold", sloThreshold) + } else { + requestTPOTSLOViolation.WithLabelValues(modelName, targetModelName).Set(0) + } + + return true +} + +// TPOT records duration of request. +func RecordRequestPredictedTPOT(ctx context.Context, modelName, targetModelName string, predicted_tpot float64) bool { + if predicted_tpot < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "Predicted TPOT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "tpot", predicted_tpot) + return false + } + requestPredictedTPOT.WithLabelValues(modelName, targetModelName).Observe(predicted_tpot) + requestPredictedTPOTGauge.WithLabelValues(modelName, targetModelName).Set(predicted_tpot) + return true +} + +// RecordRequestTPOTPredictionDuration records the duration taken to generate TPOT predictions. +func RecordRequestTPOTPredictionDuration(ctx context.Context, modelName, targetModelName string, duration float64) bool { + if duration < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TPOT prediction duration must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "duration", duration) + return false + } + requestTPOTPredictionDuration.WithLabelValues(modelName, targetModelName).Observe(duration) + requestTPOTPredictionDurationGauge.WithLabelValues(modelName, targetModelName).Set(duration) + return true +} + +// TTFT records duration of request. +func RecordRequestTTFT(ctx context.Context, modelName, targetModelName string, ttft float64) bool { + if ttft < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TTFT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "ttft", ttft) + return false + } + requestTTFT.WithLabelValues(modelName, targetModelName).Observe(ttft) + requestTTFTGauge.WithLabelValues(modelName, targetModelName).Set(ttft) + return true +} + +// RecordRequestTTFTWithSLO records TTFT and checks for SLO violation. +// If ttft exceeds the threshold, it records a violation (sets gauge to 1 and increments counter). +// If ttft is within limits, it sets gauge to 0. +func RecordRequestTTFTWithSLO(ctx context.Context, modelName, targetModelName string, ttft float64, sloThreshold float64) bool { + if ttft < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TTFT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "ttft", ttft) + return false + } + + // Check for SLO violation (ttft exceeds threshold) + if ttft > sloThreshold { + requestTTFTSLOViolation.WithLabelValues(modelName, targetModelName).Set(1) + requestTTFTSLOViolationCounter.WithLabelValues(modelName, targetModelName).Inc() + log.FromContext(ctx).V(logutil.DEFAULT).Info("TTFT SLO violation detected", + "modelName", modelName, "targetModelName", targetModelName, "ttft", ttft, "threshold", sloThreshold) + } else { + requestTTFTSLOViolation.WithLabelValues(modelName, targetModelName).Set(0) + } + + return true +} + +// TPOT records duration of request. +func RecordRequestPredictedTTFT(ctx context.Context, modelName, targetModelName string, predicted_ttft float64) bool { + if predicted_ttft < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "Predicted TTFT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "ttft", predicted_ttft) + return false + } + requestPredictedTTFT.WithLabelValues(modelName, targetModelName).Observe(predicted_ttft) + requestPredictedTTFTGauge.WithLabelValues(modelName, targetModelName).Set(predicted_ttft) + return true +} + +// RecordRequestTTFTPredictionDuration records the duration taken to generate TTFT predictions. +func RecordRequestTTFTPredictionDuration(ctx context.Context, modelName, targetModelName string, duration float64) bool { + if duration < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TTFT prediction duration must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "duration", duration) + return false + } + requestTTFTPredictionDuration.WithLabelValues(modelName, targetModelName).Observe(duration) + requestTTFTPredictionDurationGauge.WithLabelValues(modelName, targetModelName).Set(duration) + return true +} + // RecordResponseSizes records the response sizes. func RecordResponseSizes(modelName, targetModelName string, size int) { responseSizes.WithLabelValues(modelName, targetModelName).Observe(float64(size)) @@ -480,3 +834,15 @@ func IncFlowControlQueueSize(fairnessID, priority string) { func DecFlowControlQueueSize(fairnessID, priority string) { flowControlQueueSize.WithLabelValues(fairnessID, priority).Dec() } + +// SetTTFTSLOThreshold sets the TTFT SLO threshold for a model. +// This allows dynamic threshold management and makes the threshold visible in metrics. +func SetTTFTSLOThreshold(modelName, targetModelName string, threshold float64) { + requestTTFTSLOThreshold.WithLabelValues(modelName, targetModelName).Set(threshold) +} + +// SetTPOTSLOThreshold sets the TPOT SLO threshold for a model. +// This allows dynamic threshold management and makes the threshold visible in metrics. +func SetTPOTSLOThreshold(modelName, targetModelName string, threshold float64) { + requestTPOTSLOThreshold.WithLabelValues(modelName, targetModelName).Set(threshold) +} diff --git a/pkg/epp/metrics/metrics_test.go b/pkg/epp/metrics/metrics_test.go index 7d4168183..754d6d294 100644 --- a/pkg/epp/metrics/metrics_test.go +++ b/pkg/epp/metrics/metrics_test.go @@ -46,6 +46,8 @@ const ( KVCacheAvgUsageMetric = InferencePoolComponent + "_average_kv_cache_utilization" QueueAvgSizeMetric = InferencePoolComponent + "_average_queue_size" PerPodQueueSizeMetrics = InferencePoolComponent + "_per_pod_queue_size" + RequestTTFTSecondsMetric = InferenceObjectiveComponent + "_request_ttft_seconds" + RequestTPOTSecondsMetric = InferenceObjectiveComponent + "_request_tpot_seconds" ) func TestMain(m *testing.M) { diff --git a/pkg/epp/metrics/testdata/request_tpot_seconds_metric b/pkg/epp/metrics/testdata/request_tpot_seconds_metric new file mode 100644 index 000000000..beee50271 --- /dev/null +++ b/pkg/epp/metrics/testdata/request_tpot_seconds_metric @@ -0,0 +1,80 @@ +# HELP inference_model_request_tpot_seconds [ALPHA] Inference model response latency distribution in seconds for each model and target model. +# TYPE inference_model_request_tpot_seconds histogram +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.0005"} 0 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.0025"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.005"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.01"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.02"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.04"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.06"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.08"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.1"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.125"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.15"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.2"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.3"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.4"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.5"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.6"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.8"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="1"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="1.5"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="2"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="3"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="4.5"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="6"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="12"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="18"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="24"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="30"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="36"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="48"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="60"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="90"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="120"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="180"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="270"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="360"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="Inf"} 2 +inference_model_request_tpot_seconds_sum{model_name="m20", target_model_name="t10"} 0.161 +inference_model_request_tpot_seconds_count{model_name="m20", target_model_name="t10"} 2 + + +iinference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.0005"} 0 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.0025"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.005"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.01"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.02"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.04"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.06"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.08"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.1"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.125"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.15"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.2"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.3"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.4"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.5"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.6"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.8"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="1"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="1.5"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="2"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="3"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="4.5"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="6"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="12"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="18"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="24"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="30"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="36"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="48"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="60"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="90"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="120"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="180"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="270"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="360"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="Inf"} 2 +inference_model_request_tpot_seconds_sum{model_name="m20", target_model_name="t10"} 0.161 +inference_model_request_tpot_seconds_count{model_name="m20", target_model_name="t10"} 2 \ No newline at end of file diff --git a/pkg/epp/metrics/testdata/request_ttft_seconds_metric b/pkg/epp/metrics/testdata/request_ttft_seconds_metric new file mode 100644 index 000000000..315490727 --- /dev/null +++ b/pkg/epp/metrics/testdata/request_ttft_seconds_metric @@ -0,0 +1,116 @@ +# HELP inference_model_request_ttft_seconds [ALPHA] Inference model response latency distribution in seconds for each model and target model. +# TYPE inference_model_request_ttft_seconds histogram +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.005"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.025"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.05"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.1"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.2"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.4"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.6"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.8"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="1.0"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="1.25"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="1.5"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="2"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="3"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="4"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="5"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="6"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="8"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="10"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="15"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="20"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="30"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="45"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="60"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="120"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="180"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="240"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="300"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="360"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="480"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="600"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="900"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="1200"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="1800"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="2700"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="3600"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="Inf"} 2 +inference_model_request_ttft_seconds_sum{model_name="m10", target_model_name="t10"} 1.61 +inference_model_request_ttft_seconds_count{model_name="m10", target_model_name="t10"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.005"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.025"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.05"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.1"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.2"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.4"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.6"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.8"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="1"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="1.25"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="1.5"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="2"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="3"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="4"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="5"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="6"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="8"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="10"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="15"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="20"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="30"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="45"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="60"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="120"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="180"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="240"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="300"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="360"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="480"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="600"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="900"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="1200"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="1800"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="2700"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="3600"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="+Inf"} 1 +inference_model_request_ttft_seconds_sum{model_name="m10",target_model_name="t11"} 0.06 +inference_model_request_ttft_seconds_count{model_name="m10",target_model_name="t11"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.005"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.025"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.05"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.1"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.2"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.4"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.6"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.8"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="1"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="1.25"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="1.5"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="2"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="3"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="4"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="5"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="6"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="8"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="10"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="15"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="20"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="30"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="45"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="60"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="120"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="180"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="240"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="300"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="360"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="480"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="600"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="900"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="1200"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="1800"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="2700"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="3600"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="+Inf"} 1 +inference_model_request_ttft_seconds_sum{model_name="m20",target_model_name="t20"} 0.12 +inference_model_request_ttft_seconds_count{model_name="m20",target_model_name="t20"} 1 diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index f6f7deebe..55eec0694 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -269,8 +269,9 @@ func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *hand logger := log.FromContext(ctx).WithValues("stage", "bodyChunk") logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk") response := &Response{ - RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], - Headers: reqCtx.Response.Headers, + RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], + Headers: reqCtx.Response.Headers, + EndOfStream: reqCtx.ResponseComplete, } d.runResponseStreamingPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod) diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go new file mode 100644 index 000000000..fcb4b7223 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go @@ -0,0 +1,191 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "os" + "strconv" + "strings" +) + +var DefaultSamplingMean = func() float64 { + if value, exists := os.LookupEnv("SAMPLING_MEAN"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue > 0 { + return parsedValue + } + } + return 100.0 // default value +}() + +var MaxSampledTokens = func() int { + if value, exists := os.LookupEnv("MAX_SAMPLED_TOKENS"); exists { + if parsedValue, err := strconv.Atoi(value); err == nil && parsedValue > 0 { + return parsedValue + } + } + return 20 // default value +}() + +var SLOBufferFactor = func() float64 { + if value, exists := os.LookupEnv("SLO_BUFFER_FACTOR"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil { + return parsedValue + } + } + return 1.0 // default value +}() + +var NegHeadroomTTFTWeight = func() float64 { + if value, exists := os.LookupEnv("NEG_HEADROOM_TTFT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.8 // default: TTFT dominates when violating SLOs +}() + +var NegHeadroomTPOTWeight = func() float64 { + if value, exists := os.LookupEnv("NEG_HEADROOM_TPOT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.2 // default: TPOT less important in your tiny-output scenario +}() + +var HeadroomTTFTWeight = func() float64 { + if value, exists := os.LookupEnv("HEADROOM_TTFT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.8 // default +}() + +var HeadroomTPOTWeight = func() float64 { + if value, exists := os.LookupEnv("HEADROOM_TPOT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.2 // default +}() + +var HeadroomSelectionStrategy = func() HeadroomStrategy { + if value, exists := os.LookupEnv("HEADROOM_SELECTION_STRATEGY"); exists { + switch strings.ToLower(value) { + case "least": + return HeadroomStrategyLeast + case "most": + return HeadroomStrategyMost + case "composite-least": + return HeadroomStrategyCompositeLeast + case "composite-most": + return HeadroomStrategyCompositeMost + case "composite-only": + return HeadroomStrategyCompositeOnly + } + } + return HeadroomStrategyLeast // default to least (better packing) +}() + +// If using composite headroom, weights for each component. Not used by default +var CompositeKVWeight = func() float64 { + if v, ok := os.LookupEnv("COMPOSITE_KV_WEIGHT"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 { + return f + } + } + return 1 +}() + +var CompositeQueueWeight = func() float64 { + if v, ok := os.LookupEnv("COMPOSITE_QUEUE_WEIGHT"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 { + return f + } + } + return 1 +}() + +var CompositePrefixWeight = func() float64 { + if v, ok := os.LookupEnv("COMPOSITE_PREFIX_WEIGHT"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 { + return f + } + } + return 1 +}() + +// With probability ε, explore (ignore affinity gate); otherwise exploit. +var EpsilonExploreSticky = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("STICKY_EPSILON"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.01 // default 1% exploration +}() + +var EpsilonExploreNeg = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("NEG_HEADROOM_EPSILON"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.01 // default 1% exploration +}() + +// τ for per-path affinity gate (aka "stickiness" threshold). +var AffinityGateTau = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("AFFINITY_GATE_TAU"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.80 +}() + +// Global τ for the overall candidate set (previously "overall stickiness"). +var AffinityGateTauGlobal = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("AFFINITY_GATE_TAU_GLOBAL"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.99 +}() + +// Read once at init. Values: "linear" (default) or "max". +var SelectionMode = func() PodSelectionMode { + if v, ok := os.LookupEnv("POD_SELECTION_MODE"); ok { + switch strings.ToLower(v) { + case "max": + return PodSelectionMax + case "linear": + fallthrough + default: + return PodSelectionLinear + } + } + return PodSelectionLinear +}() diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go new file mode 100644 index 000000000..8574ec41b --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go @@ -0,0 +1,70 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "fmt" + "strconv" + + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" +) + +// parseFloatHeader retrieves a header by name, parses it as a float64, +// and returns the value or an error if the header is missing or invalid. +func parseFloatHeader(request schedulingtypes.LLMRequest, headerName string) (float64, bool, error) { + // 1. Get header value from the map + headerValue, ok := request.Headers[headerName] + if !ok { + return 0, false, nil // Header not found, return 0 and false + } + + // 2. Parse the header value to a float64 + parsedFloat, err := strconv.ParseFloat(headerValue, 64) + if err != nil { + return 0, false, errutil.Error{ + Code: errutil.BadRequest, + Msg: fmt.Sprintf("%s must be a float", headerName), + } + } + + // 3. Return the successfully parsed value + return parsedFloat, true, nil +} + +// parseFloatHeader retrieves a header by name, parses it as a bool, +// and returns the value or an error if the header is missing or invalid. +func parseBoolHeader(request schedulingtypes.LLMRequest, headerName string) (bool, error) { + // 1. Get header value from the map + headerValue, ok := request.Headers[headerName] + if !ok { + return false, nil // Header not found, return 0 and false + } + + // 2. Parse the header value to a bool + parsedBool, err := strconv.ParseBool(headerValue) + if err != nil { + return false, errutil.Error{ + Code: errutil.BadRequest, + Msg: fmt.Sprintf("%s must be a bool", headerName), + } + } + + // 3. Return the successfully parsed value + return parsedBool, nil +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go new file mode 100644 index 000000000..1d5568243 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go @@ -0,0 +1,145 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "math" + "math/rand" + + "sigs.k8s.io/controller-runtime/pkg/log" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func (s *SLOAwareRouter) selectFromCompositeScores(ctx context.Context, allPreds []PodPredictionResult, r *rand.Rand, strategy HeadroomStrategy) schedulingtypes.Pod { + total := 0 + choices := s.buildCompositeChoices( + ctx, allPreds, CompositeKVWeight, CompositeQueueWeight, CompositePrefixWeight, &total, + ) + if strategy == HeadroomStrategyCompositeLeast { + // Invert weights for "least" strategy + for i := range choices { + choices[i].Weight = minWeight + Wmax - choices[i].Weight + } + } + selectedPod := s.performWeightedRandomSelection(choices, total, allPreds, r) + return selectedPod +} +func (s *SLOAwareRouter) performWeightedRandomSelection(weightedChoices []Choice, total int, candidates []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + if total == 0 { + return nil + } + logger := log.FromContext(context.Background()) + // Check if MAX_SCORE_SELECTION env variable is set + if SelectionMode == PodSelectionMax { + + logger.V(logutil.DEBUG).Info("Pod selection mode: MAX - selecting pod with highest weight") + maxWeight := 0 + var selectedPod schedulingtypes.Pod + for _, c := range weightedChoices { + if c.Weight > maxWeight { + maxWeight = c.Weight + selectedPod = c.PodName + } + } + if selectedPod != nil { + return selectedPod + } + // Fallback to first pod if no selection made + return candidates[0].Pod + } + + // Original weighted random selection logic + logger.V(logutil.DEBUG).Info("Pod selection mode: LINEAR - performing weighted random selection") + idx := r.Intn(total) + var selectedPod schedulingtypes.Pod + + for _, c := range weightedChoices { + if idx < c.Weight { + selectedPod = c.PodName + break + } + idx -= c.Weight + } + + // If no pod was selected (shouldn't happen), fallback to first pod + if selectedPod == nil { + selectedPod = candidates[0].Pod + } + + return selectedPod +} +func (s *SLOAwareRouter) buildCompositeChoices( + ctx context.Context, + candidates []PodPredictionResult, + wkv, wq, wpref float64, + total *int, +) []Choice { + + // Normalize weights + sumw := wkv + wq + wpref + if sumw <= 0 { + wkv, wq, wpref = 1, 0, 0 + } else { + wkv /= sumw + wq /= sumw + wpref /= sumw + } + + // Precompute queue stats + minQ, maxQ := math.MaxInt32, -1 + queueCounts := make(map[string]int, len(candidates)) + for _, p := range candidates { + q := p.Pod.GetMetrics().WaitingQueueSize + queueCounts[p.Pod.GetPod().String()] = q + if q < minQ { + minQ = q + } + if q > maxQ { + maxQ = q + } + } + den := float64(maxQ - minQ) + + choices := make([]Choice, 0, len(candidates)) + for _, p := range candidates { + q := queueCounts[p.Pod.GetPod().String()] + relQueue := 1.0 + if den > 0 { + relQueue = (float64(maxQ-q) / den) + } + + kvUsage := p.Pod.GetMetrics().KVCacheUsagePercent + kvFree := (1.0 - kvUsage) + prefix := (p.PrefixCacheScore) + + composite := wkv*kvFree + wq*relQueue + wpref*prefix + w := int(math.Round(float64(minWeight) + (float64(Wmax-minWeight) * composite))) + *total += w + choices = append(choices, Choice{PodName: p.Pod, Weight: w}) + + log.FromContext(ctx).V(logutil.TRACE).Info("Composite (neg/pos) score", + "pod", p.Pod.GetPod().String(), + "kvUsage", kvUsage, "kvFree", kvFree, + "queue", q, "relQueue", relQueue, + "prefix", prefix, + "wkv", wkv, "wq", wq, "wprefix", wpref, + "composite", composite, "weight", w) + } + return choices +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go new file mode 100644 index 000000000..aa47f93c9 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go @@ -0,0 +1,439 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "context" + "fmt" + "strings" + "time" + + "sigs.k8s.io/controller-runtime/pkg/log" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +// RefreshLastSeenMetrics updates sloCtx.LastSeenMetrics from the latest scheduling result. +func RefreshLastSeenMetrics(ctx context.Context, sloCtx *SLORequestContext) { + if sr := sloCtx.SchedulingResult; sr != nil { + if pr := sr.ProfileResults[sr.PrimaryProfileName]; pr != nil && pr.TargetPods != nil { + for profileName, profileResult := range sr.ProfileResults { + if profileResult != nil && profileResult.TargetPods != nil && len(profileResult.TargetPods) > 0 { + sloCtx.LastSeenMetrics[profileName] = profileResult.TargetPods[0].GetMetrics().Clone() + } + } + } + } else { + log.FromContext(ctx).V(logutil.DEBUG).Info("No scheduling result found, skipping metrics refresh") + } +} + +// GetMetricsForPrediction retrieves the latest metrics for prediction from sloCtx.LastSeenMetrics. +func GetLatestMetricsForProfile(ctx context.Context, sloCtx *SLORequestContext) (*backendmetrics.MetricsState, error) { + if len(sloCtx.LastSeenMetrics) == 0 { + return nil, fmt.Errorf("no last seen metrics available for prediction") + } + + primaryProfileName := sloCtx.SchedulingResult.PrimaryProfileName + if metrics, exists := sloCtx.LastSeenMetrics[primaryProfileName]; exists { + return metrics, nil + } + + return nil, fmt.Errorf("no metrics found for primary profile %s", primaryProfileName) +} + +// ProcessHeader refreshes metrics, applies TTFT prediction, updates sloCtx.PredictedTTFT and timestamp. +func ProcessHeaderForLatencyPrediction( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *SLORequestContext, +) error { + logger := log.FromContext(ctx) + + //just for debugging, print the req context scheduling result cycle state + //print the raw scores in scheduling result + + // Build prediction request + m, err := GetLatestMetricsForProfile(ctx, sloCtx) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err) + return err + } + + targetPod := sloCtx.TargetPod + prefix_cache_score := sloCtx.PrefixCacheScoresForPods[targetPod.String()] + + in := latencypredictor.PredictionRequest{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: 0, + PrefixCacheScore: prefix_cache_score, + } + + // Predict TTFT + start := time.Now() + p, err := predictor.Predict(ctx, in) + dur := time.Since(start) + if err != nil { + logger.V(logutil.DEBUG).Error(err, "header TTFT predict failed", "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTTFT = 0 + } else if p == nil { + logger.V(logutil.DEBUG).Info("header TTFT predict nil", "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTTFT = 0 + } else { + logger.V(logutil.DEBUG).Info("header TTFT succeeded", "value_ms", p.TTFT, "duration_ms", dur.Milliseconds()) + metrics.RecordRequestTTFTPredictionDuration(ctx, sloCtx.SchedulingRequest.TargetModel, sloCtx.IncomingModelName, dur.Seconds()) + + sloCtx.PredictedTTFT = p.TTFT + } + + // Advance timestamp for first token reference + sloCtx.LastTokenTimestamp = time.Now() + RefreshLastSeenMetrics(ctx, sloCtx) + return err +} + +// ProcessFirstToken records actual TTFT, trains, predicts first TPOT, updates sloCtx, and advances timestamp. +func ProcessFirstTokenForLatencyPrediction( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *SLORequestContext, + now time.Time, +) { + logger := log.FromContext(ctx) + + // Initialize sampler + if sloCtx.TokenSampler == nil { + requestID := sloCtx.SchedulingRequest.Headers[requtil.RequestIdHeaderKey] + sloCtx.TokenSampler = NewTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens) + logger.V(logutil.DEBUG).Info("Initialized token sampler for first token", "request_id", requestID, "next_prediction_token", sloCtx.TokenSampler.GetNextSampleToken()) + } + + // Actual TTFT + sloCtx.TTFT = float64(now.Sub(sloCtx.RequestReceivedTimestamp).Milliseconds()) + sloCtx.GeneratedTokenCount = 1 + m, err := GetLatestMetricsForProfile(ctx, sloCtx) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err) + return + } + targetPod := sloCtx.TargetPod + prefix_cache_score := sloCtx.PrefixCacheScoresForPods[targetPod.String()] + + // Train TTFT + entry := latencypredictor.TrainingEntry{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + ActualTTFT: sloCtx.TTFT, + ActualTPOT: 0, + Timestamp: now, + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: 0, + PrefixCacheScore: prefix_cache_score, + } + if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil { + logger.V(logutil.DEBUG).Error(err, "record TTFT training failed") + } + m, err = GetLatestMetricsForProfile(ctx, sloCtx) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics", + "error", err) + return + } + + // Predict first TPOT + in := latencypredictor.PredictionRequest{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: sloCtx.GeneratedTokenCount, + PrefixCacheScore: 0, + } + start := time.Now() + p, err := predictor.Predict(ctx, in) + dur := time.Since(start) + if err != nil || p == nil { + logger.V(logutil.DEBUG).Error(err, "first TPOT predict failed", "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, 0) + sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, 0, len(sloCtx.PredictedTPOTObservations)) + } else { + logger.V(logutil.DEBUG).Info("first TPOT succeeded", "value_ms", p.TPOT, "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, p.TPOT) + sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, p.TPOT, len(sloCtx.PredictedTPOTObservations)) + } + metrics.RecordRequestTPOTPredictionDuration(ctx, sloCtx.SchedulingRequest.TargetModel, sloCtx.IncomingModelName, dur.Seconds()) + + // Advance timestamp + sloCtx.LastTokenTimestamp = now + // Refresh metrics + RefreshLastSeenMetrics(ctx, sloCtx) +} + +// ProcessToken records actual inter-token latency, trains, predicts sampled TPOT, updates sloCtx, and advances timestamp. +func ProcessTokenForLatencyPrediction( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *SLORequestContext, + now time.Time, +) { + logger := log.FromContext(ctx) + + // Initialize sampler if not yet + if sloCtx.TokenSampler == nil { + requestID := sloCtx.SchedulingRequest.Headers[requtil.RequestIdHeaderKey] + sloCtx.TokenSampler = NewTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens) + logger.V(logutil.DEBUG).Info("Initialized token sampler for subsequent tokens", "request_id", requestID, "next_prediction_token", sloCtx.TokenSampler.GetNextSampleToken()) + } + + // Inter-token latency + latencyMs := float64(now.Sub(sloCtx.LastTokenTimestamp).Milliseconds()) + sloCtx.GeneratedTokenCount++ + + //log the inter-token latency for predicted samples + if sloCtx.GeneratedTokenCount == 2 || sloCtx.TokenSampler.ShouldPredict(sloCtx.GeneratedTokenCount) { //tricky logic, since next sample token is always +1 from current token + sloCtx.TPOTObservations = append(sloCtx.TPOTObservations, latencyMs) + sloCtx.AvgTPOT = calculateRunningAverage(sloCtx.AvgTPOT, latencyMs, len(sloCtx.TPOTObservations)) + } + + m, err := GetLatestMetricsForProfile(ctx, sloCtx) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics", + "error", err) + return + } + // Record actual TPOT + entry := latencypredictor.TrainingEntry{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + ActualTTFT: 0, + ActualTPOT: latencyMs, + Timestamp: now, + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: sloCtx.GeneratedTokenCount - 1, + PrefixCacheScore: 0, // TPOT does not use prefix cache score + } + if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil { + logger.V(logutil.DEBUG).Error(err, "record TPOT training failed") + } + + // Sampled predict + if sloCtx.TokenSampler.ShouldPredict(sloCtx.GeneratedTokenCount) { + in := latencypredictor.PredictionRequest{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: sloCtx.GeneratedTokenCount, + PrefixCacheScore: 0, // TPOT does not use prefix cache score + } + start := time.Now() + p, err := predictor.Predict(ctx, in) + dur := time.Since(start) + if err != nil || p == nil { + logger.V(logutil.DEBUG).Error(err, "TPOT predict failed", "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, 0) + sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, 0, len(sloCtx.PredictedTPOTObservations)) + } else { + logger.V(logutil.DEBUG).Info("TPOT predict succeeded", "value_ms", p.TPOT, "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, p.TPOT) + sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, p.TPOT, len(sloCtx.PredictedTPOTObservations)) + } + metrics.RecordRequestTPOTPredictionDuration(ctx, sloCtx.SchedulingRequest.TargetModel, sloCtx.IncomingModelName, dur.Seconds()) + + sloCtx.TokenSampler.RecordPrediction(sloCtx.GeneratedTokenCount) + } + + // Advance timestamp + sloCtx.LastTokenTimestamp = now + // Refresh metrics + RefreshLastSeenMetrics(ctx, sloCtx) +} + +// PredictWithMetrics predicts TTFT or TPOT based on provided metrics state and token count. +func PredictWithMetrics( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + metricsState *backendmetrics.MetricsState, + prompt string, + generatedTokenCount int, + prefixcachescore float64, +) (*latencypredictor.PredictionResponse, error) { + logger := log.FromContext(ctx) + + if metricsState == nil { + return nil, fmt.Errorf("metrics state cannot be nil") + } + + // Build prediction request + in := latencypredictor.PredictionRequest{ + KVCachePercentage: metricsState.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(prompt)), + NumRequestWaiting: metricsState.WaitingQueueSize, + NumRequestRunning: metricsState.RunningQueueSize, + NumTokensGenerated: generatedTokenCount, + PrefixCacheScore: prefixcachescore, + } + + // Perform prediction + start := time.Now() + result, err := predictor.Predict(ctx, in) + duration := time.Since(start) + + if err != nil { + logger.V(logutil.DEBUG).Error(err, "prediction failed", + "duration_ms", duration.Milliseconds(), + "input_tokens", in.InputTokenLength, + "generated_tokens", generatedTokenCount, + "kv_cache_percent", in.KVCachePercentage, + "waiting_queue", in.NumRequestWaiting, + "running_queue", in.NumRequestRunning, + "prefix_cache_score", in.PrefixCacheScore) + return nil, err + } + + if result == nil { + logger.V(logutil.DEBUG).Info("prediction returned nil", + "duration_ms", duration.Milliseconds()) + return nil, fmt.Errorf("prediction returned nil result") + } + + logger.V(logutil.DEBUG).Info("prediction succeeded", + "tpot_ms", result.TPOT, + "ttft_ms", result.TTFT, + "duration_ms", duration.Milliseconds(), + "input_tokens", in.InputTokenLength, + "generated_tokens", generatedTokenCount, + "kv_cache_percent", in.KVCachePercentage, + "waiting_queue", in.NumRequestWaiting, + "running_queue", in.NumRequestRunning, + "prefix_cache_score", in.PrefixCacheScore) + + return result, nil +} + +// BulkPredictWithMetrics performs bulk predictions for multiple pods using their metrics states. +// Returns predictions in the same order as the input slices. +func BulkPredictWithMetrics( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + metricsStates []*backendmetrics.MetricsState, + prompts []string, + generatedTokenCounts []int, + prefixCacheScores []float64, +) ([]*latencypredictor.PredictionResponse, error) { + logger := log.FromContext(ctx) + + // Validate input lengths + if len(metricsStates) != len(prompts) || len(prompts) != len(generatedTokenCounts) || len(generatedTokenCounts) != len(prefixCacheScores) { + return nil, fmt.Errorf("input slice lengths must match: metrics=%d, prompts=%d, tokenCounts=%d, prefixScores=%d", + len(metricsStates), len(prompts), len(generatedTokenCounts), len(prefixCacheScores)) + } + + if len(metricsStates) == 0 { + return []*latencypredictor.PredictionResponse{}, nil + } + + // Validate that no metrics state is nil + for i, metricsState := range metricsStates { + if metricsState == nil { + return nil, fmt.Errorf("metrics state at index %d cannot be nil", i) + } + } + + // Build bulk prediction requests + bulkRequests := make([]latencypredictor.PredictionRequest, len(metricsStates)) + for i := range metricsStates { + bulkRequests[i] = latencypredictor.PredictionRequest{ + KVCachePercentage: metricsStates[i].KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(prompts[i])), + NumRequestWaiting: metricsStates[i].WaitingQueueSize, + NumRequestRunning: metricsStates[i].RunningQueueSize, + NumTokensGenerated: generatedTokenCounts[i], + PrefixCacheScore: prefixCacheScores[i], + } + } + + // Perform bulk prediction + start := time.Now() + bulkResponse, err := predictor.PredictBulkStrict(ctx, bulkRequests) + duration := time.Since(start) + + if err != nil { + logger.V(logutil.DEBUG).Error(err, "bulk prediction failed", + "duration_ms", duration.Milliseconds(), + "request_count", len(bulkRequests)) + return nil, err + } + + if bulkResponse == nil { + logger.V(logutil.DEBUG).Info("bulk prediction returned nil", + "duration_ms", duration.Milliseconds()) + return nil, fmt.Errorf("bulk prediction returned nil result") + } + + // Convert to pointer slice for consistency with single prediction + results := make([]*latencypredictor.PredictionResponse, len(bulkResponse.Predictions)) + for i := range bulkResponse.Predictions { + results[i] = &bulkResponse.Predictions[i] + } + + logger.V(logutil.DEBUG).Info("bulk prediction succeeded", + "duration_ms", duration.Milliseconds(), + "request_count", len(bulkRequests), + "successful_predictions", bulkResponse.SuccessfulPredictions, + "failed_predictions", bulkResponse.FailedPredictions, + "processing_time_ms", bulkResponse.ProcessingTimeMs) + + // Log detailed results if at trace level + if logger.V(logutil.TRACE).Enabled() { + for i, result := range results { + logger.V(logutil.TRACE).Info("bulk prediction result", + "index", i, + "ttft_ms", result.TTFT, + "tpot_ms", result.TPOT, + "input_tokens", bulkRequests[i].InputTokenLength, + "generated_tokens", bulkRequests[i].NumTokensGenerated, + "kv_cache_percent", bulkRequests[i].KVCachePercentage, + "waiting_queue", bulkRequests[i].NumRequestWaiting, + "running_queue", bulkRequests[i].NumRequestRunning, + "prefix_cache_score", bulkRequests[i].PrefixCacheScore) + } + } + + return results, nil +} + +// calculateRunningAverage calculates the running average efficiently +func calculateRunningAverage(currentAvg float64, newValue float64, count int) float64 { + if count == 0 { + return 0 + } + if count == 1 { + return newValue + } + return currentAvg + (newValue-currentAvg)/float64(count) +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go new file mode 100644 index 000000000..0c2cfa0a9 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go @@ -0,0 +1,138 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "context" + + "sigs.k8s.io/controller-runtime/pkg/log" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +type PodPredictionResult struct { + Pod schedulingtypes.Pod + TTFT float64 + TPOT float64 + TTFTValid bool + TPOTValid bool + IsValid bool + Error error + Headroom float64 // Headroom for the pod, if applicable + TTFTHeadroom float64 // TTFT headroom for the pod + PrefixCacheScore float64 // Prefix cache score for the pod +} + +// generatePredictions creates prediction results for all candidate pods +func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *SLORequestContext, candidatePods []schedulingtypes.Pod) ([]PodPredictionResult, error) { + logger := log.FromContext(ctx) + predictions := make([]PodPredictionResult, 0, len(candidatePods)) + + for _, pod := range candidatePods { + predResult := PodPredictionResult{Pod: pod} + + logger.V(logutil.TRACE).Info("Candidate pod for scheduling", "pod", pod.GetPod().String(), "metrics", pod.GetMetrics().String()) + + // Get prefix cache score for the pod + prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod) + + sloCtx.PrefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore + + logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "prefixCacheScore", prefixCacheScore) + + // Generate prediction + prediction, err := PredictWithMetrics(ctx, s.latencypredictor, pod.GetMetrics(), request.Body.Completions.Prompt, 1, prefixCacheScore) + if err != nil { + logger.V(logutil.DEBUG).Error(err, "Skipping pod due to prediction error", "pod", pod.GetPod().String(), "error", err) + predResult.Error = err + return nil, err + } + predResult.PrefixCacheScore = prefixCacheScore + predResult.TTFT = prediction.TTFT + predResult.TPOT = prediction.TPOT + podMinTPOTSLO := 0.0 + //if pod.GetPod().RunningRequests.Peek() != nil { + // podMinTPOTSLO = pod.GetPod().RunningRequests.Peek().TPOT + //} + // Do this: + podMinTPOTSLO = s.getPodMinTPOTSLO(pod) + predResult.TTFTValid, predResult.TPOTValid, predResult.IsValid, predResult.Headroom, predResult.TTFTHeadroom = s.validatePrediction(prediction, sloCtx, podMinTPOTSLO) + + logger.V(logutil.DEBUG).Info("Prediction for scheduling", + "pod", pod.GetPod().String(), + "prefixCacheScore", prefixCacheScore, + "TTFT", prediction.TTFT, + "TPOT", prediction.TPOT, + "buffer", SLOBufferFactor, + "podMinTPOTSLO", podMinTPOTSLO, + "ttftSLO", sloCtx.TTFTSLO, + "requestTPOTSLO", sloCtx.AvgTPOTSLO, + "tpotHeadroom", predResult.Headroom, + "ttftHeadroom", predResult.TTFTHeadroom, + "tpotValid", predResult.TPOTValid, + "ttftValid", predResult.TTFTValid, + "headroomStrategy", s.headroomStrategy) + + predictions = append(predictions, predResult) + } + + return predictions, nil +} + +// updateRequestContextWithPredictions updates the request context with prediction data +func (s *SLOAwareRouter) updateRequestContextWithPredictions(sloCtx *SLORequestContext, predictions []PodPredictionResult) { + for _, pred := range predictions { + if pred.Error == nil { + podKey := pred.Pod.GetPod().String() + if sloCtx.PredictedTTFTForScheduling == nil { + sloCtx.PredictedTTFTForScheduling = make(map[string]float64) + } + if sloCtx.PredictedTPOTForScheduling == nil { + sloCtx.PredictedTPOTForScheduling = make(map[string]float64) + } + sloCtx.PredictedTTFTForScheduling[podKey] = pred.TTFT + sloCtx.PredictedTPOTForScheduling[podKey] = pred.TPOT + } + } +} + +func (s *SLOAwareRouter) validatePrediction( + pred *latencypredictor.PredictionResponse, + sloCtx *SLORequestContext, + podMinTPOTSLO float64, +) (ttftOk, tpotOk, isValid bool, headroom float64, ttftHeadroom float64) { + + bufferedTPOT := sloCtx.AvgTPOTSLO * SLOBufferFactor + // a podMinTPOTSLO of 0 means no either no requests, or no TPOT SLOs specified on running requests + if podMinTPOTSLO > 0 { + if podMinTPOTSLO < sloCtx.AvgTPOTSLO { + //print debug message + log.FromContext(context.Background()).V(logutil.DEBUG).Info("Pod min TPOT SLO is less than the req SLO, adjusting", "podMinTPOTSLO", podMinTPOTSLO, "bufferedTPOT", sloCtx.AvgTPOTSLO) + } + bufferedTPOT = min(bufferedTPOT, podMinTPOTSLO*SLOBufferFactor) + } + + tpotOk = pred.TPOT < bufferedTPOT + ttftOk = pred.TTFT < sloCtx.TTFTSLO + + isValid = ttftOk && tpotOk + headroom = bufferedTPOT - pred.TPOT + ttftHeadroom = sloCtx.TTFTSLO - pred.TTFT + return +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go new file mode 100644 index 000000000..f865bbeb3 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go @@ -0,0 +1,262 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "fmt" + "time" + + "github.com/go-logr/logr" + "sigs.k8s.io/controller-runtime/pkg/log" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" +) + +var _ requestcontrol.PreRequest = &SLOAwareRouter{} +var _ requestcontrol.ResponseReceived = &SLOAwareRouter{} +var _ requestcontrol.ResponseStreaming = &SLOAwareRouter{} +var _ requestcontrol.ResponseComplete = &SLOAwareRouter{} + +type SLORequestContext struct { + SchedulingRequest schedulingtypes.LLMRequest + TargetPod *backend.Pod + SchedulingResult *schedulingtypes.SchedulingResult + LastSeenMetrics map[string]*backendmetrics.MetricsState + LastTokenTimestamp time.Time + RequestReceivedTimestamp time.Time + GeneratedTokenCount int + IncomingModelName string + TTFT float64 + PredictedTTFT float64 + AvgTPOT float64 + AvgPredictedTPOT float64 + TokenSampler *TokenSampler + TPOTObservations []float64 + PredictedTPOTObservations []float64 + + PrefixCacheScoresForPods map[string]float64 + + // TTFTSLO is the target time to first token SLO for the request. + TTFTSLO float64 + // TPOTSLO is the target time per output token SLO for the request. + AvgTPOTSLO float64 + + // PredictorBasedScheduling indicates whether to use predictor based scheduling. + PredictorBasedScheduling bool + //PredictedTTFTForScheduling is the map of pod names to predicted TTFT values for scheduling. + PredictedTTFTForScheduling map[string]float64 + // PredictedTPOTForScheduling is the map of pod names to predicted TPOT values for scheduling. + PredictedTPOTForScheduling map[string]float64 + + // boolean set if request has valid pod based on predictions + HasValidPod bool +} + +func NewSLORequestContext(request *schedulingtypes.LLMRequest) *SLORequestContext { + return &SLORequestContext{ + SchedulingRequest: *request, + LastSeenMetrics: make(map[string]*backendmetrics.MetricsState), + PrefixCacheScoresForPods: make(map[string]float64), + PredictedTTFTForScheduling: make(map[string]float64), + PredictedTPOTForScheduling: make(map[string]float64), + } +} + +func (s *SLOAwareRouter) getSLOContextForRequest(request *schedulingtypes.LLMRequest) (*SLORequestContext, error) { + id := request.Headers[requtil.RequestIdHeaderKey] + if ctx, exists := s.sloContextStore.Load(id); exists { + return ctx.(*SLORequestContext), nil + } + return nil, fmt.Errorf("SLO context not found for request ID: %s", id) +} + +func (s *SLOAwareRouter) setSLOContextForRequest(request *schedulingtypes.LLMRequest, ctx *SLORequestContext) { + id := request.Headers[requtil.RequestIdHeaderKey] + s.sloContextStore.Store(id, ctx) +} + +func (s *SLOAwareRouter) deleteSLOContextForRequest(request *schedulingtypes.LLMRequest) { + id := request.Headers[requtil.RequestIdHeaderKey] + s.sloContextStore.Delete(id) +} + +// --- RequestControl Hooks --- + +func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult) { + logger := log.FromContext(ctx) + + if schedulingResult == nil || len(schedulingResult.ProfileResults) == 0 { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping PreRequest because no scheduling result was provided.") + return + } + + targetPod := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName].TargetPods[0].GetPod() + if !t.CheckPredictor(logger, targetPod) { + return + } + + podName := types.NamespacedName{ + Name: targetPod.NamespacedName.Name, + Namespace: targetPod.NamespacedName.Namespace, + } + + logger.V(logutil.TRACE).Info("request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey], "podName", podName) + if request.Headers[requtil.RequestIdHeaderKey] == "" { + logger.V(logutil.DEBUG).Error(fmt.Errorf("missing request ID"), "SLOAwareRouter.PreRequest: Request is missing request ID header") + } + + id := request.Headers[requtil.RequestIdHeaderKey] + podRequestList, ok := t.runningRequestLists[podName] + if !ok { + podRequestList = NewRequestPriorityQueue() + t.runningRequestLists[podName] = podRequestList + } + + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + id := request.Headers[requtil.RequestIdHeaderKey] + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter.PreRequest: Failed to get SLO context for request", "requestID", id) + return + } + + added := podRequestList.Add(id, sloCtx.AvgTPOTSLO) + if !added { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Item already exists in queue", "podName", podName, "requestID", id) + } + + // Set up SLO request context + sloCtx.TargetPod = targetPod + sloCtx.SchedulingResult = schedulingResult + sloCtx.RequestReceivedTimestamp = time.Now() + RefreshLastSeenMetrics(ctx, sloCtx) + t.setSLOContextForRequest(request, sloCtx) +} + +func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) { + logger := log.FromContext(ctx) + if !t.CheckPredictor(logger, targetPod) { + return + } + + id := request.Headers[requtil.RequestIdHeaderKey] + + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Failed to get SLO context for request", "requestID", id) + return + } + + if err := ProcessHeaderForLatencyPrediction(ctx, t.latencypredictor, sloCtx); err != nil { + logger.V(logutil.DEBUG).Error(err, "ProcessHeader in latencypredictor failed") + } + +} + +func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) { + logger := log.FromContext(ctx) + if !t.CheckPredictor(logger, pod) || response.EndOfStream { + return + } + + now := time.Now() + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + id := request.Headers[requtil.RequestIdHeaderKey] + logger.V(logutil.TRACE).Error(err, "SLOAwareRouter.ResponseStreaming: Failed to get SLO context for request", "requestID", id) + return + } + + if sloCtx.TTFT == 0 { + ProcessFirstTokenForLatencyPrediction(ctx, t.latencypredictor, sloCtx, now) + } else { + ProcessTokenForLatencyPrediction(ctx, t.latencypredictor, sloCtx, now) + } + +} + +func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) { + logger := log.FromContext(ctx) + targetPod := pod + if !t.CheckPredictor(logger, targetPod) { + return + } + + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + id := request.Headers[requtil.RequestIdHeaderKey] + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter.ResponseComplete: Failed to get SLO context for request", "requestID", id) + return + } + + if sloCtx.TTFT > 0 { + logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTTFT", sloCtx.TTFT, "avgPredictedTTFT", sloCtx.PredictedTTFT) + metrics.RecordRequestTTFT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.TTFT/1000) + metrics.RecordRequestPredictedTTFT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.PredictedTTFT/1000) + if sloCtx.TTFTSLO > 0 { + metrics.RecordRequestTTFTWithSLO(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.TTFT, sloCtx.TTFTSLO) + } + } + + if sloCtx.AvgTPOT > 0 { + logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTPOT", sloCtx.AvgTPOT, "avgPredictedTPOT", sloCtx.AvgPredictedTPOT) + metrics.RecordRequestTPOT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgTPOT/1000) + metrics.RecordRequestPredictedTPOT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgPredictedTPOT/1000) + if sloCtx.AvgTPOTSLO > 0 { + metrics.RecordRequestTPOTWithSLO(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgTPOT, sloCtx.AvgTPOTSLO) + } + } + + logger.V(logutil.TRACE).Info("SLO Aware Routing Mode", "PredictorBasedScheduling", sloCtx.PredictorBasedScheduling) + + podName := types.NamespacedName{ + Name: targetPod.NamespacedName.Name, + Namespace: targetPod.NamespacedName.Namespace, + } + + id := request.Headers[requtil.RequestIdHeaderKey] + podRequestList, ok := t.runningRequestLists[podName] + if !ok { + err := fmt.Errorf("no running request list found for pod %s", podName.String()) + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Failed to remove request from queue", "requestID", id) + } + + _, removed := podRequestList.Remove(id) + if !removed { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Item not found in queue", "podName", podName, "requestID", id) + } + t.deleteSLOContextForRequest(request) +} + +func (t *SLOAwareRouter) CheckPredictor(logger logr.Logger, targetPod *backend.Pod) bool { + if targetPod == nil { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping hook because no target pod was provided.") + return false + } + if t.latencypredictor == nil { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping hook because predictor missing") + return false + } + return true +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go new file mode 100644 index 000000000..96999af2f --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go @@ -0,0 +1,945 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/go-logr/logr" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/types" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" +) + +// Helper functions + +func createTestSchedulingResult(pod *backend.Pod, kvUsage float64, runningQueue int, waitingQueue int) *schedulingtypes.SchedulingResult { + + mockPod := createTestPod(pod.NamespacedName.Name, kvUsage, runningQueue, waitingQueue) + + return &schedulingtypes.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*schedulingtypes.ProfileRunResult{ + "default": { + TargetPods: []schedulingtypes.Pod{mockPod}, + }, + }, + } +} + +func createTestRouter() *SLOAwareRouter { + return &SLOAwareRouter{ + sloContextStore: sync.Map{}, + runningRequestLists: make(map[types.NamespacedName]*RequestPriorityQueue), + latencypredictor: nil, + } +} + +// Test cases + +func TestNewSLORequestContext(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + + ctx := NewSLORequestContext(request) + + assert.NotNil(t, ctx) + assert.Equal(t, *request, ctx.SchedulingRequest) + assert.NotNil(t, ctx.LastSeenMetrics) + assert.NotNil(t, ctx.PrefixCacheScoresForPods) + assert.NotNil(t, ctx.PredictedTTFTForScheduling) + assert.NotNil(t, ctx.PredictedTPOTForScheduling) + assert.Empty(t, ctx.LastSeenMetrics) + assert.Empty(t, ctx.PrefixCacheScoresForPods) +} + +func TestSLOAwareRouter_SetAndGetSLOContext(t *testing.T) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + sloCtx := NewSLORequestContext(request) + + // Set context + router.setSLOContextForRequest(request, sloCtx) + + // Get context + retrievedCtx, err := router.getSLOContextForRequest(request) + + require.NoError(t, err) + assert.Equal(t, sloCtx, retrievedCtx) +} + +func TestSLOAwareRouter_GetSLOContext_NotFound(t *testing.T) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + + // Try to get context that doesn't exist + ctx, err := router.getSLOContextForRequest(request) + + assert.Error(t, err) + assert.Nil(t, ctx) + assert.Contains(t, err.Error(), "SLO context not found") +} + +func TestSLOAwareRouter_DeleteSLOContext(t *testing.T) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + sloCtx := NewSLORequestContext(request) + + // Set and then delete context + router.setSLOContextForRequest(request, sloCtx) + router.deleteSLOContextForRequest(request) + + // Verify it's deleted + ctx, err := router.getSLOContextForRequest(request) + assert.Error(t, err) + assert.Nil(t, ctx) +} + +func TestSLOAwareRouter_PreRequest_NoSchedulingResult(t *testing.T) { + router := createTestRouter() + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + + // Call PreRequest with nil scheduling result + router.PreRequest(ctx, request, nil) + + // Should not create SLO context + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_PreRequest_EmptySchedulingResult(t *testing.T) { + router := createTestRouter() + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + + schedulingResult := &schedulingtypes.SchedulingResult{ + ProfileResults: map[string]*schedulingtypes.ProfileRunResult{}, + } + + // Call PreRequest with empty scheduling result + router.PreRequest(ctx, request, schedulingResult) + + // Should not create SLO context + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_PreRequest_Success(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create and set initial SLO context + sloCtx := NewSLORequestContext(request) + sloCtx.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request, sloCtx) + + // Initialize the request priority queue + router.runningRequestLists[pod.GetPod().NamespacedName] = NewRequestPriorityQueue() + + beforeTime := time.Now() + router.PreRequest(ctx, request, schedulingResult) + afterTime := time.Now() + + // Verify SLO context was updated + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.Equal(t, pod.GetPod(), retrievedCtx.TargetPod) + assert.Equal(t, schedulingResult, retrievedCtx.SchedulingResult) + assert.True(t, retrievedCtx.RequestReceivedTimestamp.After(beforeTime) || + retrievedCtx.RequestReceivedTimestamp.Equal(beforeTime)) + assert.True(t, retrievedCtx.RequestReceivedTimestamp.Before(afterTime) || + retrievedCtx.RequestReceivedTimestamp.Equal(afterTime)) +} + +func TestSLOAwareRouter_PreRequest_AddsToQueue(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create and set initial SLO context + sloCtx := NewSLORequestContext(request) + sloCtx.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request, sloCtx) + + // PreRequest should create the queue + router.PreRequest(ctx, request, schedulingResult) + + // Verify queue was created and request was added + queue, exists := router.runningRequestLists[pod.GetPod().NamespacedName] + assert.True(t, exists, "Queue should be created for pod") + assert.NotNil(t, queue) +} + +func TestSLOAwareRouter_PreRequest_QueueAlreadyExists(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request1 := createTestLLMRequest("test-id-1", 100, 50, true) + request2 := createTestLLMRequest("test-id-2", 100, 50, true) + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create and set initial SLO contexts + sloCtx1 := NewSLORequestContext(request1) + sloCtx1.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request1, sloCtx1) + + sloCtx2 := NewSLORequestContext(request2) + sloCtx2.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request2, sloCtx2) + + // Add first request + router.PreRequest(ctx, request1, schedulingResult) + + // Add second request to same pod + router.PreRequest(ctx, request2, schedulingResult) + + // Verify both are in the same queue + queue, exists := router.runningRequestLists[pod.GetPod().NamespacedName] + assert.True(t, exists) + assert.NotNil(t, queue) +} + +func TestSLOAwareRouter_ResponseReceived_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic and should return early + router.ResponseReceived(ctx, request, response, pod.GetPod()) + + // Context should still exist + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} + +func TestSLOAwareRouter_ResponseReceived_NoPod(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic with nil pod + router.ResponseReceived(ctx, request, response, nil) + + // Predictor should not be called + +} + +func TestSLOAwareRouter_ResponseReceived_NoContext(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Don't set SLO context + router.ResponseReceived(ctx, request, response, pod.GetPod()) + + // Should handle missing context gracefully + +} + +func TestSLOAwareRouter_ResponseStreaming_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic and should return early + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // Context should still exist + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} +func TestSLOAwareRouter_ResponseStreaming_FirstToken(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + sloCtx := NewSLORequestContext(request) + sloCtx.RequestReceivedTimestamp = time.Now() + sloCtx.SchedulingResult = schedulingResult + sloCtx.SchedulingRequest = *request + sloCtx.TTFTSLO = 100 + sloCtx.AvgTPOTSLO = 50 + sloCtx.IncomingModelName = "test-model" + sloCtx.PredictedTTFT = 80.0 + sloCtx.AvgPredictedTPOT = 30.0 + // ADD THIS - populate metrics + sloCtx.LastSeenMetrics["prefill"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + sloCtx.LastSeenMetrics["default"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + router.setSLOContextForRequest(request, sloCtx) + + // Initialize the queue and add the request + queue := NewRequestPriorityQueue() + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + + beforeTime := time.Now() + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + afterTime := time.Now() + + // Verify first token timestamp was set + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.True(t, retrievedCtx.LastTokenTimestamp.After(beforeTime) || + retrievedCtx.LastTokenTimestamp.Equal(beforeTime)) + assert.True(t, retrievedCtx.LastTokenTimestamp.Before(afterTime) || + retrievedCtx.LastTokenTimestamp.Equal(afterTime)) +} + +func TestSLOAwareRouter_ResponseStreaming_SubsequentTokens(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + sloCtx := NewSLORequestContext(request) + sloCtx.RequestReceivedTimestamp = time.Now() + sloCtx.SchedulingResult = schedulingResult + sloCtx.SchedulingRequest = *request + sloCtx.TTFTSLO = 100 + sloCtx.AvgTPOTSLO = 50 + sloCtx.IncomingModelName = "test-model" + sloCtx.PredictedTTFT = 80.0 + sloCtx.AvgPredictedTPOT = 30.0 + // ADD THIS - populate metrics + sloCtx.LastSeenMetrics["prefill"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + sloCtx.LastSeenMetrics["default"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + firstTokenTime := time.Now().Add(-100 * time.Millisecond) + + router.setSLOContextForRequest(request, sloCtx) + + // Initialize the queue and add the request + queue := NewRequestPriorityQueue() + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // Verify token timestamp was updated + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.True(t, retrievedCtx.LastTokenTimestamp.After(firstTokenTime)) +} + +func TestSLOAwareRouter_ResponseComplete_QueueNotFound(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + sloCtx.IncomingModelName = "test-model" + sloCtx.TargetPod = pod.GetPod() // ADD THIS to avoid other issues + router.setSLOContextForRequest(request, sloCtx) + + // Create an EMPTY queue (not nil, but empty) to test queue.Remove behavior + router.runningRequestLists[pod.GetPod().NamespacedName] = NewRequestPriorityQueue() + + // Should handle gracefully when request is not in queue + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Context should be deleted + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} +func TestSLOAwareRouter_ResponseStreaming_NoContext(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Don't set SLO context - should handle gracefully + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // Should not panic + +} + +func TestSLOAwareRouter_ResponseComplete_Success(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Create queue and add request + queue := NewRequestPriorityQueue() + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + + sloCtx := NewSLORequestContext(request) + sloCtx.TTFT = 80 + sloCtx.AvgTPOT = 30 + sloCtx.PredictedTTFT = 85 + sloCtx.AvgPredictedTPOT = 32 + sloCtx.TTFTSLO = 100 + sloCtx.AvgTPOTSLO = 50 + sloCtx.IncomingModelName = "incoming-model" + router.setSLOContextForRequest(request, sloCtx) + + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify context was deleted + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) + + // Verify request was removed from queue + assert.Equal(t, 0, queue.Len()) +} + +func TestSLOAwareRouter_ResponseComplete_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Context should still exist (deletion happens only with predictor) + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} + +func TestSLOAwareRouter_ResponseComplete_NoPod(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic with nil pod + router.ResponseComplete(ctx, request, response, nil) + + // Context should still exist (deletion happens only with validpod.GetPod()) + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} + +func TestSLOAwareRouter_ResponseComplete_NoContext(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Don't set SLO context - should handle gracefully + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Should not panic + +} + +func TestSLOAwareRouter_ResponseComplete_WithMetrics(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Create queue + queue := NewRequestPriorityQueue() + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + + sloCtx := NewSLORequestContext(request) + sloCtx.TTFT = 80 + sloCtx.AvgTPOT = 30 + sloCtx.PredictedTTFT = 85 + sloCtx.AvgPredictedTPOT = 32 + sloCtx.TTFTSLO = 100 + sloCtx.AvgTPOTSLO = 50 + sloCtx.IncomingModelName = "incoming-model" + router.setSLOContextForRequest(request, sloCtx) + + // Should record metrics without panicking + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify cleanup + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_ResponseComplete_NoSLOs(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test-id", 0, 0, true) // No SLOs + response := &requestcontrol.Response{} + + // Create queue + queue := NewRequestPriorityQueue() + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 0) + + sloCtx := NewSLORequestContext(request) + sloCtx.TTFT = 80 + sloCtx.AvgTPOT = 30 + sloCtx.IncomingModelName = "test-model" + router.setSLOContextForRequest(request, sloCtx) + + // Should handle missing SLOs gracefully + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify cleanup + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_CheckPredictor_NilPod(t *testing.T) { + router := createTestRouter() + logger := logr.Discard() + + result := router.CheckPredictor(logger, nil) + + assert.False(t, result) +} + +func TestSLOAwareRouter_CheckPredictor_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + logger := logr.Discard() + pod := createTestPod("test-pod", 1, 1, 1) + + result := router.CheckPredictor(logger, pod.GetPod()) + + assert.False(t, result) +} + +func TestSLOAwareRouter_CheckPredictor_Success(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + logger := logr.Discard() + pod := createTestPod("test-pod", 1, 1, 1) + + result := router.CheckPredictor(logger, pod.GetPod()) + + assert.True(t, result) +} + +func TestSLORequestContext_Fields(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := NewSLORequestContext(request) + + // Test all field initialization + assert.NotNil(t, ctx.LastSeenMetrics) + assert.NotNil(t, ctx.PrefixCacheScoresForPods) + assert.NotNil(t, ctx.PredictedTTFTForScheduling) + assert.NotNil(t, ctx.PredictedTPOTForScheduling) + assert.Empty(t, ctx.TPOTObservations) + assert.Empty(t, ctx.PredictedTPOTObservations) + assert.Zero(t, ctx.GeneratedTokenCount) + assert.Zero(t, ctx.TTFT) + assert.Zero(t, ctx.AvgTPOT) + assert.Nil(t, ctx.TargetPod) + assert.Nil(t, ctx.SchedulingResult) + assert.Nil(t, ctx.TokenSampler) +} + +func TestSLORequestContext_UpdateMetrics(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := NewSLORequestContext(request) + + // Add some metrics + metricsState := &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 3, + } + ctx.LastSeenMetrics["test-pod"] = metricsState + + assert.Len(t, ctx.LastSeenMetrics, 1) + assert.Equal(t, 0.5, ctx.LastSeenMetrics["test-pod"].KVCacheUsagePercent) + assert.Equal(t, 3, ctx.LastSeenMetrics["test-pod"].WaitingQueueSize) +} + +func TestSLORequestContext_PredictionData(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := NewSLORequestContext(request) + + // Set prediction data + ctx.PredictedTTFTForScheduling["pod1"] = 80.0 + ctx.PredictedTPOTForScheduling["pod1"] = 30.0 + ctx.PredictedTTFTForScheduling["pod2"] = 90.0 + ctx.PredictedTPOTForScheduling["pod2"] = 35.0 + + assert.Len(t, ctx.PredictedTTFTForScheduling, 2) + assert.Len(t, ctx.PredictedTPOTForScheduling, 2) + assert.Equal(t, 80.0, ctx.PredictedTTFTForScheduling["pod1"]) + assert.Equal(t, 30.0, ctx.PredictedTPOTForScheduling["pod1"]) +} + +func TestSLORequestContext_PrefixCacheScores(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := NewSLORequestContext(request) + + // Set prefix cache scores + ctx.PrefixCacheScoresForPods["pod1"] = 0.8 + ctx.PrefixCacheScoresForPods["pod2"] = 0.6 + ctx.PrefixCacheScoresForPods["pod3"] = 0.9 + + assert.Len(t, ctx.PrefixCacheScoresForPods, 3) + assert.Equal(t, 0.8, ctx.PrefixCacheScoresForPods["pod1"]) + assert.Equal(t, 0.9, ctx.PrefixCacheScoresForPods["pod3"]) +} + +func TestSLOAwareRouter_ConcurrentContextAccess(t *testing.T) { + router := createTestRouter() + + // Test concurrent access to context store + var wg sync.WaitGroup + numGoroutines := 100 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + requestID := uuid.New().String() + request := createTestLLMRequest(requestID, 100, 50, true) + sloCtx := NewSLORequestContext(request) + + // Set context + router.setSLOContextForRequest(request, sloCtx) + + // Get context + retrievedCtx, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) + assert.NotNil(t, retrievedCtx) + + // Delete context + router.deleteSLOContextForRequest(request) + }(i) + } + + wg.Wait() +} + +func TestSLOAwareRouter_MultipleRequests_SamePod(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + + request1 := createTestLLMRequest("test-id-1", 100, 50, true) + request2 := createTestLLMRequest("test-id-2", 100, 50, true) + request3 := createTestLLMRequest("test-id-3", 100, 50, true) + + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create and set SLO contexts + for _, req := range []*schedulingtypes.LLMRequest{request1, request2, request3} { + sloCtx := NewSLORequestContext(req) + sloCtx.AvgTPOTSLO = 50 + router.setSLOContextForRequest(req, sloCtx) + } + + // Add all requests + router.PreRequest(ctx, request1, schedulingResult) + router.PreRequest(ctx, request2, schedulingResult) + router.PreRequest(ctx, request3, schedulingResult) + + // Verify queue has all requests + queue, exists := router.runningRequestLists[pod.GetPod().NamespacedName] + assert.True(t, exists) + assert.NotNil(t, queue) +} + +func TestSLOAwareRouter_RequestLifecycle_Complete(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create initial context + sloCtx := NewSLORequestContext(request) + sloCtx.AvgTPOTSLO = 50 + sloCtx.IncomingModelName = "test-model" + router.setSLOContextForRequest(request, sloCtx) + + // 1. PreRequest + router.PreRequest(ctx, request, schedulingResult) + + // Verify context exists + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.NotNil(t, retrievedCtx.TargetPod) + + // 2. ResponseReceived + router.ResponseReceived(ctx, request, response, pod.GetPod()) + + // 3. ResponseStreaming (first token) + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // 4. ResponseStreaming (subsequent tokens) + retrievedCtx, _ = router.getSLOContextForRequest(request) + retrievedCtx.TTFT = 100 // Mark first token received + router.setSLOContextForRequest(request, retrievedCtx) + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // 5. ResponseComplete + retrievedCtx, _ = router.getSLOContextForRequest(request) + retrievedCtx.TTFT = 80 + retrievedCtx.AvgTPOT = 30 + router.setSLOContextForRequest(request, retrievedCtx) + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify context was cleaned up + _, err = router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_MultipleRequests_DifferentPods(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + + pod1 := createTestPod("test-pod-1", 1, 1, 1) + pod2 := createTestPod("test-pod-2", 1, 1, 1) + + request1 := createTestLLMRequest("test-id-1", 100, 50, true) + request2 := createTestLLMRequest("test-id-2", 100, 50, true) + + schedulingResult1 := createTestSchedulingResult(pod1.GetPod(), 1, 1, 1) + schedulingResult2 := createTestSchedulingResult(pod2.GetPod(), 1, 1, 1) + + // Create and set SLO contexts + sloCtx1 := NewSLORequestContext(request1) + sloCtx1.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request1, sloCtx1) + + sloCtx2 := NewSLORequestContext(request2) + sloCtx2.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request2, sloCtx2) + + // Add requests to different pods + router.PreRequest(ctx, request1, schedulingResult1) + router.PreRequest(ctx, request2, schedulingResult2) + + // Verify separate queues were created + queue1, exists1 := router.runningRequestLists[pod1.GetPod().NamespacedName] + queue2, exists2 := router.runningRequestLists[pod2.GetPod().NamespacedName] + + assert.True(t, exists1) + assert.True(t, exists2) + assert.NotNil(t, queue1) + assert.NotNil(t, queue2) + assert.NotEqual(t, queue1, queue2) +} + +func TestSLORequestContext_SLOValidation(t *testing.T) { + tests := []struct { + name string + ttftSLO float64 + tpotSLO float64 + expectSLOs bool + }{ + { + name: "Both SLOs set", + ttftSLO: 100, + tpotSLO: 50, + expectSLOs: true, + }, + { + name: "No SLOs", + ttftSLO: 0, + tpotSLO: 0, + expectSLOs: false, + }, + { + name: "Only TTFT SLO", + ttftSLO: 100, + tpotSLO: 0, + expectSLOs: false, + }, + { + name: "Only TPOT SLO", + ttftSLO: 0, + tpotSLO: 50, + expectSLOs: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + request := createTestLLMRequest("test-id", tt.ttftSLO, tt.tpotSLO, true) + ctx := NewSLORequestContext(request) + ctx.TTFTSLO = tt.ttftSLO + ctx.AvgTPOTSLO = tt.tpotSLO + + hasBothSLOs := ctx.TTFTSLO > 0 && ctx.AvgTPOTSLO > 0 + assert.Equal(t, tt.expectSLOs, hasBothSLOs) + }) + } +} + +// Benchmark tests + +func BenchmarkSLOAwareRouter_PreRequest(b *testing.B) { + router := createTestRouter() + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + requestID := uuid.New().String() + request := createTestLLMRequest(requestID, 100, 50, true) + sloCtx := NewSLORequestContext(request) + sloCtx.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request, sloCtx) + router.PreRequest(ctx, request, schedulingResult) + } +} + +func BenchmarkSLOAwareRouter_ContextOperations(b *testing.B) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + sloCtx := NewSLORequestContext(request) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + router.setSLOContextForRequest(request, sloCtx) + _, _ = router.getSLOContextForRequest(request) + router.deleteSLOContextForRequest(request) + } +} + +func BenchmarkSLORequestContext_Creation(b *testing.B) { + request := createTestLLMRequest("test", 100, 50, true) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = NewSLORequestContext(request) + } +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go new file mode 100644 index 000000000..ce1e997b0 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go @@ -0,0 +1,243 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "container/heap" + "fmt" + "sort" + "strings" + "sync" +) + +// Request represents an element in the priority queue. +// The index is needed by heap.Remove and is maintained by the heap.Interface methods. +type Request struct { + ID string // Unique identifier + TPOT float64 // The priority value (lower is higher priority) + index int +} + +// RequestPriorityQueue implements a priority queue with item removal by ID. +type RequestPriorityQueue struct { + items []*Request + lookup map[string]*Request + mutex sync.RWMutex +} + +// NewRequestPriorityQueue initializes and returns a new PriorityQueue. +func NewRequestPriorityQueue() *RequestPriorityQueue { + return &RequestPriorityQueue{ + lookup: make(map[string]*Request), + items: []*Request{}, + } +} + +// Clone creates a deep copy of the priority queue. +// The new queue is completely independent of the original. +func (pq *RequestPriorityQueue) Clone() *RequestPriorityQueue { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + // Initialize a new priority queue with pre-allocated capacity. + clonedPq := &RequestPriorityQueue{ + items: make([]*Request, len(pq.items)), + lookup: make(map[string]*Request, len(pq.lookup)), + } + + // Iterate through the original items to create deep copies. + for i, oldItem := range pq.items { + // Create a new Request struct, copying all values. + newItem := &Request{ + ID: oldItem.ID, + TPOT: oldItem.TPOT, + index: oldItem.index, + } + + // Assign the new item to the cloned queue's items slice. + clonedPq.items[i] = newItem + // Update the lookup map in the cloned queue to point to the new item. + clonedPq.lookup[newItem.ID] = newItem + } + + return clonedPq +} + +// Len is the number of items in the queue. +func (pq *RequestPriorityQueue) Len() int { return len(pq.items) } + +// Less reports whether the item with index i should sort before the item with index j. +func (pq *RequestPriorityQueue) Less(i, j int) bool { + return pq.items[i].TPOT < pq.items[j].TPOT +} + +// Swap swaps the items with indexes i and j. +func (pq *RequestPriorityQueue) Swap(i, j int) { + pq.items[i], pq.items[j] = pq.items[j], pq.items[i] + pq.items[i].index = i + pq.items[j].index = j +} + +// Push adds an item to the heap. +func (pq *RequestPriorityQueue) Push(x any) { + item := x.(*Request) + item.index = len(pq.items) + pq.items = append(pq.items, item) +} + +// Pop removes and returns the minimum item from the heap. +func (pq *RequestPriorityQueue) Pop() any { + n := len(pq.items) + item := pq.items[n-1] + pq.items[n-1] = nil // avoid memory leak + item.index = -1 // for safety + pq.items = pq.items[0 : n-1] + return item +} + +// Add adds a new item to the queue. +// Returns true if the item was added, false if an item with the same ID already exists. +func (pq *RequestPriorityQueue) Add(id string, tpot float64) bool { + pq.mutex.Lock() + defer pq.mutex.Unlock() + + // Validate input + if id == "" { + return false + } + if tpot < 0 { + return false + } + + // If item already exists, do not add + if _, exists := pq.lookup[id]; exists { + return false + } + + item := &Request{ + ID: id, + TPOT: tpot, + } + pq.lookup[id] = item + heap.Push(pq, item) + return true +} + +// Update modifies the TPOT value of an existing item in the queue. +// If the item doesn't exist, this method does nothing. +func (pq *RequestPriorityQueue) Update(id string, tpot float64) bool { + pq.mutex.Lock() + defer pq.mutex.Unlock() + + // Validate input + if tpot < 0 { + return false + } + + item, exists := pq.lookup[id] + if !exists { + return false + } + + item.TPOT = tpot + heap.Fix(pq, item.index) + return true +} + +// Remove removes an item from the queue by its ID. +func (pq *RequestPriorityQueue) Remove(id string) (*Request, bool) { + pq.mutex.Lock() + defer pq.mutex.Unlock() + + item, ok := pq.lookup[id] + if !ok { + return nil, false + } + removed := heap.Remove(pq, item.index).(*Request) + delete(pq.lookup, id) + return removed, true +} + +// Peek returns the item with the lowest value without removing it. +func (pq *RequestPriorityQueue) Peek() *Request { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + if len(pq.items) == 0 { + return nil + } + return pq.items[0] +} + +// GetSize returns the current number of items in the queue. +func (pq *RequestPriorityQueue) GetSize() int { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + return len(pq.items) +} + +// Contains checks if an item with the given ID exists in the queue. +func (pq *RequestPriorityQueue) Contains(id string) bool { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + _, exists := pq.lookup[id] + return exists +} + +// ToSlice returns a copy of all items in the queue, sorted by ID for stable comparison. +// This is primarily intended for testing and validation. +func (pq *RequestPriorityQueue) ToSlice() []*Request { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + // Create a copy to avoid returning a reference to the internal slice. + itemsCopy := make([]*Request, len(pq.items)) + copy(itemsCopy, pq.items) + + // Sort by ID to have a deterministic order for comparison in tests. + sort.Slice(itemsCopy, func(i, j int) bool { + return itemsCopy[i].ID < itemsCopy[j].ID + }) + + return itemsCopy +} + +// String returns a string representation of the queue for debugging. +func (pq *RequestPriorityQueue) String() string { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + if len(pq.items) == 0 { + return "RequestPriorityQueue: []" + } + + var builder strings.Builder + builder.WriteString("RequestPriorityQueue: [") + + for i, item := range pq.items { + if i > 0 { + builder.WriteString(", ") + } + builder.WriteString(item.ID) + builder.WriteString("(") + builder.WriteString(fmt.Sprintf("%.2f", item.TPOT)) + builder.WriteString(")") + } + + builder.WriteString("]") + return builder.String() +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go new file mode 100644 index 000000000..a8eba5fe1 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go @@ -0,0 +1,391 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "fmt" + "sync" + "testing" + "time" +) + +func TestNewRequestPriorityQueue(t *testing.T) { + pq := NewRequestPriorityQueue() + + if pq == nil { + t.Fatal("NewRequestPriorityQueue returned nil") + } + + if pq.GetSize() != 0 { + t.Errorf("Expected empty queue, got size %d", pq.GetSize()) + } + + if pq.Peek() != nil { + t.Error("Expected nil from Peek on empty queue") + } +} + +func TestAdd(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test successful add + if !pq.Add("req1", 2.5) { + t.Error("Expected Add to return true for new item") + } + + if pq.GetSize() != 1 { + t.Errorf("Expected size 1, got %d", pq.GetSize()) + } + + // Test duplicate add + if pq.Add("req1", 3.0) { + t.Error("Expected Add to return false for duplicate ID") + } + + if pq.GetSize() != 1 { + t.Errorf("Expected size 1 after duplicate add, got %d", pq.GetSize()) + } + + // Test validation + if pq.Add("", 1.0) { + t.Error("Expected Add to return false for empty ID") + } + + if pq.Add("req2", -1.0) { + t.Error("Expected Add to return false for negative TPOT") + } +} + +func TestPriorityOrdering(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Add items with different priorities + pq.Add("high", 1.0) // highest priority (lowest TPOT) + pq.Add("medium", 5.0) // medium priority + pq.Add("low", 10.0) // lowest priority (highest TPOT) + + // Check that highest priority item is at the top + peek := pq.Peek() + if peek == nil || peek.ID != "high" || peek.TPOT != 1.0 { + t.Errorf("Expected high priority item at top, got %+v", peek) + } + + // Test removal order + expected := []struct { + id string + tpot float64 + }{ + {"high", 1.0}, + {"medium", 5.0}, + {"low", 10.0}, + } + + for _, exp := range expected { + item := pq.Peek() + if item.ID != exp.id || item.TPOT != exp.tpot { + t.Errorf("Expected %s(%.1f), got %s(%.1f)", exp.id, exp.tpot, item.ID, item.TPOT) + } + + removed, ok := pq.Remove(item.ID) + if !ok || removed.ID != exp.id { + t.Errorf("Failed to remove %s", exp.id) + } + } +} + +func TestRemove(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test remove from empty queue + if _, ok := pq.Remove("nonexistent"); ok { + t.Error("Expected Remove to return false for empty queue") + } + + // Add some items + pq.Add("req1", 1.0) + pq.Add("req2", 2.0) + pq.Add("req3", 3.0) + + // Test successful remove + removed, ok := pq.Remove("req2") + if !ok || removed.ID != "req2" || removed.TPOT != 2.0 { + t.Errorf("Expected to remove req2(2.0), got %+v, ok=%v", removed, ok) + } + + if pq.GetSize() != 2 { + t.Errorf("Expected size 2 after removal, got %d", pq.GetSize()) + } + + // Test remove nonexistent + if _, ok := pq.Remove("req2"); ok { + t.Error("Expected Remove to return false for already removed item") + } + + // Verify remaining items are still in correct order + if peek := pq.Peek(); peek.ID != "req1" { + t.Errorf("Expected req1 at top, got %s", peek.ID) + } +} + +func TestUpdate(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test update nonexistent item + if pq.Update("nonexistent", 1.0) { + t.Error("Expected Update to return false for nonexistent item") + } + + // Add items + pq.Add("req1", 1.0) + pq.Add("req2", 2.0) + pq.Add("req3", 3.0) + + // Update to make req3 highest priority + if !pq.Update("req3", 0.5) { + t.Error("Expected Update to return true for existing item") + } + + // Check that req3 is now at the top + if peek := pq.Peek(); peek.ID != "req3" || peek.TPOT != 0.5 { + t.Errorf("Expected req3(0.5) at top, got %s(%.1f)", peek.ID, peek.TPOT) + } + + // Test validation + if pq.Update("req1", -1.0) { + t.Error("Expected Update to return false for negative TPOT") + } +} + +func TestContains(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test empty queue + if pq.Contains("req1") { + t.Error("Expected Contains to return false for empty queue") + } + + // Add item + pq.Add("req1", 1.0) + + // Test existing item + if !pq.Contains("req1") { + t.Error("Expected Contains to return true for existing item") + } + + // Test nonexistent item + if pq.Contains("req2") { + t.Error("Expected Contains to return false for nonexistent item") + } + + // Test after removal + pq.Remove("req1") + if pq.Contains("req1") { + t.Error("Expected Contains to return false after removal") + } +} + +func TestClone(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test clone of empty queue + clone := pq.Clone() + if clone.GetSize() != 0 { + t.Error("Expected cloned empty queue to be empty") + } + + // Add items to original + pq.Add("req1", 1.0) + pq.Add("req2", 2.0) + pq.Add("req3", 3.0) + + // Clone with items + clone = pq.Clone() + + // Verify clone has same items + if clone.GetSize() != pq.GetSize() { + t.Errorf("Expected clone size %d, got %d", pq.GetSize(), clone.GetSize()) + } + + // Verify independence - modify original + pq.Add("req4", 4.0) + if clone.GetSize() == pq.GetSize() { + t.Error("Clone should be independent of original") + } + + // Verify independence - modify clone + clone.Remove("req1") + if !pq.Contains("req1") { + t.Error("Original should not be affected by clone modifications") + } + + // Verify deep copy - items should be different instances + origPeek := pq.Peek() + clonePeek := clone.Peek() + if origPeek == clonePeek { + t.Error("Clone should create new Request instances, not share pointers") + } +} + +func TestString(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test empty queue + str := pq.String() + expected := "RequestPriorityQueue: []" + if str != expected { + t.Errorf("Expected %q, got %q", expected, str) + } + + // Test with items + pq.Add("req1", 1.5) + pq.Add("req2", 2.25) + + str = pq.String() + // Should contain both items in priority order + if !contains(str, "req1(1.50)") || !contains(str, "req2(2.25)") { + t.Errorf("String output missing expected items: %s", str) + } +} + +func TestConcurrency(t *testing.T) { + pq := NewRequestPriorityQueue() + const numWorkers = 10 + const itemsPerWorker = 100 + + var wg sync.WaitGroup + + // Launch workers that add items + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func(workerID int) { + defer wg.Done() + for j := 0; j < itemsPerWorker; j++ { + id := fmt.Sprintf("worker%d-item%d", workerID, j) + tpot := float64(j) + float64(workerID)*0.1 + pq.Add(id, tpot) + } + }(i) + } + + // Launch workers that read from the queue + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < itemsPerWorker/2; j++ { + pq.Peek() + pq.GetSize() + time.Sleep(time.Microsecond) + } + }() + } + + wg.Wait() + + // Verify final state + expectedSize := numWorkers * itemsPerWorker + if pq.GetSize() != expectedSize { + t.Errorf("Expected final size %d, got %d", expectedSize, pq.GetSize()) + } +} + +func TestLargeQueue(t *testing.T) { + pq := NewRequestPriorityQueue() + const numItems = 10000 + + // Add many items + for i := 0; i < numItems; i++ { + id := fmt.Sprintf("item%d", i) + tpot := float64(numItems - i) // Reverse order so item0 has highest priority + pq.Add(id, tpot) + } + + if pq.GetSize() != numItems { + t.Errorf("Expected size %d, got %d", numItems, pq.GetSize()) + } + + // Verify priority ordering by removing items + lastTPOT := -1.0 + for i := 0; i < numItems; i++ { + item := pq.Peek() + if item.TPOT < lastTPOT { + t.Errorf("Priority order violated: %.1f < %.1f", item.TPOT, lastTPOT) + } + lastTPOT = item.TPOT + pq.Remove(item.ID) + } + + if pq.GetSize() != 0 { + t.Errorf("Expected empty queue after removing all items, got size %d", pq.GetSize()) + } +} + +func BenchmarkAdd(b *testing.B) { + pq := NewRequestPriorityQueue() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + id := fmt.Sprintf("item%d", i) + pq.Add(id, float64(i)) + } +} + +func BenchmarkPeek(b *testing.B) { + pq := NewRequestPriorityQueue() + + // Pre-populate queue + for i := 0; i < 1000; i++ { + pq.Add(fmt.Sprintf("item%d", i), float64(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pq.Peek() + } +} + +func BenchmarkRemove(b *testing.B) { + pq := NewRequestPriorityQueue() + + // Pre-populate queue + for i := 0; i < b.N; i++ { + pq.Add(fmt.Sprintf("item%d", i), float64(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pq.Remove(fmt.Sprintf("item%d", i)) + } +} + +// Helper function to check if a string contains a substring +func contains(s, substr string) bool { + return len(s) >= len(substr) && + (s == substr || + s[:len(substr)] == substr || + s[len(s)-len(substr):] == substr || + containsHelper(s, substr)) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go new file mode 100644 index 000000000..bdeca3037 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go @@ -0,0 +1,136 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "hash/fnv" + "math" + "math/rand" + "time" +) + +// TokenSampler handles Poisson-distributed sampling for predictions only +// Training happens on every token regardless of sampling +type TokenSampler struct { + rng *rand.Rand + nextSampleToken int + samplingMean float64 + maxSamples int + sampleCount int +} + +// SetSamplingMean sets the sampling mean (lambda) for the Poisson distribution +func (ts *TokenSampler) SetSamplingMean(mean float64) { + ts.samplingMean = mean +} + +// SetMaxSamples sets the maximum number of samples +func (ts *TokenSampler) SetMaxSamples(max int) { + ts.maxSamples = max +} + +// SetSampleCount sets the current number of predictions made +func (ts *TokenSampler) SetSampleCount(count int) { + ts.sampleCount = count +} + +func NewTokenSampler(requestID string, samplingMean float64, maxSamples int) *TokenSampler { + // Use request ID hash as seed for reproducibility + seed := int64(0) + if requestID != "" { + hash := fnv.New64a() + hash.Write([]byte(requestID)) + seed = int64(hash.Sum64()) + } + if seed == 0 { + seed = time.Now().UnixNano() + } + + sampler := &TokenSampler{ + rng: rand.New(rand.NewSource(seed)), + samplingMean: samplingMean, + maxSamples: maxSamples, + } + + // Set first sample token (skip token 1 since that's TTFT) + sampler.nextSampleToken = 2 + sampler.poissonNext() + + return sampler +} + +// poissonNext generates the next interval using Poisson distribution +func (ts *TokenSampler) poissonNext() int { + lambda := ts.samplingMean + if lambda <= 0 { + return 1 + } + + // For small lambda, use Knuth's algorithm + if lambda < 30 { + l := math.Exp(-lambda) + k := 0 + p := 1.0 + + for p > l { + k++ + p *= ts.rng.Float64() + } + return k - 1 + } + + // For larger lambda, use normal approximation + normal := ts.rng.NormFloat64() + interval := int(math.Round(lambda + math.Sqrt(lambda)*normal)) + if interval < 1 { + return 1 + } + return interval +} + +// ShouldPredict determines if we should make a prediction for the current token +func (ts *TokenSampler) ShouldPredict(currentToken int) bool { + return currentToken == ts.nextSampleToken && ts.sampleCount < ts.maxSamples +} + +// RecordPrediction records that a prediction was made and calculates the next sample token +func (ts *TokenSampler) RecordPrediction(currentToken int) { + if ts.sampleCount >= ts.maxSamples { + return + } + + ts.sampleCount++ + + if ts.sampleCount < ts.maxSamples { + interval := ts.poissonNext() + ts.nextSampleToken = currentToken + interval + } +} + +// GetNextSampleToken returns the next token to predict for +func (ts *TokenSampler) GetNextSampleToken() int { + return ts.nextSampleToken +} + +// SetNextSampleToken sets the next token to predict for +func (ts *TokenSampler) SetNextSampleToken(token int) { + ts.nextSampleToken = token +} + +// GetSampleCount returns the current number of predictions made +func (ts *TokenSampler) GetSampleCount() int { + return ts.sampleCount +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go new file mode 100644 index 000000000..b476579b5 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go @@ -0,0 +1,325 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "fmt" + "math/rand" + "sync" + "time" + + "sigs.k8s.io/controller-runtime/pkg/log" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +type SLOAwareRouter struct { + tn plugins.TypedName + latencypredictor latencypredictor.PredictorInterface + runningRequestLists map[types.NamespacedName]*RequestPriorityQueue + sloContextStore sync.Map // map[string]*SLORequestContext + headroomStrategy HeadroomStrategy +} + +var _ framework.Scorer = &SLOAwareRouter{} + +func NewSLOAwareRouter(latencypredictor latencypredictor.PredictorInterface, strategy HeadroomStrategy) *SLOAwareRouter { + return &SLOAwareRouter{ + tn: plugins.TypedName{Type: SLOAwareRouterPluginType, Name: SLOAwareRouterPluginType}, + latencypredictor: latencypredictor, + runningRequestLists: make(map[types.NamespacedName]*RequestPriorityQueue), + sloContextStore: sync.Map{}, + headroomStrategy: strategy, + } +} + +func (s *SLOAwareRouter) TypedName() plugins.TypedName { + return s.tn +} + +func (s *SLOAwareRouter) WithName(name string) *SLOAwareRouter { + s.tn.Name = name + return s +} + +// SetHeadroomStrategy allows runtime configuration of headroom selection strategy +func (s *SLOAwareRouter) SetHeadroomStrategy(strategy HeadroomStrategy) { + s.headroomStrategy = strategy +} + +// GetHeadroomStrategy returns the current headroom selection strategy +func (s *SLOAwareRouter) GetHeadroomStrategy() HeadroomStrategy { + return s.headroomStrategy +} + +func (s *SLOAwareRouter) epsilonGreedyAffinityGate( + ctx context.Context, + candidates []PodPredictionResult, + r *rand.Rand, + label string, // e.g. "positive" or "negative" + prefixStickyThreshold float64, +) ([]PodPredictionResult, bool) { + logger := log.FromContext(ctx) + + eligible := make([]PodPredictionResult, 0, len(candidates)) + for _, p := range candidates { + if p.PrefixCacheScore >= prefixStickyThreshold { + eligible = append(eligible, p) + } + } + + // No eligible sticky pods? Explore (no gating). + if len(eligible) == 0 { + return candidates, false + } + + // ε-exploration branch + if r.Float64() < EpsilonExploreSticky { + logger.V(logutil.DEBUG).Info("ε-greedy: exploring (ignoring affinity gate)", + "path", label, "epsilon", EpsilonExploreSticky, "eligibleCount", len(eligible)) + return candidates, false + } + + logger.V(logutil.DEBUG).Info("ε-greedy: exploiting (apply affinity gate)", + "path", label, "threshold", prefixStickyThreshold, "eligibleCount", len(eligible), "total", len(candidates)) + return eligible, true +} + +// scoreWithoutPredictions provides fallback scoring based only on prefix cache scores +// when latency predictions are unavailable +func (s *SLOAwareRouter) scoreWithoutPredictions( + ctx context.Context, + state *schedulingtypes.CycleState, + pods []schedulingtypes.Pod, + r *rand.Rand, +) map[schedulingtypes.Pod]float64 { + logger := log.FromContext(ctx) + logger.V(logutil.TRACE).Info("Using composite-only scoring without predictions") + + scores := make(map[schedulingtypes.Pod]float64, len(pods)) + for _, pod := range pods { + scores[pod] = 0 + } + + if len(pods) == 0 { + return scores + } + + // Build prediction results with only prefix cache scores + podResults := make([]PodPredictionResult, 0, len(pods)) + for _, pod := range pods { + prefixScore := s.getPrefixCacheScoreForPod(ctx, state, pod) + podResults = append(podResults, PodPredictionResult{ + Pod: pod, + PrefixCacheScore: prefixScore, + IsValid: true, // All pods are valid when we don't check predictions + }) + } + + // Select based on composite scores (prefix cache + other non-prediction metrics) + selectedPod := s.selectFromCompositeScores(ctx, podResults, r, HeadroomStrategyCompositeOnly) + + if selectedPod != nil { + scores[selectedPod] = 1 + logger.V(logutil.TRACE).Info("Selected pod using composite-only scoring", "pod", selectedPod.GetPod().String()) + } + + return scores +} + +func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) map[schedulingtypes.Pod]float64 { + logger := log.FromContext(ctx) + if s.latencypredictor == nil { + logger.V(logutil.DEBUG).Info("SLOAwareRouter: no predictor configured, returning nil scores") + return nil + } + + sloCtx := s.getOrMakeSLORequestContext(request) + + var err error + // get request slos + // Get Request SLOs from request header + sloCtx.TTFTSLO, _, err = parseFloatHeader(*request, TTFTSLOHeaderKey) + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", TTFTSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TTFT SLO from header") + } + + sloCtx.AvgTPOTSLO, _, err = parseFloatHeader(*request, TPOTSLOHeaderKey) + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", TPOTSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TPOT SLO from header") + } + sloCtx.PredictorBasedScheduling, err = parseBoolHeader(*request, "x-prediction-based-scheduling") + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("x-prediction-based-scheduling must be a bool: %v", err)}, "SLOAwareRouter: Error parsing PredictorBasedScheduling from header") + } + + // Check if SLOs are provided + if !sloCtx.PredictorBasedScheduling { + logger.V(logutil.DEBUG).Info("PredictorBasedScheduling turned off, skipping prediction-based filtering") + s.setSLOContextForRequest(request, sloCtx) + return nil + } + + // Initialize scores map with all pods having score 0 + scores := make(map[schedulingtypes.Pod]float64, len(pods)) + for _, pod := range pods { + scores[pod] = 0 + } + + source := rand.NewSource(time.Now().UnixNano()) + r := rand.New(source) + predictions, err := s.generatePredictions(ctx, state, request, sloCtx, pods) + if err != nil { + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Error generating predictions, falling back to composite-only scoring") + // Fall back to composite-only scoring using prefix cache scores + s.setSLOContextForRequest(request, sloCtx) + return s.scoreWithoutPredictions(ctx, state, pods, r) + } + s.updateRequestContextWithPredictions(sloCtx, predictions) + + allPreds := append([]PodPredictionResult(nil), predictions...) + allPreds, sticky := s.epsilonGreedyAffinityGate(ctx, allPreds, r, "overall", AffinityGateTauGlobal) + + // Check if all pods are invalid and all have running requests + allPodsInvalid := true + allPodsHaveRunningRequests := true + + for _, pred := range allPreds { + if pred.IsValid { + allPodsInvalid = false + } + + runningRequestCount := s.getPodRunningRequestCount(pred.Pod) + if runningRequestCount == 0 { + allPodsHaveRunningRequests = false + } + } + + // Set HasValidPod to false if all pods are invalid and all have running requests + if allPodsInvalid && allPodsHaveRunningRequests && !sticky { + sloCtx.HasValidPod = false + logger.V(logutil.DEBUG).Info("All pods are invalid and have running requests, setting HasValidPod to false") + } + + // 2) Tiered selection: positive headroom pods get 99% probability, negative get 1% + var posHeadroomPods, negHeadroomPods []PodPredictionResult + for _, p := range allPreds { + // A pod has positive headroom only if BOTH TTFT and TPOT have positive headroom + if p.Headroom > 0 && p.TTFTHeadroom > 0 { + posHeadroomPods = append(posHeadroomPods, p) + } else { + // A pod has negative headroom if EITHER TTFT or TPOT has negative/zero headroom + negHeadroomPods = append(negHeadroomPods, p) + } + } + + logger.V(logutil.DEBUG).Info("Pod headroom distribution", + "positivePods", len(posHeadroomPods), + "negativePods", len(negHeadroomPods)) + + var selectedPod schedulingtypes.Pod + + if s.headroomStrategy == HeadroomStrategyCompositeOnly { + logger.V(logutil.DEBUG).Info("Selecting from composite scores only") + selectedPod = s.selectFromCompositeScores(ctx, allPreds, r, HeadroomStrategyCompositeOnly) + } else if len(posHeadroomPods) > 0 && len(negHeadroomPods) > 0 { + // 99% chance to select from positive headroom pods, 1% from negative + if r.Float64() < EpsilonExploreNeg { + logger.V(logutil.DEBUG).Info("Selecting from negative headroom pods (1% chance)") + selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) + } else { + logger.V(logutil.DEBUG).Info("Selecting from positive headroom pods (99% chance)") + selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) + } + } else if len(posHeadroomPods) > 0 { + // If only positive headroom pods exist, select from them + logger.V(logutil.DEBUG).Info("Only positive headroom pods available") + selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) + } else if len(negHeadroomPods) > 0 { + // If only negative headroom pods exist, select from them + logger.V(logutil.DEBUG).Info("Only negative headroom pods available") + selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) + } else if len(allPreds) > 0 { + // fallback - select randomly from valid pods + logger.V(logutil.DEBUG).Info("No headroom pods available, selecting randomly from valid pods") + selectedPod = allPreds[r.Intn(len(allPreds))].Pod + } else { + // No valid pods - return all zeros + logger.V(logutil.DEBUG).Info("No valid pods available, returning all zero scores") + return scores + } + + // Set score = 1 for selected pod, 0 for all others + if selectedPod != nil { + scores[selectedPod] = 1 + logger.V(logutil.DEBUG).Info("Selected pod for scheduling", "pod", selectedPod.GetPod().String()) + } + + s.setSLOContextForRequest(request, sloCtx) + + return scores +} + +func (t *SLOAwareRouter) getOrMakeSLORequestContext(request *schedulingtypes.LLMRequest) *SLORequestContext { + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + sloCtx = NewSLORequestContext(request) + } + return sloCtx +} + +func (s *SLOAwareRouter) getPrefixCacheScoreForPod(ctx context.Context, cycleState *schedulingtypes.CycleState, pod schedulingtypes.Pod) float64 { + log.FromContext(ctx).V(logutil.DEBUG).Info("Running getPrefixCacheScoreForPod, getting prefix cache score for pod", "pod", pod.GetPod().String()) + plugintype := prefix.PrefixCachePluginType + pluginname := prefix.PrefixCachePluginType + cycleStateKey := (plugins.TypedName{Type: plugintype, Name: pluginname}).String() + stateData, err := cycleState.Read(plugins.StateKey(cycleStateKey)) + + log.FromContext(ctx).V(logutil.DEBUG).Info("Reading prefix cache state from cycle state", "stateKey", cycleStateKey) + + if err != nil { + // The prefix cache plugin might not be enabled, which is a valid scenario. + log.FromContext(ctx).V(logutil.DEBUG).Info("Prefix cache state not found in cycle state, returning prefix cache score of 0.0", "pod", pod.GetPod().String()) + return 0.0 + } + + prefixCacheState, ok := stateData.(*prefix.SchedulingContextState) + if !ok { + // This should not happen if the plugin is configured correctly. + log.FromContext(ctx).Error(fmt.Errorf("unexpected state type: %T", stateData), "failed to read prefix cache state") + return 0.0 + } + + total := len(prefixCacheState.PrefixHashes) + if total == 0 { + // if the request has no prefixes, return 0.0 + log.FromContext(ctx).V(logutil.DEBUG).Info("No prefixes found in request, returning prefix cache score of 0.0") + return 0.0 + } + + matchLen := prefixCacheState.PrefixCacheServers[prefix.ServerID(pod.GetPod().NamespacedName)] + log.FromContext(ctx).V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "matchLen", matchLen, "totalPrefixes", total) + return float64(matchLen) / float64(total) +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go new file mode 100644 index 000000000..da073ff65 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go @@ -0,0 +1,527 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/types" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +// mockPredictor implements PredictorInterface for testing +type mockPredictor struct { + predictions map[string]*latencypredictor.PredictionResponse + err error +} + +func (m *mockPredictor) Predict(ctx context.Context, request latencypredictor.PredictionRequest) (*latencypredictor.PredictionResponse, error) { + if m.err != nil { + return nil, m.err + } + // Generate a key based on KV cache percentage to return different predictions for different pods + key := fmt.Sprintf("%.1f", request.KVCachePercentage) + if pred, ok := m.predictions[key]; ok { + return pred, nil + } + // Default prediction + return &latencypredictor.PredictionResponse{TTFT: 0.5, TPOT: 0.03}, nil +} + +func (m *mockPredictor) PredictBulk(ctx context.Context, requests []latencypredictor.PredictionRequest) (*latencypredictor.BulkPredictionResponse, error) { + if m.err != nil { + return nil, m.err + } + // Generate a key based on KV cache percentage to return different predictions for different pods + responses := make([]latencypredictor.PredictionResponse, 0, len(requests)) + for _, request := range requests { + key := fmt.Sprintf("%.1f", request.KVCachePercentage) + if pred, ok := m.predictions[key]; ok { + responses = append(responses, *pred) + } else { + return nil, fmt.Errorf("no prediction for key %s", key) + } + } + return &latencypredictor.BulkPredictionResponse{Predictions: responses}, nil +} + +func (m *mockPredictor) PredictBulkStrict(ctx context.Context, requests []latencypredictor.PredictionRequest) (*latencypredictor.BulkPredictionResponse, error) { + if m.err != nil { + return nil, m.err + } + // Generate a key based on KV cache percentage to return different predictions for different pods + responses := make([]latencypredictor.PredictionResponse, 0, len(requests)) + for _, request := range requests { + key := fmt.Sprintf("%.1f", request.KVCachePercentage) + if pred, ok := m.predictions[key]; ok { + responses = append(responses, *pred) + } else { + return nil, fmt.Errorf("no prediction for key %s", key) + } + } + return &latencypredictor.BulkPredictionResponse{Predictions: responses}, nil +} + +func (m *mockPredictor) AddTrainingDataBulk(data []latencypredictor.TrainingEntry) error { + return nil +} + +func (m *mockPredictor) AddTrainingData(data latencypredictor.TrainingEntry) error { + return nil +} + +func (m *mockPredictor) HealthCheck() error { + return nil +} + +func (m *mockPredictor) GetServerStatus(ctx context.Context) (*latencypredictor.ServerStatusResponse, error) { + return &latencypredictor.ServerStatusResponse{}, nil +} + +func createTestPod(name string, kvCacheUsage float64, runningQueueSize, waitingQueueSize int) schedulingtypes.Pod { + return &schedulingtypes.PodMetrics{ + Pod: &backend.Pod{ + NamespacedName: types.NamespacedName{ + Name: name, + Namespace: "default", + }, + }, + MetricsState: &backendmetrics.MetricsState{ + KVCacheUsagePercent: kvCacheUsage, + RunningQueueSize: runningQueueSize, + WaitingQueueSize: waitingQueueSize, + }, + } +} + +func createTestLLMRequest(reqID string, ttftSLO, tpotSLO float64, predictionBased bool) *schedulingtypes.LLMRequest { + headers := make(map[string]string) + headers[requtil.RequestIdHeaderKey] = reqID + if ttftSLO > 0 { + headers["x-ttft-slo"] = fmt.Sprintf("%f", ttftSLO) + } + if tpotSLO > 0 { + headers["x-avg-tpot-slo"] = fmt.Sprintf("%f", tpotSLO) + } + headers["x-prediction-based-scheduling"] = fmt.Sprintf("%t", predictionBased) + + return &schedulingtypes.LLMRequest{ + Headers: headers, + Body: &schedulingtypes.LLMRequestBody{ + Completions: &schedulingtypes.CompletionsRequest{ + Prompt: "test prompt", + }, + }, + } +} + +func TestSLOAwareRouter_Score(t *testing.T) { + tests := []struct { + name string + predictor *mockPredictor + strategy HeadroomStrategy + request *schedulingtypes.LLMRequest + pods []schedulingtypes.Pod + expectedScores map[string]float64 // Map of pod name to expected score + expectNil bool + }{ + { + name: "Prediction-based scheduling disabled", + predictor: &mockPredictor{}, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, false), // predictionBased = false + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), // 50% KV cache, 2 running, 1 waiting + createTestPod("pod2", 0.7, 3, 2), // 70% KV cache, 3 running, 2 waiting + }, + expectNil: true, + }, + { + name: "No predictor configured", + predictor: nil, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), + }, + expectNil: true, + }, + { + name: "All pods have positive headroom", + predictor: &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.5": {TTFT: 0.5, TPOT: 0.03}, // 50% KV cache + "0.6": {TTFT: 0.6, TPOT: 0.04}, // 60% KV cache + "0.3": {TTFT: 0.4, TPOT: 0.02}, // 30% KV cache + }, + }, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), // 50% KV cache + createTestPod("pod2", 0.6, 3, 2), // 60% KV cache + createTestPod("pod3", 0.3, 1, 0), // 30% KV cache + }, + // One pod should be selected with score 1, others 0 + expectedScores: map[string]float64{ + // We can't predict which one due to randomness, but exactly one should be 1 + }, + }, + { + name: "All pods have negative headroom", + predictor: &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.8": {TTFT: 1.5, TPOT: 0.08}, // 80% KV cache - high load + "0.9": {TTFT: 1.8, TPOT: 0.09}, // 90% KV cache - very high load + }, + }, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.8, 5, 3), // 80% KV cache, high load + createTestPod("pod2", 0.9, 6, 4), // 90% KV cache, very high load + }, + // One pod should still be selected even with negative headroom + expectedScores: map[string]float64{}, + }, + { + name: "Mixed positive and negative headroom", + predictor: &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.3": {TTFT: 0.5, TPOT: 0.03}, // 30% KV cache - Positive headroom + "0.9": {TTFT: 1.5, TPOT: 0.08}, // 90% KV cache - Negative headroom + }, + }, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod-positive", 0.3, 1, 0), // Low KV cache, positive headroom + createTestPod("pod-negative", 0.9, 6, 4), // High KV cache, negative headroom + }, + // With 99% probability, positive headroom pod should be selected + expectedScores: map[string]float64{}, + }, + { + name: "Prediction errors - fallback to composite scoring", + predictor: &mockPredictor{ + err: fmt.Errorf("prediction failed"), + }, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), + createTestPod("pod2", 0.6, 3, 2), + }, + // Should fall back to composite-only scoring and select one pod + expectedScores: map[string]float64{ + // One pod should be selected with score 1, verified in general validation below + }, + }, + { + name: "Empty pod list", + predictor: &mockPredictor{}, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{}, + // Should return empty scores map + expectedScores: map[string]float64{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var router *SLOAwareRouter + if tt.predictor == nil { + router = NewSLOAwareRouter(nil, tt.strategy) + } else { + router = NewSLOAwareRouter(tt.predictor, tt.strategy) + } + + scores := router.Score(context.Background(), schedulingtypes.NewCycleState(), tt.request, tt.pods) + + if tt.expectNil { + assert.Nil(t, scores, "Expected nil scores") + return + } + + assert.NotNil(t, scores, "Expected non-nil scores") + + // If we have specific expected scores, verify them + if len(tt.expectedScores) > 0 { + for _, pod := range tt.pods { + podName := pod.GetPod().NamespacedName.Name + if expectedScore, ok := tt.expectedScores[podName]; ok { + assert.InDelta(t, expectedScore, scores[pod], 0.0001, "Pod %s should have score %f", podName, expectedScore) + } + } + } + + // General validation: exactly one pod should have score 1 (selected), others should have score 0 + // This applies even when predictions fail because we fall back to composite scoring + if !tt.expectNil && len(tt.pods) > 0 && tt.predictor != nil { + selectedCount := 0 + for _, score := range scores { + if score == 1.0 { + selectedCount++ + } else { + assert.InDelta(t, 0.0, score, 0.0001, "Non-selected pods should have score 0") + } + } + assert.Equal(t, 1, selectedCount, "Exactly one pod should be selected with score 1") + } + }) + } +} + +func TestSLOAwareRouter_Strategies(t *testing.T) { + tests := []struct { + name string + strategy HeadroomStrategy + }{ + { + name: "HeadroomStrategyLeast", + strategy: HeadroomStrategyLeast, + }, + { + name: "HeadroomStrategyMost", + strategy: HeadroomStrategyMost, + }, + { + name: "HeadroomStrategyCompositeMost", + strategy: HeadroomStrategyCompositeMost, + }, + { + name: "HeadroomStrategyCompositeLeast", + strategy: HeadroomStrategyCompositeLeast, + }, + { + name: "HeadroomStrategyCompositeOnly", + strategy: HeadroomStrategyCompositeOnly, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.5": {TTFT: 0.5, TPOT: 0.03}, + "0.6": {TTFT: 0.6, TPOT: 0.04}, + "0.3": {TTFT: 0.4, TPOT: 0.02}, + }, + } + router := NewSLOAwareRouter(predictor, tt.strategy) + + request := createTestLLMRequest("test", 1.0, 0.05, true) + pods := []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), + createTestPod("pod2", 0.6, 3, 2), + createTestPod("pod3", 0.3, 1, 0), + } + + scores := router.Score(context.Background(), schedulingtypes.NewCycleState(), request, pods) + + assert.NotNil(t, scores, "Expected non-nil scores for strategy %s", tt.strategy) + + // Verify exactly one pod is selected + selectedCount := 0 + for _, score := range scores { + if score == 1.0 { + selectedCount++ + } + } + assert.Equal(t, 1, selectedCount, "Strategy %s should select exactly one pod", tt.strategy) + }) + } +} + +func TestSLOAwareRouter_SetHeadroomStrategy(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + + assert.Equal(t, HeadroomStrategyLeast, router.GetHeadroomStrategy(), "Initial strategy should be Least") + + router.SetHeadroomStrategy(HeadroomStrategyMost) + assert.Equal(t, HeadroomStrategyMost, router.GetHeadroomStrategy(), "Strategy should be updated to Most") + + router.SetHeadroomStrategy(HeadroomStrategyCompositeOnly) + assert.Equal(t, HeadroomStrategyCompositeOnly, router.GetHeadroomStrategy(), "Strategy should be updated to CompositeOnly") +} + +func TestSLOAwareRouter_TypedName(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + + tn := router.TypedName() + assert.Equal(t, "slo-aware-routing", tn.Type, "Type should be slo-aware-routing") + assert.Equal(t, "slo-aware-routing", tn.Name, "Default name should be slo-aware-routing") +} + +func TestSLOAwareRouter_WithName(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + + customName := "custom-router" + router = router.WithName(customName) + + tn := router.TypedName() + assert.Equal(t, "slo-aware-routing", tn.Type, "Type should remain slo-aware-routing") + assert.Equal(t, customName, tn.Name, "Name should be updated to custom name") +} + +func TestSLOAwareRouter_GetPodRunningRequestCount(t *testing.T) { + tests := []struct { + name string + setupRequests func(*SLOAwareRouter, schedulingtypes.Pod) + expectedCount int + }{ + { + name: "No running requests", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) {}, + expectedCount: 0, + }, + { + name: "One running request", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = NewRequestPriorityQueue() + r.runningRequestLists[podName].Add("req1", 0.04) + }, + expectedCount: 1, + }, + { + name: "Multiple running requests", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = NewRequestPriorityQueue() + r.runningRequestLists[podName].Add("req1", 0.04) + r.runningRequestLists[podName].Add("req2", 0.03) + r.runningRequestLists[podName].Add("req3", 0.05) + }, + expectedCount: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + pod := createTestPod("test-pod", 0.5, 2, 1) + + tt.setupRequests(router, pod) + + count := router.getPodRunningRequestCount(pod) + assert.Equal(t, tt.expectedCount, count, "Running request count should match expected") + }) + } +} + +func TestSLOAwareRouter_GetPodMinTPOTSLO(t *testing.T) { + tests := []struct { + name string + setupRequests func(*SLOAwareRouter, schedulingtypes.Pod) + expectedSLO float64 + }{ + { + name: "No running requests", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) {}, + expectedSLO: 0.0, + }, + { + name: "One running request", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = NewRequestPriorityQueue() + r.runningRequestLists[podName].Add("req1", 0.04) + }, + expectedSLO: 0.04, + }, + { + name: "Multiple running requests - should return minimum", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = NewRequestPriorityQueue() + // Add in any order - heap will maintain minimum at top + r.runningRequestLists[podName].Add("req1", 0.05) + r.runningRequestLists[podName].Add("req2", 0.03) // This is the minimum + r.runningRequestLists[podName].Add("req3", 0.04) + }, + expectedSLO: 0.03, // Minimum TPOT (heap guarantees this is at items[0]) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + pod := createTestPod("test-pod", 0.5, 2, 1) + + tt.setupRequests(router, pod) + + minSLO := router.getPodMinTPOTSLO(pod) + assert.InDelta(t, tt.expectedSLO, minSLO, 0.0001, "Min TPOT SLO should match expected") + }) + } +} + +func TestSLOAwareRouter_GetPrefixCacheScoreForPod(t *testing.T) { + tests := []struct { + name string + setupState func(*schedulingtypes.CycleState) + expectedScore float64 + }{ + { + name: "No prefix cache state", + setupState: func(s *schedulingtypes.CycleState) {}, + expectedScore: 0.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + + state := schedulingtypes.NewCycleState() + tt.setupState(state) + + pod := createTestPod("test-pod", 0.5, 2, 1) + + score := router.getPrefixCacheScoreForPod(context.Background(), state, pod) + assert.InDelta(t, tt.expectedScore, score, 0.0001, "Prefix cache score should match expected") + }) + } +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go new file mode 100644 index 000000000..eeab50433 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go @@ -0,0 +1,385 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "context" + "math" + "math/rand" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/log" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// selectFromPositiveHeadroomPods selects a pod from positive headroom pods using headroom strategy +// Updated to incorporate TTFTHeadroom with a configurable blend vs TPOT headroom. +func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, posHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + logger := log.FromContext(ctx) + + if len(posHeadroomPods) == 1 { + return posHeadroomPods[0].Pod + } + + // Apply perfect stickiness (with exploration) + candidates, sticky := s.epsilonGreedyAffinityGate(ctx, posHeadroomPods, r, "positive", AffinityGateTau) + + // If perfect stickiness collapsed us to a single pod, short-circuit + if sticky && len(candidates) == 1 { + return candidates[0].Pod + } + switch s.headroomStrategy { + case HeadroomStrategyCompositeMost: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeMost) + case HeadroomStrategyCompositeLeast: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeLeast) + } + + // Find min/max for TPOT (Headroom) and TTFTHeadroom across positive pods to normalize to [0,1] + minTPOTH, maxTPOTH := math.MaxFloat64, -math.MaxFloat64 + minTTFTH, maxTTFTH := math.MaxFloat64, -math.MaxFloat64 + + for _, p := range candidates { + if p.Headroom < minTPOTH { + minTPOTH = p.Headroom + } + if p.Headroom > maxTPOTH { + maxTPOTH = p.Headroom + } + if p.TTFTHeadroom < minTTFTH { + minTTFTH = p.TTFTHeadroom + } + if p.TTFTHeadroom > maxTTFTH { + maxTTFTH = p.TTFTHeadroom + } + } + + tpotRange := maxTPOTH - minTPOTH + ttftRange := maxTTFTH - minTTFTH + + // Precompute blend weights (renormalize if user sets both to 0) + alpha := HeadroomTTFTWeight + beta := HeadroomTPOTWeight + if alpha+beta <= 0 { + alpha = 1.0 + beta = 0.0 + } + sum := alpha + beta + alpha /= sum + beta /= sum + + logger.V(logutil.DEBUG).Info("Positive headroom normalization ranges", + "minTPOTHeadroom", minTPOTH, "maxTPOTHeadroom", maxTPOTH, + "minTTFTHeadroom", minTTFTH, "maxTTFTHeadroom", maxTTFTH, + "alphaTTFT", alpha, "betaTPOT", beta, "strategy", s.headroomStrategy) + + // Calculate weights for weighted random selection + weightedChoices := make([]Choice, 0, len(candidates)) + total := 0 + + for _, p := range candidates { + // Normalize to [0,1] within the cohort + nTPOTH := 0.5 + if tpotRange > eps { + nTPOTH = (p.Headroom - minTPOTH) / (tpotRange + eps) + } + nTTFTH := 0.5 + if ttftRange > eps { + nTTFTH = (p.TTFTHeadroom - minTTFTH) / (ttftRange + eps) + } + + // Blend: larger combined -> "safer"; smaller -> "tighter packing" + combined := alpha*nTTFTH + beta*nTPOTH + + // Map to integer weights + var w int + switch s.headroomStrategy { + case HeadroomStrategyLeast: + // prefer smaller combined headroom (pack closer to limits) + w = int((1.0-combined)*float64(Wmax-minWeight)) + minWeight + 1 + case HeadroomStrategyMost: + // prefer larger combined headroom (more conservative / spread) + w = int(combined*float64(Wmax-minWeight)) + minWeight + 1 + default: + // Fallback to least + w = int((1.0-combined)*float64(Wmax-minWeight)) + minWeight + 1 + } + + weightedChoices = append(weightedChoices, Choice{PodName: p.Pod, Weight: w}) + total += w + + logger.V(logutil.TRACE).Info("Positive headroom blended weight", + "pod", p.Pod.GetPod().String(), + "ttftHeadroom", p.TTFTHeadroom, "normTTFTHeadroom", nTTFTH, + "tpotHeadroom", p.Headroom, "normTPOTHeadroom", nTPOTH, + "combined", combined, "weight", w) + } + + return s.performWeightedRandomSelection(weightedChoices, total, candidates, r) + +} + +// selectFromNegativeHeadroomPods selects a pod from negative headroom pods using hierarchical TTFT/TPOT logic +// Modified to strictly prefer pods with 0 running requests +func (s *SLOAwareRouter) selectFromNegativeHeadroomPods(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + logger := log.FromContext(ctx) + + if len(negHeadroomPods) == 1 { + return negHeadroomPods[0].Pod + } + + // First, separate pods by running request count + var zeroRunningRequestPods, nonZeroRunningRequestPods []PodPredictionResult + + for _, p := range negHeadroomPods { + runningRequestCount := s.getPodRunningRequestCount(p.Pod) + if runningRequestCount == 0 { + zeroRunningRequestPods = append(zeroRunningRequestPods, p) + } else { + nonZeroRunningRequestPods = append(nonZeroRunningRequestPods, p) + } + } + + logger.V(logutil.DEBUG).Info("Negative headroom pods by running request count", + "zeroRunningRequests", len(zeroRunningRequestPods), + "nonZeroRunningRequests", len(nonZeroRunningRequestPods)) + + // If we have pods with 0 running requests, strictly prefer them + if len(zeroRunningRequestPods) > 0 { + logger.V(logutil.DEBUG).Info("Selecting from pods with zero running requests") + return s.selectFromNegativeHeadroomPodsInternal(ctx, zeroRunningRequestPods, r) + } + + // Otherwise, fall back to pods with running requests + logger.V(logutil.DEBUG).Info("No pods with zero running requests, selecting from pods with running requests") + return s.selectFromNegativeHeadroomPodsInternal(ctx, nonZeroRunningRequestPods, r) +} + +// selectFromNegativeHeadroomPodsInternal handles the actual selection logic for negative headroom pods +func (s *SLOAwareRouter) selectFromNegativeHeadroomPodsInternal(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + if len(negHeadroomPods) == 1 { + return negHeadroomPods[0].Pod + } + + // Apply perfect stickiness (with exploration) + candidates, sticky := s.epsilonGreedyAffinityGate(ctx, negHeadroomPods, r, "negative", AffinityGateTau) + + // If perfect stickiness collapsed us to a single pod, short-circuit + if sticky && len(candidates) == 1 { + return candidates[0].Pod + } + + switch s.headroomStrategy { + case HeadroomStrategyCompositeMost: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeMost) + case HeadroomStrategyCompositeLeast: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeMost) + } + + // Build weighted choices for selection + weightedChoices := make([]Choice, 0, len(candidates)) + total := 0 + + s.handleNegativeHeadroomPodsHierarchical(ctx, candidates, &weightedChoices, &total, minWeight) + + // Perform weighted random selection + return s.performWeightedRandomSelection(weightedChoices, total, candidates, r) +} + +// weightPodsByBlendedDeficit applies blended weighting using TTFT and TPOT deficits. +// Lower blended deficit => higher weight. +func (ps *SLOAwareRouter) weightPodsByBlendedDeficit( + ctx context.Context, + pods []PodPredictionResult, + choices *[]Choice, + total *int, + minWeight int, + alpha, beta float64, // weights for TTFT and TPOT deficits + category string, +) { + logger := log.FromContext(ctx) + if len(pods) == 0 { + return + } + + const Wrange = 80 + const eps = 1e-9 + + // Compute raw deficits (only when headroom is negative) + type deficits struct { + pod PodPredictionResult + ttftDef float64 + tpotDef float64 + } + defs := make([]deficits, 0, len(pods)) + + minTTFT, maxTTFT := math.MaxFloat64, -math.MaxFloat64 + minTPOT, maxTPOT := math.MaxFloat64, -math.MaxFloat64 + + for _, p := range pods { + ttftDef := 0.0 + if p.TTFTHeadroom < 0 { + ttftDef = -p.TTFTHeadroom + } + tpotDef := 0.0 + if p.Headroom < 0 { + tpotDef = -p.Headroom + } + defs = append(defs, deficits{pod: p, ttftDef: ttftDef, tpotDef: tpotDef}) + + if ttftDef < minTTFT { + minTTFT = ttftDef + } + if ttftDef > maxTTFT { + maxTTFT = ttftDef + } + if tpotDef < minTPOT { + minTPOT = tpotDef + } + if tpotDef > maxTPOT { + maxTPOT = tpotDef + } + } + + ttftRange := maxTTFT - minTTFT + tpotRange := maxTPOT - minTPOT + + // Normalize alpha/beta + if alpha+beta <= 0 { + alpha, beta = 1.0, 0.0 + } else { + sum := alpha + beta + alpha /= sum + beta /= sum + } + + logger.V(logutil.DEBUG).Info("Negative headroom blended deficits", + "category", category, + "minTTFTDef", minTTFT, "maxTTFTDef", maxTTFT, + "minTPOTDef", minTPOT, "maxTPOTDef", maxTPOT, + "alphaTTFT", alpha, "betaTPOT", beta, "podCount", len(pods)) + + for _, d := range defs { + // Normalize deficits to [0,1] within this bucket (0 = best / least violation) + nTTFT := 0.0 + if ttftRange > eps { + nTTFT = (d.ttftDef - minTTFT) / (ttftRange + eps) + } + nTPOT := 0.0 + if tpotRange > eps { + nTPOT = (d.tpotDef - minTPOT) / (tpotRange + eps) + } + + // Blended "badness": higher = worse violation + blended := alpha*nTTFT + beta*nTPOT + + // Convert to selection weight: lower badness -> higher weight + // Ensure a floor so no pod is completely excluded within the bucket. + w := int((1.0-blended)*float64(Wrange)) + minWeight + 1 + + *choices = append(*choices, Choice{PodName: d.pod.Pod, Weight: w}) + *total += w + + logger.V(logutil.TRACE).Info("Negative bucket blended weighting", + "pod", d.pod.Pod.GetPod().String(), + "ttftDef", d.ttftDef, "tpotDef", d.tpotDef, + "normTTFT", nTTFT, "normTPOT", nTPOT, + "blendedBadness", blended, "weight", w) + } +} + +func (s *SLOAwareRouter) handleNegativeHeadroomPodsHierarchical( + ctx context.Context, + negHeadroomPods []PodPredictionResult, + choices *[]Choice, + total *int, + minWeightForNegative int, +) { + logger := log.FromContext(ctx) + + // Categorize pods by their headroom status + var negTTFTNegTPOT, negTTFTNonNegTPOT, nonNegTTFTNegTPOT, nonNegTTFTNonNegTPOT []PodPredictionResult + + for _, p := range negHeadroomPods { + if p.TTFTHeadroom < 0 && p.Headroom < 0 { + negTTFTNegTPOT = append(negTTFTNegTPOT, p) + } else if p.TTFTHeadroom < 0 && p.Headroom >= 0 { + negTTFTNonNegTPOT = append(negTTFTNonNegTPOT, p) + } else if p.TTFTHeadroom >= 0 && p.Headroom < 0 { + nonNegTTFTNegTPOT = append(nonNegTTFTNegTPOT, p) + } else { + nonNegTTFTNonNegTPOT = append(nonNegTTFTNonNegTPOT, p) + } + } + + logger.V(logutil.DEBUG).Info("Hierarchical negative headroom pod distribution", + "totalNegative", len(negHeadroomPods), + "negTTFT_negTPOT", len(negTTFTNegTPOT), + "negTTFT_nonNegTPOT", len(negTTFTNonNegTPOT), + "nonNegTTFT_negTPOT", len(nonNegTTFTNegTPOT), + "nonNegTTFT_nonNegTPOT", len(nonNegTTFTNonNegTPOT)) + + // Priority 1: both TTFT and TPOT negative -> blended deficits (both active) + if len(negTTFTNegTPOT) > 0 { + s.weightPodsByBlendedDeficit(ctx, negTTFTNegTPOT, choices, total, minWeightForNegative, + NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "both_negative") + } + + // Priority 2: TTFT negative, TPOT non-negative -> blended still works (TPOT deficit=0) + if len(negTTFTNonNegTPOT) > 0 { + s.weightPodsByBlendedDeficit(ctx, negTTFTNonNegTPOT, choices, total, minWeightForNegative, + NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "ttft_negative") + } + + // Priority 3: TTFT non-negative, TPOT negative -> blended (TTFT deficit=0) + if len(nonNegTTFTNegTPOT) > 0 { + s.weightPodsByBlendedDeficit(ctx, nonNegTTFTNegTPOT, choices, total, minWeightForNegative, + NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "tpot_negative") + } + + // Priority 4: edge-case bucket -> minimal weight + for _, p := range nonNegTTFTNonNegTPOT { + *choices = append(*choices, Choice{PodName: p.Pod, Weight: minWeightForNegative}) + *total += minWeightForNegative + } +} + +func (s *SLOAwareRouter) getPodMinTPOTSLO(pod schedulingtypes.Pod) float64 { + podName := types.NamespacedName{ + Name: pod.GetPod().NamespacedName.Name, + Namespace: pod.GetPod().NamespacedName.Namespace, + } + if runningReqs, ok := s.runningRequestLists[podName]; ok && runningReqs.GetSize() > 0 { + if topReq := runningReqs.Peek(); topReq != nil { + return topReq.TPOT + } + } + return 0 // no running requests or no TPOT SLOs +} + +func (s *SLOAwareRouter) getPodRunningRequestCount(pod schedulingtypes.Pod) int { + podName := types.NamespacedName{ + Name: pod.GetPod().NamespacedName.Name, + Namespace: pod.GetPod().NamespacedName.Namespace, + } + if runningReqs, ok := s.runningRequestLists[podName]; ok { + return runningReqs.GetSize() + } + return 0 // no running requests +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go new file mode 100644 index 000000000..8030866d8 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go @@ -0,0 +1,57 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + +type HeadroomStrategy string + +type Choice struct { + PodName schedulingtypes.Pod + Weight int +} + +const ( + // HeadroomStrategyLeast prioritizes pods with least positive headroom (better packing) + HeadroomStrategyLeast HeadroomStrategy = "least" + // HeadroomStrategyMost prioritizes pods with most positive headroom (more conservative) + HeadroomStrategyMost HeadroomStrategy = "most" + + HeadroomStrategyCompositeLeast HeadroomStrategy = "composite-least" + HeadroomStrategyCompositeMost HeadroomStrategy = "composite-most" + HeadroomStrategyCompositeOnly HeadroomStrategy = "composite-only" + + // TTFT header string + TTFTSLOHeaderKey = "x-slo-ttft-ms" + // TPOT header string + TPOTSLOHeaderKey = "x-slo-tpot-ms" +) + +const ( + SLOAwareRouterPluginType = "slo-aware-routing" + eps = 1e-9 + Wmax = 100 + minWeight = 1 +) + +type PodSelectionMode string + +const ( + PodSelectionLinear PodSelectionMode = "linear" // weighted-random (current behavior) + PodSelectionMax PodSelectionMode = "max" // pick argmax weight +) diff --git a/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go b/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go new file mode 100644 index 000000000..900335c9e --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go @@ -0,0 +1,154 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package profile + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +const ( + SLOAwareProfileHandlerType = "slo-aware-profile-handler" + DefaultProfileName = "default" + PrefixProfileName = "prefix" + SLOProfileName = "slo" + + // Boolean header string for whether to use predictor based scheduling + PreictionBasedSchedulingHeaderKey = "x-prediction-based-scheduling" +) + +// compile-time type assertion +var _ framework.ProfileHandler = &SLOAwareProfileHandler{} + +// SLOAwareProfileHandlerFactory defines the factory function for SLOAwareProfileHandler. +func SLOAwareProfileHandlerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + return NewSLOAwareProfileHandler().WithName(name), nil +} + +// NewSLOAwareProfileHandler initializes a new SLOAwareProfileHandler and returns its pointer. +func NewSLOAwareProfileHandler() *SLOAwareProfileHandler { + return &SLOAwareProfileHandler{ + typedName: plugins.TypedName{Type: SLOAwareProfileHandlerType, Name: SLOAwareProfileHandlerType}, + } +} + +// SLOAwareProfileHandler handles two profiles: the default profile and the SLO profile. +// When the request has PredictorBasedScheduling=true, it uses the SLO profile result to select +// the destination pod. Otherwise, it uses the default profile result. +type SLOAwareProfileHandler struct { + typedName plugins.TypedName + prefixProfile string // the profile that should be executed first + +} + +// TypedName returns the type and name tuple of this plugin instance. +func (h *SLOAwareProfileHandler) TypedName() plugins.TypedName { + return h.typedName +} + +// WithName sets the name of the profile handler. +func (h *SLOAwareProfileHandler) WithName(name string) *SLOAwareProfileHandler { + h.typedName.Name = name + return h +} + +// Pick selects the SchedulingProfiles to run from the list of candidate profiles, while taking into consideration the request properties and the +// previously executed cycles along with their results. +func (h *SLOAwareProfileHandler) Pick(_ context.Context, _ *types.CycleState, request *types.LLMRequest, profiles map[string]*framework.SchedulerProfile, + profileResults map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile { + if len(profiles) == len(profileResults) { // all profiles have been executed already in previous call + return map[string]*framework.SchedulerProfile{} + } + + if _, executed := profileResults[PrefixProfileName]; !executed { + // if prefix profile was not executed yet, first let the scheduler run the decode profile + return map[string]*framework.SchedulerProfile{ + PrefixProfileName: profiles[PrefixProfileName], + } + } + // otherwise, prefix was already executed. + + // return all profiles except prefix. + profilesToRun := make(map[string]*framework.SchedulerProfile) + for name, profile := range profiles { + if name != PrefixProfileName { + profilesToRun[name] = profile + } + } + return profilesToRun +} + +// ProcessResults handles the outcome of the profile runs after all profiles ran. +// It may aggregate results, log test profile outputs, or apply custom logic. It specifies in the SchedulingResult the +// key of the primary profile that should be used to get the request selected destination. +// When a profile run fails, its result in the profileResults map is nil. +func (h *SLOAwareProfileHandler) ProcessResults(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) { + + if len(profileResults) < 2 { + return nil, errors.New("SLOAwareProfileHandler requires at least two profiles to operate") + } + + predictorBasedScheduling, err := parseBoolHeader(*request, PreictionBasedSchedulingHeaderKey) + if err != nil { + return nil, fmt.Errorf("error parsing predictorBasedScheduling from header failed to choose scheduling profile: x-prediction-based-scheduling must be a bool: %v", err) + } + + if predictorBasedScheduling { // TODO grab header directly from request.Headers instead of request field + if profileResults[SLOProfileName] == nil { // there was an error while running the SLO profile + return nil, fmt.Errorf("failed to run scheduler profile '%s'", SLOProfileName) + } + return &types.SchedulingResult{ + ProfileResults: profileResults, + PrimaryProfileName: SLOProfileName, + }, nil + } + + if profileResults[DefaultProfileName] == nil { // there was an error while running the default profile + return nil, fmt.Errorf("failed to run scheduler profile '%s'", DefaultProfileName) + } + + return &types.SchedulingResult{ + ProfileResults: profileResults, + PrimaryProfileName: DefaultProfileName, + }, nil +} + +// parseFloatHeader retrieves a header by name, parses it as a bool, +// and returns the value or an error if the header is missing or invalid. +func parseBoolHeader(request types.LLMRequest, headerName string) (bool, error) { + // 1. Get header value from the map + headerValue, ok := request.Headers[headerName] + if !ok { + return false, nil // Header not found, return 0 and false + } + + // 2. Parse the header value to a bool + parsedBool, err := strconv.ParseBool(headerValue) + if err != nil { + return false, fmt.Errorf("must be a bool: %v", headerName) + } + + // 3. Return the successfully parsed value + return parsedBool, nil +} diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index c3037175e..014c51d82 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -43,6 +43,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" ) // ExtProcServerRunner provides methods to manage an external process server. @@ -59,6 +60,7 @@ type ExtProcServerRunner struct { Director *requestcontrol.Director SaturationDetector *saturationdetector.Detector UseExperimentalDatalayerV2 bool // Pluggable data layer feature flag + LatencyPredictor latencypredictor.PredictorInterface // This should only be used in tests. We won't need this once we do not inject metrics in the tests. // TODO:(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/432) Cleanup @@ -78,6 +80,7 @@ const ( DefaultHealthChecking = false // default for --health-checking DefaultEnablePprof = true // default for --enable-pprof DefaultTotalQueuedRequestsMetric = "vllm:num_requests_waiting" // default for --total-queued-requests-metric + DefaultTotalRunningRequestsMetric = "vllm:num_requests_running" // default for --total-running-requests-metric DefaultKvCacheUsagePercentageMetric = "vllm:gpu_cache_usage_perc" // default for --kv-cache-usage-percentage-metric DefaultLoraInfoMetric = "vllm:lora_requests_info" // default for --lora-info-metric DefaultCacheInfoMetric = "vllm:cache_config_info" // default for --cache-info-metric diff --git a/sidecars/latencypredictorasync/types.go b/sidecars/latencypredictorasync/types.go index 4b4a1ca0b..c8eadefe2 100644 --- a/sidecars/latencypredictorasync/types.go +++ b/sidecars/latencypredictorasync/types.go @@ -120,7 +120,6 @@ type PredictorInterface interface { PredictBulk(ctx context.Context, requests []PredictionRequest) (*BulkPredictionResponse, error) PredictBulkStrict(ctx context.Context, requests []PredictionRequest) (*BulkPredictionResponse, error) AddTrainingDataBulk(entry []TrainingEntry) error - GetServerStatus(ctx context.Context) (*ServerStatusResponse, error) } // --- Data Models ---