From 61b0996ed278fffd270d06fb5c2e329422e43390 Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Thu, 6 Nov 2025 22:23:02 +0000 Subject: [PATCH 1/6] Add latency predictor plugins, deployment, and runner.go integration --- cmd/epp/runner/runner.go | 69 ++- .../manifests/inferencepool-resources-lp.yaml | 450 ++++++++++++++++++ go.mod | 2 +- pkg/epp/backend/metrics/metrics.go | 9 + pkg/epp/backend/metrics/metrics_spec.go | 24 +- pkg/epp/datalayer/metrics/extractor.go | 13 +- pkg/epp/datalayer/metrics/mapping.go | 24 +- pkg/epp/datastore/datastore.go | 2 +- pkg/epp/metrics/metrics.go | 366 ++++++++++++++ pkg/epp/metrics/metrics_test.go | 2 + .../testdata/request_tpot_seconds_metric | 80 ++++ .../testdata/request_ttft_seconds_metric | 116 +++++ .../plugins/multi/slo_aware_router/config.go | 169 +++++++ .../plugins/multi/slo_aware_router/headers.go | 66 +++ .../plugins/multi/slo_aware_router/helpers.go | 145 ++++++ .../latencypredictor_helper.go | 449 +++++++++++++++++ .../multi/slo_aware_router/prediction.go | 122 +++++ .../slo_aware_router/requestcontrol_hooks.go | 246 ++++++++++ .../slo_aware_router/running_request_queue.go | 227 +++++++++ .../running_request_queue_test.go | 391 +++++++++++++++ .../plugins/multi/slo_aware_router/sampler.go | 122 +++++ .../plugins/multi/slo_aware_router/scorer.go | 290 +++++++++++ .../multi/slo_aware_router/selection.go | 381 +++++++++++++++ .../plugins/multi/slo_aware_router/types.go | 53 +++ .../profile/slo_aware_profile_handler.go | 154 ++++++ pkg/epp/server/runserver.go | 3 + sidecars/latencypredictorasync/types.go | 1 - 27 files changed, 3949 insertions(+), 27 deletions(-) create mode 100644 config/manifests/inferencepool-resources-lp.yaml create mode 100644 pkg/epp/metrics/testdata/request_tpot_seconds_metric create mode 100644 pkg/epp/metrics/testdata/request_ttft_seconds_metric create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go create mode 100644 pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index da81fdf46..79d90f24c 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -19,6 +19,7 @@ package runner import ( "context" "crypto/tls" + "encoding/json" "errors" "flag" "fmt" @@ -61,6 +62,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/profile" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/scorer" @@ -68,6 +70,7 @@ import ( runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" "sigs.k8s.io/gateway-api-inference-extension/version" ) @@ -108,6 +111,7 @@ var ( "then a self-signed certificate is used.") // metric flags totalQueuedRequestsMetric = flag.String("total-queued-requests-metric", runserver.DefaultTotalQueuedRequestsMetric, "Prometheus metric for the number of queued requests.") + totalRunningRequestsMetric = flag.String("total-running-requests-metric", runserver.DefaultTotalRunningRequestsMetric, "Prometheus metric for the number of running requests.") kvCacheUsagePercentageMetric = flag.String("kv-cache-usage-percentage-metric", runserver.DefaultKvCacheUsagePercentageMetric, "Prometheus metric for the fraction of KV-cache blocks currently in use (from 0 to 1).") // LoRA metrics loraInfoMetric = flag.String("lora-info-metric", runserver.DefaultLoraInfoMetric, "Prometheus metric for the LoRA info metrics (must be in vLLM label format).") @@ -127,7 +131,10 @@ var ( modelServerMetricsScheme = flag.String("model-server-metrics-scheme", "http", "Scheme to scrape metrics from pods") modelServerMetricsHttpsInsecureSkipVerify = flag.Bool("model-server-metrics-https-insecure-skip-verify", true, "When using 'https' scheme for 'model-server-metrics-scheme', configure 'InsecureSkipVerify' (default to true)") haEnableLeaderElection = flag.Bool("ha-enable-leader-election", false, "Enables leader election for high availability. When enabled, readiness probes will only pass on the leader.") - tracing = flag.Bool("tracing", true, "Enables emitting traces") + + // Latency Predictor Flag + enableLatencyPredictor = flag.Bool("enable-latency-predictor", false, "Enable the regression-based latency predictor and scheduler scorer.") + tracing = flag.Bool("tracing", true, "Enables emitting traces") setupLog = ctrl.Log.WithName("setup") ) @@ -297,9 +304,29 @@ func (r *Runner) Run(ctx context.Context) error { runtime.SetBlockProfileRate(1) } - err = r.parsePluginsConfiguration(ctx, datastore) + // =================================================================== + // == Latency Predictor Integration + // =================================================================== + var predictor latencypredictor.PredictorInterface // Use the interface type + if *enableLatencyPredictor { + setupLog.Info("Latency predictor is enabled. Initializing...") + predictor = latencypredictor.New(latencypredictor.ConfigFromEnv(), ctrl.Log.WithName("latency-predictor")) + + // For the runnable, you'll need to type assert back to the concrete type + concretePredictor := predictor.(*latencypredictor.Predictor) + if err := mgr.Add(runnable.NoLeaderElection(&predictorRunnable{predictor: concretePredictor})); err != nil { + setupLog.Error(err, "Failed to register latency predictor runnable") + return err + } + } else { + setupLog.Info("Latency predictor is disabled.") + predictor = nil // This will be a true nil interface + } + // =================================================================== + + err = r.parsePluginsConfiguration(ctx, predictor, datastore) if err != nil { - setupLog.Error(err, "Failed to parse plugins configuration") + setupLog.Error(err, "Failed to parse the configuration") return err } @@ -368,6 +395,7 @@ func (r *Runner) Run(ctx context.Context) error { Director: director, SaturationDetector: saturationDetector, UseExperimentalDatalayerV2: useDatalayerV2, // pluggable data layer feature flag + LatencyPredictor: predictor, } if err := serverRunner.SetupWithManager(ctx, mgr); err != nil { setupLog.Error(err, "Failed to setup EPP controllers") @@ -410,7 +438,14 @@ func (r *Runner) registerInTreePlugins() { plugins.Register(testfilter.HeaderBasedTestingFilterType, testfilter.HeaderBasedTestingFilterFactory) } -func (r *Runner) parsePluginsConfiguration(ctx context.Context, ds datastore.Datastore) error { +func (r *Runner) registerLatencyPredictorPlugins(predictor latencypredictor.PredictorInterface) { + plugins.Register(slo_aware_router.SLOAwareRouterPluginType, func(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + return slo_aware_router.NewSLOAwareRouter(predictor, slo_aware_router.HeadroomSelectionStrategy).WithName(name), nil + }) + plugins.Register(profile.SLOAwareProfileHandlerType, profile.SLOAwareProfileHandlerFactory) +} + +func (r *Runner) parsePluginsConfiguration(ctx context.Context, predictor latencypredictor.PredictorInterface, ds datastore.Datastore) error { if *configText == "" && *configFile == "" { return nil // configuring through code, not through file } @@ -429,6 +464,12 @@ func (r *Runner) parsePluginsConfiguration(ctx context.Context, ds datastore.Dat } r.registerInTreePlugins() + // If we have a latency predictor enabled and predictor and datastore are not nil, + // register the latency predictor plugins (currently just the SLO scorer). + if *enableLatencyPredictor && predictor != nil { + setupLog.Info("Registering latency predictor plugins") + r.registerLatencyPredictorPlugins(predictor) + } handle := plugins.NewEppHandle(ctx, ds.PodList) config, err := loader.LoadConfig(configBytes, handle, logger) @@ -459,6 +500,7 @@ func (r *Runner) setupMetricsCollection(setupLog logr.Logger, useExperimentalDat func setupMetricsV1(setupLog logr.Logger) (datalayer.EndpointFactory, error) { mapping, err := backendmetrics.NewMetricMapping( *totalQueuedRequestsMetric, + *totalRunningRequestsMetric, *kvCacheUsagePercentageMetric, *loraInfoMetric, *cacheInfoMetric, @@ -502,6 +544,7 @@ func setupDatalayer() (datalayer.EndpointFactory, error) { *modelServerMetricsHttpsInsecureSkipVerify, nil) extractor, err := dlmetrics.NewExtractor(*totalQueuedRequestsMetric, + *totalRunningRequestsMetric, *kvCacheUsagePercentageMetric, *loraInfoMetric, *cacheInfoMetric) @@ -613,3 +656,21 @@ func setupPprofHandlers(mgr ctrl.Manager) error { } return nil } + +// =================================================================== +// == Latency Predictor Plugin and Helpers +// =================================================================== + +// predictorRunnable implements controller-runtime's Runnable interface to manage the predictor's lifecycle. +type predictorRunnable struct { + predictor *latencypredictor.Predictor +} + +func (p *predictorRunnable) Start(ctx context.Context) error { + setupLog.Info("Starting latency predictor...") + p.predictor.Start(ctx) + <-ctx.Done() + setupLog.Info("Stopping latency predictor...") + p.predictor.Stop() + return nil +} diff --git a/config/manifests/inferencepool-resources-lp.yaml b/config/manifests/inferencepool-resources-lp.yaml new file mode 100644 index 000000000..c2a3528de --- /dev/null +++ b/config/manifests/inferencepool-resources-lp.yaml @@ -0,0 +1,450 @@ +# Note: If you change this file, please also change the file used for e2e tests! +# +# https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/test/testdata/inferencepool-e2e.yaml + +# --- ConfigMaps --- +apiVersion: v1 +kind: ConfigMap +metadata: + name: latency-predictor-config + namespace: default +data: + LATENCY_RETRAINING_INTERVAL_SEC: "1" + LATENCY_MIN_SAMPLES_FOR_RETRAIN: "100" + LATENCY_TTFT_MODEL_PATH: "/models/ttft.joblib" + LATENCY_TPOT_MODEL_PATH: "/models/tpot.joblib" + LATENCY_TTFT_SCALER_PATH: "/models/ttft_scaler.joblib" + LATENCY_TPOT_SCALER_PATH: "/models/tpot_scaler.joblib" + LATENCY_MODEL_TYPE: "xgboost" + LATENCY_MAX_TRAINING_DATA_SIZE_PER_BUCKET: "5000" +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: prediction-server-config + namespace: default +data: + LATENCY_MODEL_TYPE: "xgboost" + PREDICT_HOST: "0.0.0.0" + LOCAL_TTFT_MODEL_PATH: "/server_models/ttft.joblib" # Use individual storage + LOCAL_TPOT_MODEL_PATH: "/server_models/tpot.joblib" + LOCAL_TTFT_SCALER_PATH: "/server_models/ttft_scaler.joblib" + LOCAL_TPOT_SCALER_PATH: "/server_models/tpot_scaler.joblib" +--- +# --- InferencePool --- +apiVersion: inference.networking.x-k8s.io/v1alpha2 +kind: InferencePool +metadata: + name: vllm-llama3-8b-instruct +spec: + targetPortNumber: 8000 + selector: + app: vllm-llama3-8b-instruct + extensionRef: + name: vllm-llama3-8b-instruct-epp +--- +# --- EPP Service --- +apiVersion: v1 +kind: Service +metadata: + name: vllm-llama3-8b-instruct-epp + namespace: default +spec: + selector: + app: vllm-llama3-8b-instruct-epp + ports: + - name: epp-grpc + protocol: TCP + port: 9002 + targetPort: 9002 + appProtocol: http2 + - name: latency-predictor-training + protocol: TCP + port: 8000 + targetPort: 8000 + - name: latency-predictor-1 + protocol: TCP + port: 8001 + targetPort: 8001 + - name: latency-predictor-2 + protocol: TCP + port: 8002 + targetPort: 8002 + - name: latency-predictor-3 + protocol: TCP + port: 8003 + targetPort: 8003 + - name: prometheus + protocol: TCP + port: 9090 + targetPort: 9090 + type: LoadBalancer +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: vllm-llama3-8b-instruct-epp + namespace: default +--- +# --- EPP Deployment with Individual Container Volumes --- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: vllm-llama3-8b-instruct-epp + namespace: default + labels: + app: vllm-llama3-8b-instruct-epp +spec: + replicas: 1 # Multiple EPP pods for scaling + selector: + matchLabels: + app: vllm-llama3-8b-instruct-epp + template: + metadata: + labels: + app: vllm-llama3-8b-instruct-epp + spec: + serviceAccountName: vllm-llama3-8b-instruct-epp + # Conservatively, this timeout should mirror the longest grace period of the pods within the pool + terminationGracePeriodSeconds: 130 + containers: + # EPP Container + - name: epp + image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/slo-routing-epp-exp + imagePullPolicy: Always + args: + - -pool-name + - "vllm-llama3-8b-instruct" + - "-pool-namespace" + - "default" + - --pool-group + - "inference.networking.x-k8s.io" + - -v + - "4" + - --zap-encoder + - "json" + - -grpc-port + - "9002" + - -grpc-health-port + - "9003" + - "--config-file" + - "/config/default-plugins.yaml" + - "-enable-latency-predictor" + env: + - name: PREDICTION_SERVER_URL + value: "http://localhost:8001,http://localhost:8002,http://localhost:8003" # Multiple prediction servers + - name: TRAINING_SERVER_URL + value: "http://localhost:8000" # Single training server for sending training data + - name: LATENCY_MAX_SAMPLE_SIZE + value: "10000" # Maximum sample size for latency prediction + - name: NEG_HEADROOM_TPOT_WEIGHT + value: "0.2" # Weight for TPOT in negative headroom calculation + - name: NEG_HEADROOM_TTFT_WEIGHT + value: "0.8" # Weight for TTFT in negative headroom calculation + ports: + - containerPort: 9002 + - containerPort: 9003 + - name: metrics + containerPort: 9090 + livenessProbe: + grpc: + port: 9003 + service: inference-extension + initialDelaySeconds: 5 + periodSeconds: 10 + readinessProbe: + grpc: + port: 9003 + service: inference-extension + initialDelaySeconds: 5 + periodSeconds: 10 + volumeMounts: + - name: plugins-config-volume + mountPath: "/config" + # Training Server Sidecar Container + - name: training-server + image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_training:latest + imagePullPolicy: Always + ports: + - containerPort: 8000 + name: training-port + livenessProbe: + httpGet: + path: /healthz + port: 8000 + initialDelaySeconds: 30 + periodSeconds: 20 + readinessProbe: + httpGet: + path: /readyz + port: 8000 + initialDelaySeconds: 45 + periodSeconds: 10 + resources: + requests: + cpu: "2000m" + memory: "4Gi" + limits: + cpu: "4000m" + memory: "8Gi" + envFrom: + - configMapRef: + name: latency-predictor-config + env: + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: SERVER_TYPE + value: "training" + volumeMounts: + - name: training-server-storage + mountPath: /models + # Prediction Server Sidecar Container 1 + - name: prediction-server-1 + image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest + imagePullPolicy: Always + command: ["uvicorn"] + args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8001"] + ports: + - containerPort: 8001 + name: predict-port-1 + livenessProbe: + httpGet: + path: /healthz + port: 8001 + initialDelaySeconds: 15 + periodSeconds: 15 + readinessProbe: + httpGet: + path: /readyz + port: 8001 + initialDelaySeconds: 10 + periodSeconds: 5 + failureThreshold: 10 + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" + envFrom: + - configMapRef: + name: prediction-server-config + env: + - name: PREDICT_PORT + value: "8001" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: SERVER_TYPE + value: "prediction-1" + - name: TRAINING_SERVER_URL + value: "http://localhost:8000" + volumeMounts: + - name: prediction-server-1-storage + mountPath: /server_models + # Prediction Server Sidecar Container 2 + - name: prediction-server-2 + image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest + imagePullPolicy: Always + command: ["uvicorn"] + args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8002"] + ports: + - containerPort: 8002 + name: predict-port-2 + livenessProbe: + httpGet: + path: /healthz + port: 8002 + initialDelaySeconds: 15 + periodSeconds: 15 + readinessProbe: + httpGet: + path: /readyz + port: 8002 + initialDelaySeconds: 10 + periodSeconds: 5 + failureThreshold: 10 + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" + envFrom: + - configMapRef: + name: prediction-server-config + env: + - name: PREDICT_PORT + value: "8002" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: SERVER_TYPE + value: "prediction-2" + - name: TRAINING_SERVER_URL + value: "http://localhost:8000" + volumeMounts: + - name: prediction-server-2-storage + mountPath: /server_models + # Prediction Server Sidecar Container 3 + - name: prediction-server-3 + image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest + imagePullPolicy: Always + command: ["uvicorn"] + args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8003"] + ports: + - containerPort: 8003 + name: predict-port-3 + livenessProbe: + httpGet: + path: /healthz + port: 8003 + initialDelaySeconds: 15 + periodSeconds: 15 + readinessProbe: + httpGet: + path: /readyz + port: 8003 + initialDelaySeconds: 10 + periodSeconds: 5 + failureThreshold: 10 + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" + envFrom: + - configMapRef: + name: prediction-server-config + env: + - name: PREDICT_PORT + value: "8003" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: SERVER_TYPE + value: "prediction-3" + - name: TRAINING_SERVER_URL + value: "http://localhost:8000" + volumeMounts: + - name: prediction-server-3-storage + mountPath: /server_models + volumes: + - name: training-server-storage + emptyDir: + sizeLimit: "20Gi" # Dedicated volume for training server + - name: prediction-server-1-storage + emptyDir: + sizeLimit: "10Gi" # Dedicated volume for prediction server 1 + - name: prediction-server-2-storage + emptyDir: + sizeLimit: "10Gi" # Dedicated volume for prediction server 2 + - name: prediction-server-3-storage + emptyDir: + sizeLimit: "10Gi" # Dedicated volume for prediction server 3 + - name: plugins-config-volume + configMap: + name: plugins-config +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: plugins-config + namespace: default +data: + default-plugins.yaml: | + apiVersion: inference.networking.x-k8s.io/v1alpha1 + kind: EndpointPickerConfig + plugins: + - type: queue-scorer + - type: kv-cache-utilization-scorer + - type: prefix-cache-scorer + - type: slo-aware-routing + - type: slo-aware-profile-handler + - type: max-score-picker + schedulingProfiles: + - name: prefix + plugins: + - pluginRef: prefix-cache-scorer + - name: default + plugins: + - pluginRef: slo-aware-routing + weight: 0 + - pluginRef: queue-scorer + - pluginRef: kv-cache-utilization-scorer + - pluginRef: max-score-picker + - name: slo + plugins: + - pluginRef: slo-aware-routing + - pluginRef: max-score-picker +--- +# --- RBAC --- +kind: Role +apiVersion: rbac.authorization.k8s.io/v1 +metadata: + name: pod-read + namespace: default +rules: +- apiGroups: [ "inference.networking.x-k8s.io" ] + resources: [ "inferenceobjectives", "inferencepools" ] + verbs: [ "get", "watch", "list" ] +- apiGroups: [ "inference.networking.k8s.io" ] + resources: [ "inferencepools" ] + verbs: [ "get", "watch", "list" ] +- apiGroups: [ "" ] + resources: [ "pods" ] + verbs: [ "get", "watch", "list" ] +--- +kind: RoleBinding +apiVersion: rbac.authorization.k8s.io/v1 +metadata: + name: pod-read-binding + namespace: default +subjects: +- kind: ServiceAccount + name: vllm-llama3-8b-instruct-epp + namespace: default +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: pod-read +--- +kind: ClusterRole +apiVersion: rbac.authorization.k8s.io/v1 +metadata: + name: auth-reviewer +rules: +- apiGroups: + - authentication.k8s.io + resources: + - tokenreviews + verbs: + - create +- apiGroups: + - authorization.k8s.io + resources: + - subjectaccessreviews + verbs: + - create +--- +kind: ClusterRoleBinding +apiVersion: rbac.authorization.k8s.io/v1 +metadata: + name: auth-reviewer-binding +subjects: +- kind: ServiceAccount + name: vllm-llama3-8b-instruct-epp + namespace: default +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: auth-reviewer \ No newline at end of file diff --git a/go.mod b/go.mod index ab54e64f3..c32d4b5f8 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/elastic/crd-ref-docs v0.2.0 github.com/envoyproxy/go-control-plane/envoy v1.35.0 github.com/go-logr/logr v1.4.3 + github.com/go-logr/zapr v1.3.0 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/hashicorp/golang-lru/v2 v2.0.7 @@ -61,7 +62,6 @@ require ( github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/go-logr/zapr v1.3.0 // indirect github.com/go-openapi/jsonpointer v0.21.2 // indirect github.com/go-openapi/jsonreference v0.21.0 // indirect github.com/go-openapi/swag v0.23.1 // indirect diff --git a/pkg/epp/backend/metrics/metrics.go b/pkg/epp/backend/metrics/metrics.go index 5d2a85e96..d2d9b60f9 100644 --- a/pkg/epp/backend/metrics/metrics.go +++ b/pkg/epp/backend/metrics/metrics.go @@ -97,6 +97,15 @@ func (p *PodMetricsClientImpl) promToPodMetrics( } } + if p.MetricMapping.TotalRunningRequests != nil { + running, err := p.getMetric(metricFamilies, *p.MetricMapping.TotalRunningRequests) + if err == nil { + updated.RunningQueueSize = int(running.GetGauge().GetValue()) + } else { + errs = multierr.Append(errs, err) + } + } + if p.MetricMapping.KVCacheUtilization != nil { usage, err := p.getMetric(metricFamilies, *p.MetricMapping.KVCacheUtilization) if err == nil { diff --git a/pkg/epp/backend/metrics/metrics_spec.go b/pkg/epp/backend/metrics/metrics_spec.go index 7407f4ed7..b3c26db2c 100644 --- a/pkg/epp/backend/metrics/metrics_spec.go +++ b/pkg/epp/backend/metrics/metrics_spec.go @@ -29,10 +29,11 @@ type MetricSpec struct { // MetricMapping holds named MetricSpecs. type MetricMapping struct { - TotalQueuedRequests *MetricSpec - KVCacheUtilization *MetricSpec - LoraRequestInfo *MetricSpec - CacheConfigInfo *MetricSpec + TotalQueuedRequests *MetricSpec + TotalRunningRequests *MetricSpec + KVCacheUtilization *MetricSpec + LoraRequestInfo *MetricSpec + CacheConfigInfo *MetricSpec } // stringToMetricSpec converts a string to a MetricSpec. @@ -94,11 +95,15 @@ func stringToMetricSpec(specStr string) (*MetricSpec, error) { } // NewMetricMapping creates a MetricMapping from string values. -func NewMetricMapping(queuedStr, kvUsageStr, loraReqInfoStr, cacheInfoMetric string) (*MetricMapping, error) { +func NewMetricMapping(queuedStr, runningStr, kvUsageStr, loraReqInfoStr, cacheInfoMetric string) (*MetricMapping, error) { queuedSpec, err := stringToMetricSpec(queuedStr) if err != nil { return nil, fmt.Errorf("error parsing WaitingRequests: %w", err) } + runningSpec, err := stringToMetricSpec(runningStr) + if err != nil { + return nil, fmt.Errorf("error parsing RunningRequests: %w", err) + } kvUsageSpec, err := stringToMetricSpec(kvUsageStr) if err != nil { return nil, fmt.Errorf("error parsing KVCacheUsage: %w", err) @@ -114,10 +119,11 @@ func NewMetricMapping(queuedStr, kvUsageStr, loraReqInfoStr, cacheInfoMetric str } mapping := &MetricMapping{ - TotalQueuedRequests: queuedSpec, - KVCacheUtilization: kvUsageSpec, - LoraRequestInfo: loraReqInfoSpec, - CacheConfigInfo: cacheInfoSpec, + TotalQueuedRequests: queuedSpec, + TotalRunningRequests: runningSpec, + KVCacheUtilization: kvUsageSpec, + LoraRequestInfo: loraReqInfoSpec, + CacheConfigInfo: cacheInfoSpec, } return mapping, nil diff --git a/pkg/epp/datalayer/metrics/extractor.go b/pkg/epp/datalayer/metrics/extractor.go index f6142c494..95baec1b9 100644 --- a/pkg/epp/datalayer/metrics/extractor.go +++ b/pkg/epp/datalayer/metrics/extractor.go @@ -64,8 +64,8 @@ func Produces() map[string]any { // configured with the given metrics' specifications. // These are mandatory metrics per the MSP specification, and are used // as the basis for the built-in scheduling plugins. -func NewExtractor(queueSpec, kvusageSpec, loraSpec, cacheInfoSpec string) (*Extractor, error) { - mapping, err := NewMapping(queueSpec, kvusageSpec, loraSpec, cacheInfoSpec) +func NewExtractor(queueSpec, runningSpec, kvusageSpec, loraSpec, cacheInfoSpec string) (*Extractor, error) { + mapping, err := NewMapping(queueSpec, runningSpec, kvusageSpec, loraSpec, cacheInfoSpec) if err != nil { return nil, fmt.Errorf("failed to create extractor metrics Mapping - %w", err) } @@ -107,6 +107,15 @@ func (ext *Extractor) Extract(ctx context.Context, data any, ep datalayer.Endpoi } } + if spec := ext.mapping.TotalRunningRequests; spec != nil { // extract running requests + if metric, err := spec.getLatestMetric(families); err != nil { + errs = append(errs, err) + } else { + clone.RunningQueueSize = int(extractValue(metric)) + updated = true + } + } + if spec := ext.mapping.KVCacheUtilization; spec != nil { // extract KV cache usage if metric, err := spec.getLatestMetric(families); err != nil { errs = append(errs, err) diff --git a/pkg/epp/datalayer/metrics/mapping.go b/pkg/epp/datalayer/metrics/mapping.go index fab6cf75f..7b1fed9c1 100644 --- a/pkg/epp/datalayer/metrics/mapping.go +++ b/pkg/epp/datalayer/metrics/mapping.go @@ -23,20 +23,25 @@ import ( // Mapping holds specifications for the well-known metrics defined // in the Model Server Protocol. type Mapping struct { - TotalQueuedRequests *Spec - KVCacheUtilization *Spec - LoraRequestInfo *LoRASpec - CacheInfo *Spec + TotalQueuedRequests *Spec + TotalRunningRequests *Spec + KVCacheUtilization *Spec + LoraRequestInfo *LoRASpec + CacheInfo *Spec } // NewMapping creates a metrics.Mapping from the input specification strings. -func NewMapping(queue, kvusage, lora, cacheInfo string) (*Mapping, error) { +func NewMapping(queue, running, kvusage, lora, cacheInfo string) (*Mapping, error) { var errs []error queueSpec, err := parseStringToSpec(queue) if err != nil { errs = append(errs, err) } + runningSpec, err := parseStringToSpec(running) + if err != nil { + errs = append(errs, err) + } kvusageSpec, err := parseStringToSpec(kvusage) if err != nil { errs = append(errs, err) @@ -55,9 +60,10 @@ func NewMapping(queue, kvusage, lora, cacheInfo string) (*Mapping, error) { return nil, errors.Join(errs...) } return &Mapping{ - TotalQueuedRequests: queueSpec, - KVCacheUtilization: kvusageSpec, - LoraRequestInfo: loraSpec, - CacheInfo: cacheInfoSpec, + TotalQueuedRequests: queueSpec, + TotalRunningRequests: runningSpec, + KVCacheUtilization: kvusageSpec, + LoraRequestInfo: loraSpec, + CacheInfo: cacheInfoSpec, }, nil } diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index dade69469..5dcd0f4a0 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -63,7 +63,7 @@ type Datastore interface { // PodList lists pods matching the given predicate. PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool - PodDelete(podNAme string) + PodDelete(podName string) // Clears the store state, happens when the pool gets deleted. Clear() diff --git a/pkg/epp/metrics/metrics.go b/pkg/epp/metrics/metrics.go index 59c8976cd..af422f2b5 100644 --- a/pkg/epp/metrics/metrics.go +++ b/pkg/epp/metrics/metrics.go @@ -63,6 +63,193 @@ var ( []string{"model_name", "target_model_name", "error_code"}, ) + requestTTFT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model TTFT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.005, 0.025, 0.05, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 2, 3, + 4, 5, 6, 8, 10, 15, 20, 30, 45, 60, 120, 180, 240, 300, 360, 480, 600, 900, 1200, 1800, 2700, 3600, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestTTFTGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Inference model TTFT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestPredictedTTFT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_predicted_ttft_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model Predicted TTFT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.005, 0.025, 0.05, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 2, 3, + 4, 5, 6, 8, 10, 15, 20, 30, 45, 60, 120, 180, 240, 300, 360, 480, 600, 900, 1200, 1800, 2700, 3600, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestPredictedTTFTGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_predicted_ttft_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Inference model Predicted TTFT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + // New metrics for TTFT prediction duration + requestTTFTPredictionDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_prediction_duration_seconds", + Help: metricsutil.HelpMsgWithStability("Duration taken to generate TTFT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestTTFTPredictionDurationGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_prediction_duration_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Latest duration taken to generate TTFT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model TPOT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0005, 0.00205, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.125, 0.15, 0.2, 0.3, + 0.4, 0.5, 0.6, 0.8, 1, 1.5, 2, 3, 4.5, 6, 12, 18, 24, 30, 36, 48, 60, 90, 120, 180, 270, 360, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOTGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Inference model TPOT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + requestPredictedTPOT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_predicted_tpot_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model Predicted TPOT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0005, 0.00205, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.125, 0.15, 0.2, 0.3, + 0.4, 0.5, 0.6, 0.8, 1, 1.5, 2, 3, 4.5, 6, 12, 18, 24, 30, 36, 48, 60, 90, 120, 180, 270, 360, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestPredictedTPOTGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_predicted_tpot_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Inference model Predicted TPOT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + // New metrics for TPOT prediction duration + requestTPOTPredictionDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_prediction_duration_seconds", + Help: metricsutil.HelpMsgWithStability("Duration taken to generate TPOT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOTPredictionDurationGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_prediction_duration_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Latest duration taken to generate TPOT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + // SLO Violation Metrics + requestTTFTSLOViolation = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_slo_violation", + Help: metricsutil.HelpMsgWithStability("Boolean indicator (0 or 1) of whether the last TTFT measurement violated the SLO threshold for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestTTFTSLOViolationCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_slo_violation_total", + Help: metricsutil.HelpMsgWithStability("Counter of TTFT SLO violations for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOTSLOViolation = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_slo_violation", + Help: metricsutil.HelpMsgWithStability("Boolean indicator (0 or 1) of whether the last TPOT measurement violated the SLO threshold for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOTSLOViolationCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_slo_violation_total", + Help: metricsutil.HelpMsgWithStability("Counter of TPOT SLO violations for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + // SLO threshold gauges (for dynamic threshold management) + requestTTFTSLOThreshold = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_slo_threshold_seconds", + Help: metricsutil.HelpMsgWithStability("Current TTFT SLO threshold in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOTSLOThreshold = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_slo_threshold_seconds", + Help: metricsutil.HelpMsgWithStability("Current TPOT SLO threshold in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + requestLatencies = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Subsystem: InferenceObjectiveComponent, @@ -282,6 +469,32 @@ var registerMetrics sync.Once // Register all metrics. func Register(customCollectors ...prometheus.Collector) { registerMetrics.Do(func() { + metrics.Registry.MustRegister(requestTPOT) + metrics.Registry.MustRegister(requestTTFT) + + metrics.Registry.MustRegister(requestTPOTGauge) + metrics.Registry.MustRegister(requestTTFTGauge) + + metrics.Registry.MustRegister(requestPredictedTPOT) + metrics.Registry.MustRegister(requestPredictedTTFT) + + metrics.Registry.MustRegister(requestPredictedTPOTGauge) + metrics.Registry.MustRegister(requestPredictedTTFTGauge) + + // Register new prediction duration metrics + metrics.Registry.MustRegister(requestTPOTPredictionDuration) + metrics.Registry.MustRegister(requestTPOTPredictionDurationGauge) + metrics.Registry.MustRegister(requestTTFTPredictionDuration) + metrics.Registry.MustRegister(requestTTFTPredictionDurationGauge) + + // Register SLO violation metrics + metrics.Registry.MustRegister(requestTTFTSLOViolation) + metrics.Registry.MustRegister(requestTTFTSLOViolationCounter) + metrics.Registry.MustRegister(requestTPOTSLOViolation) + metrics.Registry.MustRegister(requestTPOTSLOViolationCounter) + metrics.Registry.MustRegister(requestTTFTSLOThreshold) + metrics.Registry.MustRegister(requestTPOTSLOThreshold) + metrics.Registry.MustRegister(requestCounter) metrics.Registry.MustRegister(requestErrCounter) metrics.Registry.MustRegister(requestLatencies) @@ -332,6 +545,30 @@ func Reset() { PrefixCacheHitLength.Reset() flowControlRequestQueueDuration.Reset() flowControlQueueSize.Reset() + + requestTPOT.Reset() + requestTTFT.Reset() + requestTPOTGauge.Reset() + requestTTFTGauge.Reset() + + requestPredictedTPOT.Reset() + requestPredictedTTFT.Reset() + requestPredictedTPOTGauge.Reset() + requestPredictedTTFTGauge.Reset() + + // Reset new prediction duration metrics + requestTPOTPredictionDuration.Reset() + requestTPOTPredictionDurationGauge.Reset() + requestTTFTPredictionDuration.Reset() + requestTTFTPredictionDurationGauge.Reset() + + // Reset SLO violation metrics + requestTTFTSLOViolation.Reset() + requestTTFTSLOViolationCounter.Reset() + requestTPOTSLOViolation.Reset() + requestTPOTSLOViolationCounter.Reset() + requestTTFTSLOThreshold.Reset() + requestTPOTSLOThreshold.Reset() } // RecordRequstCounter records the number of requests. @@ -363,6 +600,123 @@ func RecordRequestLatencies(ctx context.Context, modelName, targetModelName stri return true } +func RecordRequestTPOT(ctx context.Context, modelName, targetModelName string, tpot float64) bool { + if tpot < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TPOT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "tpot", tpot) + return false + } + requestTPOT.WithLabelValues(modelName, targetModelName).Observe(tpot) + requestTPOTGauge.WithLabelValues(modelName, targetModelName).Set(tpot) + return true +} + +// RecordRequestTPOTWithSLO records TPOT and checks for SLO violation. +// If tpot exceeds the threshold, it records a violation (sets gauge to 1 and increments counter). +// If tpot is within limits, it sets gauge to 0. +func RecordRequestTPOTWithSLO(ctx context.Context, modelName, targetModelName string, tpot float64, sloThreshold float64) bool { + if tpot < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TPOT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "tpot", tpot) + return false + } + + // Check for SLO violation (tpot exceeds threshold) + if tpot > sloThreshold { + requestTPOTSLOViolation.WithLabelValues(modelName, targetModelName).Set(1) + requestTPOTSLOViolationCounter.WithLabelValues(modelName, targetModelName).Inc() + log.FromContext(ctx).V(logutil.DEFAULT).Info("TPOT SLO violation detected", + "modelName", modelName, "targetModelName", targetModelName, "tpot", tpot, "threshold", sloThreshold) + } else { + requestTPOTSLOViolation.WithLabelValues(modelName, targetModelName).Set(0) + } + + return true +} + +// TPOT records duration of request. +func RecordRequestPredictedTPOT(ctx context.Context, modelName, targetModelName string, predicted_tpot float64) bool { + if predicted_tpot < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "Predicted TPOT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "tpot", predicted_tpot) + return false + } + requestPredictedTPOT.WithLabelValues(modelName, targetModelName).Observe(predicted_tpot) + requestPredictedTPOTGauge.WithLabelValues(modelName, targetModelName).Set(predicted_tpot) + return true +} + +// RecordRequestTPOTPredictionDuration records the duration taken to generate TPOT predictions. +func RecordRequestTPOTPredictionDuration(ctx context.Context, modelName, targetModelName string, duration float64) bool { + if duration < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TPOT prediction duration must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "duration", duration) + return false + } + requestTPOTPredictionDuration.WithLabelValues(modelName, targetModelName).Observe(duration) + requestTPOTPredictionDurationGauge.WithLabelValues(modelName, targetModelName).Set(duration) + return true +} + +// TTFT records duration of request. +func RecordRequestTTFT(ctx context.Context, modelName, targetModelName string, ttft float64) bool { + if ttft < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TTFT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "ttft", ttft) + return false + } + requestTTFT.WithLabelValues(modelName, targetModelName).Observe(ttft) + requestTTFTGauge.WithLabelValues(modelName, targetModelName).Set(ttft) + return true +} + +// RecordRequestTTFTWithSLO records TTFT and checks for SLO violation. +// If ttft exceeds the threshold, it records a violation (sets gauge to 1 and increments counter). +// If ttft is within limits, it sets gauge to 0. +func RecordRequestTTFTWithSLO(ctx context.Context, modelName, targetModelName string, ttft float64, sloThreshold float64) bool { + if ttft < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TTFT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "ttft", ttft) + return false + } + + // Check for SLO violation (ttft exceeds threshold) + if ttft > sloThreshold { + requestTTFTSLOViolation.WithLabelValues(modelName, targetModelName).Set(1) + requestTTFTSLOViolationCounter.WithLabelValues(modelName, targetModelName).Inc() + log.FromContext(ctx).V(logutil.DEFAULT).Info("TTFT SLO violation detected", + "modelName", modelName, "targetModelName", targetModelName, "ttft", ttft, "threshold", sloThreshold) + } else { + requestTTFTSLOViolation.WithLabelValues(modelName, targetModelName).Set(0) + } + + return true +} + +// TPOT records duration of request. +func RecordRequestPredictedTTFT(ctx context.Context, modelName, targetModelName string, predicted_ttft float64) bool { + if predicted_ttft < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "Predicted TTFT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "ttft", predicted_ttft) + return false + } + requestPredictedTTFT.WithLabelValues(modelName, targetModelName).Observe(predicted_ttft) + requestPredictedTTFTGauge.WithLabelValues(modelName, targetModelName).Set(predicted_ttft) + return true +} + +// RecordRequestTTFTPredictionDuration records the duration taken to generate TTFT predictions. +func RecordRequestTTFTPredictionDuration(ctx context.Context, modelName, targetModelName string, duration float64) bool { + if duration < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TTFT prediction duration must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "duration", duration) + return false + } + requestTTFTPredictionDuration.WithLabelValues(modelName, targetModelName).Observe(duration) + requestTTFTPredictionDurationGauge.WithLabelValues(modelName, targetModelName).Set(duration) + return true +} + // RecordResponseSizes records the response sizes. func RecordResponseSizes(modelName, targetModelName string, size int) { responseSizes.WithLabelValues(modelName, targetModelName).Observe(float64(size)) @@ -480,3 +834,15 @@ func IncFlowControlQueueSize(fairnessID, priority string) { func DecFlowControlQueueSize(fairnessID, priority string) { flowControlQueueSize.WithLabelValues(fairnessID, priority).Dec() } + +// SetTTFTSLOThreshold sets the TTFT SLO threshold for a model. +// This allows dynamic threshold management and makes the threshold visible in metrics. +func SetTTFTSLOThreshold(modelName, targetModelName string, threshold float64) { + requestTTFTSLOThreshold.WithLabelValues(modelName, targetModelName).Set(threshold) +} + +// SetTPOTSLOThreshold sets the TPOT SLO threshold for a model. +// This allows dynamic threshold management and makes the threshold visible in metrics. +func SetTPOTSLOThreshold(modelName, targetModelName string, threshold float64) { + requestTPOTSLOThreshold.WithLabelValues(modelName, targetModelName).Set(threshold) +} diff --git a/pkg/epp/metrics/metrics_test.go b/pkg/epp/metrics/metrics_test.go index 7d4168183..754d6d294 100644 --- a/pkg/epp/metrics/metrics_test.go +++ b/pkg/epp/metrics/metrics_test.go @@ -46,6 +46,8 @@ const ( KVCacheAvgUsageMetric = InferencePoolComponent + "_average_kv_cache_utilization" QueueAvgSizeMetric = InferencePoolComponent + "_average_queue_size" PerPodQueueSizeMetrics = InferencePoolComponent + "_per_pod_queue_size" + RequestTTFTSecondsMetric = InferenceObjectiveComponent + "_request_ttft_seconds" + RequestTPOTSecondsMetric = InferenceObjectiveComponent + "_request_tpot_seconds" ) func TestMain(m *testing.M) { diff --git a/pkg/epp/metrics/testdata/request_tpot_seconds_metric b/pkg/epp/metrics/testdata/request_tpot_seconds_metric new file mode 100644 index 000000000..beee50271 --- /dev/null +++ b/pkg/epp/metrics/testdata/request_tpot_seconds_metric @@ -0,0 +1,80 @@ +# HELP inference_model_request_tpot_seconds [ALPHA] Inference model response latency distribution in seconds for each model and target model. +# TYPE inference_model_request_tpot_seconds histogram +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.0005"} 0 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.0025"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.005"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.01"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.02"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.04"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.06"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.08"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.1"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.125"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.15"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.2"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.3"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.4"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.5"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.6"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.8"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="1"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="1.5"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="2"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="3"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="4.5"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="6"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="12"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="18"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="24"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="30"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="36"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="48"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="60"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="90"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="120"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="180"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="270"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="360"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="Inf"} 2 +inference_model_request_tpot_seconds_sum{model_name="m20", target_model_name="t10"} 0.161 +inference_model_request_tpot_seconds_count{model_name="m20", target_model_name="t10"} 2 + + +iinference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.0005"} 0 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.0025"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.005"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.01"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.02"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.04"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.06"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.08"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.1"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.125"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.15"} 1 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.2"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.3"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.4"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.5"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.6"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="0.8"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="1"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="1.5"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="2"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="3"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="4.5"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="6"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="12"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="18"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="24"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="30"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="36"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="48"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="60"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="90"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="120"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="180"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="270"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="360"} 2 +inference_model_request_tpot_seconds_bucket{model_name="m20", target_model_name="t10", le="Inf"} 2 +inference_model_request_tpot_seconds_sum{model_name="m20", target_model_name="t10"} 0.161 +inference_model_request_tpot_seconds_count{model_name="m20", target_model_name="t10"} 2 \ No newline at end of file diff --git a/pkg/epp/metrics/testdata/request_ttft_seconds_metric b/pkg/epp/metrics/testdata/request_ttft_seconds_metric new file mode 100644 index 000000000..315490727 --- /dev/null +++ b/pkg/epp/metrics/testdata/request_ttft_seconds_metric @@ -0,0 +1,116 @@ +# HELP inference_model_request_ttft_seconds [ALPHA] Inference model response latency distribution in seconds for each model and target model. +# TYPE inference_model_request_ttft_seconds histogram +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.005"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.025"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.05"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.1"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.2"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.4"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.6"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="0.8"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="1.0"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="1.25"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="1.5"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="2"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="3"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="4"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="5"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="6"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="8"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="10"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="15"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="20"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="30"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="45"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="60"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="120"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="180"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="240"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="300"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="360"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="480"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="600"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="900"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="1200"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="1800"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="2700"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="3600"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10", target_model_name="t10", le="Inf"} 2 +inference_model_request_ttft_seconds_sum{model_name="m10", target_model_name="t10"} 1.61 +inference_model_request_ttft_seconds_count{model_name="m10", target_model_name="t10"} 2 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.005"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.025"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.05"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.1"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.2"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.4"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.6"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="0.8"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="1"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="1.25"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="1.5"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="2"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="3"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="4"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="5"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="6"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="8"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="10"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="15"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="20"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="30"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="45"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="60"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="120"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="180"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="240"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="300"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="360"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="480"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="600"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="900"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="1200"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="1800"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="2700"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="3600"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m10",target_model_name="t11",le="+Inf"} 1 +inference_model_request_ttft_seconds_sum{model_name="m10",target_model_name="t11"} 0.06 +inference_model_request_ttft_seconds_count{model_name="m10",target_model_name="t11"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.005"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.025"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.05"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.1"} 0 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.2"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.4"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.6"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="0.8"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="1"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="1.25"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="1.5"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="2"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="3"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="4"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="5"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="6"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="8"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="10"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="15"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="20"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="30"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="45"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="60"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="120"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="180"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="240"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="300"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="360"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="480"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="600"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="900"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="1200"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="1800"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="2700"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="3600"} 1 +inference_model_request_ttft_seconds_bucket{model_name="m20",target_model_name="t20",le="+Inf"} 1 +inference_model_request_ttft_seconds_sum{model_name="m20",target_model_name="t20"} 0.12 +inference_model_request_ttft_seconds_count{model_name="m20",target_model_name="t20"} 1 diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go new file mode 100644 index 000000000..cf9d8ee33 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go @@ -0,0 +1,169 @@ +/* +© 2025 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "os" + "strconv" + "strings" +) + +var SLOBufferFactor = func() float64 { + if value, exists := os.LookupEnv("SLO_BUFFER_FACTOR"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil { + return parsedValue + } + } + return 1.0 // default value +}() + +var NegHeadroomTTFTWeight = func() float64 { + if value, exists := os.LookupEnv("NEG_HEADROOM_TTFT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.8 // default: TTFT dominates when violating SLOs +}() + +var NegHeadroomTPOTWeight = func() float64 { + if value, exists := os.LookupEnv("NEG_HEADROOM_TPOT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.2 // default: TPOT less important in your tiny-output scenario +}() + +var HeadroomTTFTWeight = func() float64 { + if value, exists := os.LookupEnv("HEADROOM_TTFT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.8 // default +}() + +var HeadroomTPOTWeight = func() float64 { + if value, exists := os.LookupEnv("HEADROOM_TPOT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.2 // default +}() + +var HeadroomSelectionStrategy = func() HeadroomStrategy { + if value, exists := os.LookupEnv("HEADROOM_SELECTION_STRATEGY"); exists { + switch strings.ToLower(value) { + case "least": + return HeadroomStrategyLeast + case "most": + return HeadroomStrategyMost + case "composite-least": + return HeadroomStrategyCompositeLeast + case "composite-most": + return HeadroomStrategyCompositeMost + case "composite-only": + return HeadroomStrategyCompositeOnly + } + } + return HeadroomStrategyLeast // default to least (better packing) +}() + +// If using composite headroom, weights for each component. Not used by default +var CompositeKVWeight = func() float64 { + if v, ok := os.LookupEnv("COMPOSITE_KV_WEIGHT"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 { + return f + } + } + return 1 +}() + +var CompositeQueueWeight = func() float64 { + if v, ok := os.LookupEnv("COMPOSITE_QUEUE_WEIGHT"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 { + return f + } + } + return 1 +}() + +var CompositePrefixWeight = func() float64 { + if v, ok := os.LookupEnv("COMPOSITE_PREFIX_WEIGHT"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 { + return f + } + } + return 1 +}() + +// With probability ε, explore (ignore affinity gate); otherwise exploit. +var EpsilonExploreSticky = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("STICKY_EPSILON"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.01 // default 1% exploration +}() + +var EpsilonExploreNeg = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("NEG_HEADROOM_EPSILON"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.01 // default 1% exploration +}() + +// τ for per-path affinity gate (aka "stickiness" threshold). +var AffinityGateTau = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("AFFINITY_GATE_TAU"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.80 +}() + +// Global τ for the overall candidate set (previously "overall stickiness"). +var AffinityGateTauGlobal = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("AFFINITY_GATE_TAU_GLOBAL"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.99 +}() + +// Read once at init. Values: "linear" (default) or "max". +var SelectionMode = func() PodSelectionMode { + if v, ok := os.LookupEnv("POD_SELECTION_MODE"); ok { + switch strings.ToLower(v) { + case "max": + return PodSelectionMax + case "linear": + fallthrough + default: + return PodSelectionLinear + } + } + return PodSelectionLinear +}() diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go new file mode 100644 index 000000000..2588b3104 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go @@ -0,0 +1,66 @@ +/* +© 2025 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "fmt" + "strconv" + + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" +) + +// parseFloatHeader retrieves a header by name, parses it as a float64, +// and returns the value or an error if the header is missing or invalid. +func parseFloatHeader(request schedulingtypes.LLMRequest, headerName string) (float64, bool, error) { + // 1. Get header value from the map + headerValue, ok := request.Headers[headerName] + if !ok { + return 0, false, nil // Header not found, return 0 and false + } + + // 2. Parse the header value to a float64 + parsedFloat, err := strconv.ParseFloat(headerValue, 64) + if err != nil { + return 0, false, errutil.Error{ + Code: errutil.BadRequest, + Msg: fmt.Sprintf("%s must be a float", headerName), + } + } + + // 3. Return the successfully parsed value + return parsedFloat, true, nil +} + +// parseFloatHeader retrieves a header by name, parses it as a bool, +// and returns the value or an error if the header is missing or invalid. +func parseBoolHeader(request schedulingtypes.LLMRequest, headerName string) (bool, error) { + // 1. Get header value from the map + headerValue, ok := request.Headers[headerName] + if !ok { + return false, nil // Header not found, return 0 and false + } + + // 2. Parse the header value to a bool + parsedBool, err := strconv.ParseBool(headerValue) + if err != nil { + return false, errutil.Error{ + Code: errutil.BadRequest, + Msg: fmt.Sprintf("%s must be a bool", headerName), + } + } + + // 3. Return the successfully parsed value + return parsedBool, nil +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go new file mode 100644 index 000000000..3f02d5e52 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go @@ -0,0 +1,145 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "math" + "math/rand" + + "sigs.k8s.io/controller-runtime/pkg/log" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func (s *SLOAwareRouter) selectFromCompositeScores(ctx context.Context, allPreds []PodPredictionResult, r *rand.Rand, strategy HeadroomStrategy) schedulingtypes.Pod { + total := 0 + choices := s.buildCompositeChoices( + ctx, allPreds, CompositeKVWeight, CompositeQueueWeight, CompositePrefixWeight, &total, + ) + if strategy == HeadroomStrategyCompositeLeast { + // Invert weights for "least" strategy + for i := range choices { + choices[i].Weight = minWeight + Wmax - choices[i].Weight + } + } + selectedPod := s.performWeightedRandomSelection(choices, total, allPreds, r) + return selectedPod +} +func (s *SLOAwareRouter) performWeightedRandomSelection(weightedChoices []Choice, total int, candidates []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + if total == 0 { + return nil + } + logger := log.FromContext(context.Background()) + // Check if MAX_SCORE_SELECTION env variable is set + if SelectionMode == PodSelectionMax { + + logger.V(logutil.DEBUG).Info("Pod selection mode: MAX - selecting pod with highest weight") + maxWeight := 0 + var selectedPod schedulingtypes.Pod + for _, c := range weightedChoices { + if c.Weight > maxWeight { + maxWeight = c.Weight + selectedPod = c.PodName + } + } + if selectedPod != nil { + return selectedPod + } + // Fallback to first pod if no selection made + return candidates[0].Pod + } + + // Original weighted random selection logic + logger.V(logutil.DEBUG).Info("Pod selection mode: LINEAR - performing weighted random selection") + idx := r.Intn(total) + var selectedPod schedulingtypes.Pod + + for _, c := range weightedChoices { + if idx < c.Weight { + selectedPod = c.PodName + break + } + idx -= c.Weight + } + + // If no pod was selected (shouldn't happen), fallback to first pod + if selectedPod == nil { + selectedPod = candidates[0].Pod + } + + return selectedPod +} +func (s *SLOAwareRouter) buildCompositeChoices( + ctx context.Context, + candidates []PodPredictionResult, + wkv, wq, wpref float64, + total *int, +) []Choice { + + // Normalize weights + sumw := wkv + wq + wpref + if sumw <= 0 { + wkv, wq, wpref = 1, 0, 0 + } else { + wkv /= sumw + wq /= sumw + wpref /= sumw + } + + // Precompute queue stats + minQ, maxQ := math.MaxInt32, -1 + queueCounts := make(map[string]int, len(candidates)) + for _, p := range candidates { + q := p.Pod.GetMetrics().WaitingQueueSize + queueCounts[p.Pod.GetPod().String()] = q + if q < minQ { + minQ = q + } + if q > maxQ { + maxQ = q + } + } + den := float64(maxQ - minQ) + + choices := make([]Choice, 0, len(candidates)) + for _, p := range candidates { + q := queueCounts[p.Pod.GetPod().String()] + relQueue := 1.0 + if den > 0 { + relQueue = (float64(maxQ-q) / den) + } + + kvUsage := p.Pod.GetMetrics().KVCacheUsagePercent + kvFree := (1.0 - kvUsage) + prefix := (p.PrefixCacheScore) + + composite := wkv*kvFree + wq*relQueue + wpref*prefix + w := int(math.Round(float64(minWeight) + (float64(Wmax-minWeight) * composite))) + *total += w + choices = append(choices, Choice{PodName: p.Pod, Weight: w}) + + log.FromContext(ctx).V(logutil.DEBUG).Info("Composite (neg/pos) score", + "pod", p.Pod.GetPod().String(), + "kvUsage", kvUsage, "kvFree", kvFree, + "queue", q, "relQueue", relQueue, + "prefix", prefix, + "wkv", wkv, "wq", wq, "wprefix", wpref, + "composite", composite, "weight", w) + } + return choices +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go new file mode 100644 index 000000000..e9a1b128c --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go @@ -0,0 +1,449 @@ +/* +© 2025 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "context" + "fmt" + "strings" + "time" + + "sigs.k8s.io/controller-runtime/pkg/log" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +const ( + // Poisson sampling parameters for predictions + defaultSamplingMean = 100 // Mean interval between prediction samples (tokens) + maxSampledTokens = 20 // Maximum number of prediction samples per request +) + +// RefreshLastSeenMetrics updates sloCtx.LastSeenMetrics from the latest scheduling result. +func RefreshLastSeenMetrics(ctx context.Context, sloCtx *SLORequestContext) { + if sr := sloCtx.SchedulingResult; sr != nil { + if pr := sr.ProfileResults[sr.PrimaryProfileName]; pr != nil && pr.TargetPods != nil { + for profileName, profileResult := range sr.ProfileResults { + if profileResult != nil && profileResult.TargetPods != nil && len(profileResult.TargetPods) > 0 { + sloCtx.LastSeenMetrics[profileName] = profileResult.TargetPods[0].GetMetrics().Clone() + } + } + } + } else { + log.FromContext(ctx).V(logutil.DEBUG).Info("No scheduling result found, skipping metrics refresh") + } +} + +// GetMetricsForPrediction retrieves the latest metrics for prediction from sloCtx.LastSeenMetrics. +func GetLatestMetricsForProfile(ctx context.Context, sloCtx *SLORequestContext, profileName string) (*backendmetrics.MetricsState, error) { + if len(sloCtx.LastSeenMetrics) == 0 { + return nil, fmt.Errorf("no last seen metrics available for prediction") + } + + // Use the primary profile's metrics for prediction + if metrics, exists := sloCtx.LastSeenMetrics[profileName]; exists { + return metrics, nil + } + + log.FromContext(ctx).V(logutil.DEBUG).Info("No metrics found for profile, trying primary profile", "profile_name", profileName) + + primaryProfileName := sloCtx.SchedulingResult.PrimaryProfileName + if metrics, exists := sloCtx.LastSeenMetrics[primaryProfileName]; exists { + return metrics, nil + } + + return nil, fmt.Errorf("no metrics found for primary profile %s", primaryProfileName) +} + +// ProcessHeader refreshes metrics, applies TTFT prediction, updates sloCtx.PredictedTTFT and timestamp. +func ProcessHeaderForLatencyPrediction( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *SLORequestContext, +) error { + logger := log.FromContext(ctx) + + //just for debugging, print the req context scheduling result cycle state + //print the raw scores in scheduling result + + // Build prediction request + //check if prefill profile name is set, if not use primary profile name + m, err := GetLatestMetricsForProfile(ctx, sloCtx, "prefill") + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err) + return err + } + + targetPod := sloCtx.TargetPod + prefix_cache_score := sloCtx.PrefixCacheScoresForPods[targetPod.String()] + + in := latencypredictor.PredictionRequest{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: 0, + PrefixCacheScore: prefix_cache_score, + } + + // Predict TTFT + start := time.Now() + p, err := predictor.Predict(ctx, in) + dur := time.Since(start) + if err != nil { + logger.V(logutil.DEBUG).Error(err, "header TTFT predict failed", "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTTFT = 0 + } else if p == nil { + logger.V(logutil.DEBUG).Info("header TTFT predict nil", "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTTFT = 0 + } else { + logger.V(logutil.DEBUG).Info("header TTFT succeeded", "value_ms", p.TTFT, "duration_ms", dur.Milliseconds()) + metrics.RecordRequestTTFTPredictionDuration(ctx, sloCtx.SchedulingRequest.TargetModel, sloCtx.IncomingModelName, dur.Seconds()) + + sloCtx.PredictedTTFT = p.TTFT + } + + // Advance timestamp for first token reference + sloCtx.LastTokenTimestamp = time.Now() + RefreshLastSeenMetrics(ctx, sloCtx) + return err +} + +// ProcessFirstToken records actual TTFT, trains, predicts first TPOT, updates sloCtx, and advances timestamp. +func ProcessFirstTokenForLatencyPrediction( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *SLORequestContext, + now time.Time, +) { + logger := log.FromContext(ctx) + + // Initialize sampler + if sloCtx.TokenSampler == nil { + requestID := sloCtx.SchedulingRequest.Headers[requtil.RequestIdHeaderKey] + sloCtx.TokenSampler = NewTokenSampler(requestID, defaultSamplingMean, maxSampledTokens) + logger.V(logutil.DEBUG).Info("Initialized token sampler for first token", "request_id", requestID, "next_prediction_token", sloCtx.TokenSampler.GetNextSampleToken()) + } + + // Actual TTFT + sloCtx.TTFT = float64(now.Sub(sloCtx.RequestReceivedTimestamp).Milliseconds()) + sloCtx.GeneratedTokenCount = 1 + m, err := GetLatestMetricsForProfile(ctx, sloCtx, "prefill") + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err) + return + } + targetPod := sloCtx.TargetPod + prefix_cache_score := sloCtx.PrefixCacheScoresForPods[targetPod.String()] + + // Train TTFT + entry := latencypredictor.TrainingEntry{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + ActualTTFT: sloCtx.TTFT, + ActualTPOT: 0, + Timestamp: now, + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: 0, + PrefixCacheScore: prefix_cache_score, + } + if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil { + logger.V(logutil.DEBUG).Error(err, "record TTFT training failed") + } + m, err = GetLatestMetricsForProfile(ctx, sloCtx, sloCtx.SchedulingResult.PrimaryProfileName) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics", + "error", err) + return + } + + // Predict first TPOT + in := latencypredictor.PredictionRequest{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: sloCtx.GeneratedTokenCount, + PrefixCacheScore: 0, + } + start := time.Now() + p, err := predictor.Predict(ctx, in) + dur := time.Since(start) + if err != nil || p == nil { + logger.V(logutil.DEBUG).Error(err, "first TPOT predict failed", "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, 0) + sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, 0, len(sloCtx.PredictedTPOTObservations)) + } else { + logger.V(logutil.DEBUG).Info("first TPOT succeeded", "value_ms", p.TPOT, "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, p.TPOT) + sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, p.TPOT, len(sloCtx.PredictedTPOTObservations)) + } + metrics.RecordRequestTPOTPredictionDuration(ctx, sloCtx.SchedulingRequest.TargetModel, sloCtx.IncomingModelName, dur.Seconds()) + + // Advance timestamp + sloCtx.LastTokenTimestamp = now + // Refresh metrics + RefreshLastSeenMetrics(ctx, sloCtx) +} + +// ProcessToken records actual inter-token latency, trains, predicts sampled TPOT, updates sloCtx, and advances timestamp. +func ProcessTokenForLatencyPrediction( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *SLORequestContext, + now time.Time, +) { + logger := log.FromContext(ctx) + + // Initialize sampler if not yet + if sloCtx.TokenSampler == nil { + requestID := sloCtx.SchedulingRequest.Headers[requtil.RequestIdHeaderKey] + sloCtx.TokenSampler = NewTokenSampler(requestID, defaultSamplingMean, maxSampledTokens) + logger.V(logutil.DEBUG).Info("Initialized token sampler for subsequent tokens", "request_id", requestID, "next_prediction_token", sloCtx.TokenSampler.GetNextSampleToken()) + } + + // Inter-token latency + latencyMs := float64(now.Sub(sloCtx.LastTokenTimestamp).Milliseconds()) + sloCtx.GeneratedTokenCount++ + + //log the inter-token latency for predicted samples + if sloCtx.GeneratedTokenCount == 2 || sloCtx.TokenSampler.ShouldPredict(sloCtx.GeneratedTokenCount) { //tricky logic, since next sample token is always +1 from current token + sloCtx.TPOTObservations = append(sloCtx.TPOTObservations, latencyMs) + sloCtx.AvgTPOT = calculateRunningAverage(sloCtx.AvgTPOT, latencyMs, len(sloCtx.TPOTObservations)) + } + + m, err := GetLatestMetricsForProfile(ctx, sloCtx, sloCtx.SchedulingResult.PrimaryProfileName) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics", + "error", err) + return + } + // Record actual TPOT + entry := latencypredictor.TrainingEntry{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + ActualTTFT: 0, + ActualTPOT: latencyMs, + Timestamp: now, + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: sloCtx.GeneratedTokenCount - 1, + PrefixCacheScore: 0, // TPOT does not use prefix cache score + } + if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil { + logger.V(logutil.DEBUG).Error(err, "record TPOT training failed") + } + + // Sampled predict + if sloCtx.TokenSampler.ShouldPredict(sloCtx.GeneratedTokenCount) { + in := latencypredictor.PredictionRequest{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: sloCtx.GeneratedTokenCount, + PrefixCacheScore: 0, // TPOT does not use prefix cache score + } + start := time.Now() + p, err := predictor.Predict(ctx, in) + dur := time.Since(start) + if err != nil || p == nil { + logger.V(logutil.DEBUG).Error(err, "TPOT predict failed", "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, 0) + sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, 0, len(sloCtx.PredictedTPOTObservations)) + } else { + logger.V(logutil.DEBUG).Info("TPOT predict succeeded", "value_ms", p.TPOT, "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, p.TPOT) + sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, p.TPOT, len(sloCtx.PredictedTPOTObservations)) + } + metrics.RecordRequestTPOTPredictionDuration(ctx, sloCtx.SchedulingRequest.TargetModel, sloCtx.IncomingModelName, dur.Seconds()) + + sloCtx.TokenSampler.RecordPrediction(sloCtx.GeneratedTokenCount) + } + + // Advance timestamp + sloCtx.LastTokenTimestamp = now + // Refresh metrics + RefreshLastSeenMetrics(ctx, sloCtx) +} + +// PredictWithMetrics predicts TTFT or TPOT based on provided metrics state and token count. +func PredictWithMetrics( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + metricsState *backendmetrics.MetricsState, + prompt string, + generatedTokenCount int, + prefixcachescore float64, +) (*latencypredictor.PredictionResponse, error) { + logger := log.FromContext(ctx) + + if metricsState == nil { + return nil, fmt.Errorf("metrics state cannot be nil") + } + + // Build prediction request + in := latencypredictor.PredictionRequest{ + KVCachePercentage: metricsState.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(prompt)), + NumRequestWaiting: metricsState.WaitingQueueSize, + NumRequestRunning: metricsState.RunningQueueSize, + NumTokensGenerated: generatedTokenCount, + PrefixCacheScore: prefixcachescore, + } + + // Perform prediction + start := time.Now() + result, err := predictor.Predict(ctx, in) + duration := time.Since(start) + + if err != nil { + logger.V(logutil.DEBUG).Error(err, "prediction failed", + "duration_ms", duration.Milliseconds(), + "input_tokens", in.InputTokenLength, + "generated_tokens", generatedTokenCount, + "kv_cache_percent", in.KVCachePercentage, + "waiting_queue", in.NumRequestWaiting, + "running_queue", in.NumRequestRunning, + "prefix_cache_score", in.PrefixCacheScore) + return nil, err + } + + if result == nil { + logger.V(logutil.DEBUG).Info("prediction returned nil", + "duration_ms", duration.Milliseconds()) + return nil, fmt.Errorf("prediction returned nil result") + } + + logger.V(logutil.DEBUG).Info("prediction succeeded", + "tpot_ms", result.TPOT, + "ttft_ms", result.TTFT, + "duration_ms", duration.Milliseconds(), + "input_tokens", in.InputTokenLength, + "generated_tokens", generatedTokenCount, + "kv_cache_percent", in.KVCachePercentage, + "waiting_queue", in.NumRequestWaiting, + "running_queue", in.NumRequestRunning, + "prefix_cache_score", in.PrefixCacheScore) + + return result, nil +} + +// BulkPredictWithMetrics performs bulk predictions for multiple pods using their metrics states. +// Returns predictions in the same order as the input slices. +func BulkPredictWithMetrics( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + metricsStates []*backendmetrics.MetricsState, + prompts []string, + generatedTokenCounts []int, + prefixCacheScores []float64, +) ([]*latencypredictor.PredictionResponse, error) { + logger := log.FromContext(ctx) + + // Validate input lengths + if len(metricsStates) != len(prompts) || len(prompts) != len(generatedTokenCounts) || len(generatedTokenCounts) != len(prefixCacheScores) { + return nil, fmt.Errorf("input slice lengths must match: metrics=%d, prompts=%d, tokenCounts=%d, prefixScores=%d", + len(metricsStates), len(prompts), len(generatedTokenCounts), len(prefixCacheScores)) + } + + if len(metricsStates) == 0 { + return []*latencypredictor.PredictionResponse{}, nil + } + + // Validate that no metrics state is nil + for i, metricsState := range metricsStates { + if metricsState == nil { + return nil, fmt.Errorf("metrics state at index %d cannot be nil", i) + } + } + + // Build bulk prediction requests + bulkRequests := make([]latencypredictor.PredictionRequest, len(metricsStates)) + for i := range metricsStates { + bulkRequests[i] = latencypredictor.PredictionRequest{ + KVCachePercentage: metricsStates[i].KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(prompts[i])), + NumRequestWaiting: metricsStates[i].WaitingQueueSize, + NumRequestRunning: metricsStates[i].RunningQueueSize, + NumTokensGenerated: generatedTokenCounts[i], + PrefixCacheScore: prefixCacheScores[i], + } + } + + // Perform bulk prediction + start := time.Now() + bulkResponse, err := predictor.PredictBulkStrict(ctx, bulkRequests) + duration := time.Since(start) + + if err != nil { + logger.V(logutil.DEBUG).Error(err, "bulk prediction failed", + "duration_ms", duration.Milliseconds(), + "request_count", len(bulkRequests)) + return nil, err + } + + if bulkResponse == nil { + logger.V(logutil.DEBUG).Info("bulk prediction returned nil", + "duration_ms", duration.Milliseconds()) + return nil, fmt.Errorf("bulk prediction returned nil result") + } + + // Convert to pointer slice for consistency with single prediction + results := make([]*latencypredictor.PredictionResponse, len(bulkResponse.Predictions)) + for i := range bulkResponse.Predictions { + results[i] = &bulkResponse.Predictions[i] + } + + logger.V(logutil.DEBUG).Info("bulk prediction succeeded", + "duration_ms", duration.Milliseconds(), + "request_count", len(bulkRequests), + "successful_predictions", bulkResponse.SuccessfulPredictions, + "failed_predictions", bulkResponse.FailedPredictions, + "processing_time_ms", bulkResponse.ProcessingTimeMs) + + // Log detailed results if at trace level + if logger.V(logutil.TRACE).Enabled() { + for i, result := range results { + logger.V(logutil.TRACE).Info("bulk prediction result", + "index", i, + "ttft_ms", result.TTFT, + "tpot_ms", result.TPOT, + "input_tokens", bulkRequests[i].InputTokenLength, + "generated_tokens", bulkRequests[i].NumTokensGenerated, + "kv_cache_percent", bulkRequests[i].KVCachePercentage, + "waiting_queue", bulkRequests[i].NumRequestWaiting, + "running_queue", bulkRequests[i].NumRequestRunning, + "prefix_cache_score", bulkRequests[i].PrefixCacheScore) + } + } + + return results, nil +} + +// calculateRunningAverage calculates the running average efficiently +func calculateRunningAverage(currentAvg float64, newValue float64, count int) float64 { + if count == 0 { + return 0 + } + if count == 1 { + return newValue + } + return currentAvg + (newValue-currentAvg)/float64(count) +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go new file mode 100644 index 000000000..dcd9c40e0 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go @@ -0,0 +1,122 @@ +/* +© 2025 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "context" + + "sigs.k8s.io/controller-runtime/pkg/log" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +// generatePredictions creates prediction results for all candidate pods +func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *SLORequestContext, candidatePods []schedulingtypes.Pod) []PodPredictionResult { + logger := log.FromContext(ctx) + predictions := make([]PodPredictionResult, 0, len(candidatePods)) + + for _, pod := range candidatePods { + predResult := PodPredictionResult{Pod: pod} + + logger.V(logutil.TRACE).Info("Candidate pod for scheduling", "pod", pod.GetPod().String(), "metrics", pod.GetMetrics().String()) + + // Get prefix cache score for the pod + prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod) + + sloCtx.PrefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore + + logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "prefixCacheScore", prefixCacheScore) + + // Generate prediction + prediction, err := PredictWithMetrics(ctx, s.latencypredictor, pod.GetMetrics(), request.Body.Completions.Prompt, 1, prefixCacheScore) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping pod due to prediction error", "pod", pod.GetPod().String(), "error", err) + predResult.Error = err + predictions = append(predictions, predResult) + continue + } + predResult.PrefixCacheScore = prefixCacheScore + predResult.TTFT = prediction.TTFT + predResult.TPOT = prediction.TPOT + podMinTPOTSLO := 0.0 + //if pod.GetPod().RunningRequests.Peek() != nil { + // podMinTPOTSLO = pod.GetPod().RunningRequests.Peek().TPOT + //} + // Do this: + podMinTPOTSLO = s.getPodMinTPOTSLO(pod) + predResult.TTFTValid, predResult.TPOTValid, predResult.IsValid, predResult.Headroom, predResult.TTFTHeadroom = s.validatePrediction(prediction, sloCtx, podMinTPOTSLO) + + logger.V(logutil.DEBUG).Info("Prediction for scheduling", + "pod", pod.GetPod().String(), + "prefixCacheScore", prefixCacheScore, + "TTFT", prediction.TTFT, + "TPOT", prediction.TPOT, + "buffer", SLOBufferFactor, + "podMinTPOTSLO", podMinTPOTSLO, + "ttftSLO", sloCtx.TTFTSLO, + "requestTPOTSLO", sloCtx.AvgTPOTSLO, + "tpotHeadroom", predResult.Headroom, + "ttftHeadroom", predResult.TTFTHeadroom, + "tpotValid", predResult.TPOTValid, + "ttftValid", predResult.TTFTValid, + "headroomStrategy", s.headroomStrategy) + + predictions = append(predictions, predResult) + } + + return predictions +} + +// updateRequestContextWithPredictions updates the request context with prediction data +func (s *SLOAwareRouter) updateRequestContextWithPredictions(sloCtx *SLORequestContext, predictions []PodPredictionResult) { + for _, pred := range predictions { + if pred.Error == nil { + podKey := pred.Pod.GetPod().String() + if sloCtx.PredictedTTFTForScheduling == nil { + sloCtx.PredictedTTFTForScheduling = make(map[string]float64) + } + if sloCtx.PredictedTPOTForScheduling == nil { + sloCtx.PredictedTPOTForScheduling = make(map[string]float64) + } + sloCtx.PredictedTTFTForScheduling[podKey] = pred.TTFT + sloCtx.PredictedTPOTForScheduling[podKey] = pred.TPOT + } + } +} + +func (s *SLOAwareRouter) validatePrediction( + pred *latencypredictor.PredictionResponse, + sloCtx *SLORequestContext, + podMinTPOTSLO float64, +) (ttftOk, tpotOk, isValid bool, headroom float64, ttftHeadroom float64) { + + bufferedTPOT := sloCtx.AvgTPOTSLO * SLOBufferFactor + // a podMinTPOTSLO of 0 means no either no requests, or no TPOT SLOs specified on running requests + if podMinTPOTSLO > 0 { + if podMinTPOTSLO < sloCtx.AvgTPOTSLO { + //print debug message + log.FromContext(context.Background()).V(logutil.DEBUG).Info("Pod min TPOT SLO is less than the req SLO, adjusting", "podMinTPOTSLO", podMinTPOTSLO, "bufferedTPOT", sloCtx.AvgTPOTSLO) + } + bufferedTPOT = min(bufferedTPOT, podMinTPOTSLO*SLOBufferFactor) + } + + tpotOk = pred.TPOT < bufferedTPOT + ttftOk = pred.TTFT < sloCtx.TTFTSLO + + isValid = ttftOk && tpotOk + headroom = bufferedTPOT - pred.TPOT + ttftHeadroom = sloCtx.TTFTSLO - pred.TTFT + return +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go new file mode 100644 index 000000000..17399c6a8 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go @@ -0,0 +1,246 @@ +package slo_aware_router + +import ( + "context" + "fmt" + "time" + + "github.com/go-logr/logr" + "github.com/google/uuid" + "sigs.k8s.io/controller-runtime/pkg/log" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" +) + +var _ requestcontrol.PreRequest = &SLOAwareRouter{} +var _ requestcontrol.ResponseReceived = &SLOAwareRouter{} +var _ requestcontrol.ResponseStreaming = &SLOAwareRouter{} +var _ requestcontrol.ResponseComplete = &SLOAwareRouter{} + +type SLORequestContext struct { + SchedulingRequest schedulingtypes.LLMRequest + TargetPod *backend.Pod + SchedulingResult *schedulingtypes.SchedulingResult + LastSeenMetrics map[string]*backendmetrics.MetricsState + LastTokenTimestamp time.Time + RequestReceivedTimestamp time.Time + GeneratedTokenCount int + IncomingModelName string + TTFT float64 + PredictedTTFT float64 + AvgTPOT float64 + AvgPredictedTPOT float64 + TokenSampler *TokenSampler + TPOTObservations []float64 + PredictedTPOTObservations []float64 + + PrefixCacheScoresForPods map[string]float64 + + // TTFTSLO is the target time to first token SLO for the request. + TTFTSLO float64 + // TPOTSLO is the target time per output token SLO for the request. + AvgTPOTSLO float64 + + // PredictorBasedScheduling indicates whether to use predictor based scheduling. + PredictorBasedScheduling bool + //PredictedTTFTForScheduling is the map of pod names to predicted TTFT values for scheduling. + PredictedTTFTForScheduling map[string]float64 + // PredictedTPOTForScheduling is the map of pod names to predicted TPOT values for scheduling. + PredictedTPOTForScheduling map[string]float64 + + // boolean set if request has valid pod based on predictions + HasValidPod bool +} + +func NewSLORequestContext(request *schedulingtypes.LLMRequest) *SLORequestContext { + return &SLORequestContext{ + SchedulingRequest: *request, + LastSeenMetrics: make(map[string]*backendmetrics.MetricsState), + PrefixCacheScoresForPods: make(map[string]float64), + PredictedTTFTForScheduling: make(map[string]float64), + PredictedTPOTForScheduling: make(map[string]float64), + } +} + +func (s *SLOAwareRouter) getSLOContextForRequest(request *schedulingtypes.LLMRequest) (*SLORequestContext, error) { + id := request.Headers[requtil.RequestIdHeaderKey] + if ctx, exists := s.sloContextStore.Load(id); exists { + return ctx.(*SLORequestContext), nil + } + return nil, fmt.Errorf("SLO context not found for request ID: %s", id) +} + +func (s *SLOAwareRouter) setSLOContextForRequest(request *schedulingtypes.LLMRequest, ctx *SLORequestContext) { + id := request.Headers[requtil.RequestIdHeaderKey] + s.sloContextStore.Store(id, ctx) +} + +func (s *SLOAwareRouter) deleteSLOContextForRequest(request *schedulingtypes.LLMRequest) { + id := request.Headers[requtil.RequestIdHeaderKey] + s.sloContextStore.Delete(id) +} + +// --- RequestControl Hooks --- + +func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult) { + logger := log.FromContext(ctx) + + if schedulingResult == nil || len(schedulingResult.ProfileResults) == 0 { + logger.V(logutil.DEBUG).Info("SLOAwareRouter: Skipping PreRequest because no scheduling result was provided.") + return + } + + targetPod := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName].TargetPods[0].GetPod() + + podName := types.NamespacedName{ + Name: targetPod.NamespacedName.Name, + Namespace: targetPod.NamespacedName.Namespace, + } + + logger.V(logutil.DEBUG).Info("request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey], "podName", podName) + if request.Headers[requtil.RequestIdHeaderKey] == "" { + request.Headers[requtil.RequestIdHeaderKey] = uuid.New().String() + logger.V(logutil.DEBUG).Info("Generated new request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey]) + logger.V(logutil.DEBUG).Info("request headers for SLO tracking", "requestHeaders", request.Headers) + } + + id := request.Headers[requtil.RequestIdHeaderKey] + podRequestList, ok := t.runningRequestLists[podName] + if !ok { + podRequestList = NewRequestPriorityQueue() + t.runningRequestLists[podName] = podRequestList + } + + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + id := request.Headers[requtil.RequestIdHeaderKey] + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter.PreRequest: Failed to get SLO context for request", "requestID", id) + return + } + + added := podRequestList.Add(id, sloCtx.AvgTPOTSLO) + if !added { + logger.V(logutil.DEBUG).Info("SLOAwareRouter: Item already exists in queue", "podName", podName, "requestID", id) + } + + // Set up SLO request context + sloCtx.TargetPod = targetPod + sloCtx.SchedulingResult = schedulingResult + sloCtx.RequestReceivedTimestamp = time.Now() + RefreshLastSeenMetrics(ctx, sloCtx) + t.setSLOContextForRequest(request, sloCtx) +} + +func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) { + logger := log.FromContext(ctx) + id := request.Headers[requtil.RequestIdHeaderKey] + + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Failed to get SLO context for request", "requestID", id) + return + } + + if !t.CheckPredictor(logger, targetPod) { + return + } + + if err := ProcessHeaderForLatencyPrediction(ctx, t.latencypredictor, sloCtx); err != nil { + logger.V(logutil.DEBUG).Error(err, "ProcessHeader in latencypredictor failed") + } + +} + +func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) { + logger := log.FromContext(ctx) + if !t.CheckPredictor(logger, pod) { + return + } + + now := time.Now() + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + id := request.Headers[requtil.RequestIdHeaderKey] + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter.ResponseStreaming: Failed to get SLO context for request", "requestID", id) + return + } + + if sloCtx.TTFT == 0 { + ProcessFirstTokenForLatencyPrediction(ctx, t.latencypredictor, sloCtx, now) + } else { + ProcessTokenForLatencyPrediction(ctx, t.latencypredictor, sloCtx, now) + } + +} + +func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) { + logger := log.FromContext(ctx) + targetPod := pod + if !t.CheckPredictor(logger, targetPod) { + return + } + + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + id := request.Headers[requtil.RequestIdHeaderKey] + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter.ResponseComplete: Failed to get SLO context for request", "requestID", id) + return + } + + if sloCtx.TTFT > 0 { + logger.V(logutil.DEBUG).Info("Averages calculated", "avgActualTTFT", sloCtx.TTFT, "avgPredictedTTFT", sloCtx.PredictedTTFT) + metrics.RecordRequestTTFT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.TTFT/1000) + metrics.RecordRequestPredictedTTFT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.PredictedTTFT/1000) + if sloCtx.TTFTSLO > 0 { + metrics.RecordRequestTTFTWithSLO(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.TTFT, sloCtx.TTFTSLO) + } + } + + if sloCtx.AvgTPOT > 0 { + logger.V(logutil.DEBUG).Info("Averages calculated", "avgActualTPOT", sloCtx.AvgTPOT, "avgPredictedTPOT", sloCtx.AvgPredictedTPOT) + metrics.RecordRequestTPOT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgTPOT/1000) + metrics.RecordRequestPredictedTPOT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgPredictedTPOT/1000) + if sloCtx.AvgTPOTSLO > 0 { + metrics.RecordRequestTPOTWithSLO(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgTPOT, sloCtx.AvgTPOTSLO) + } + } + + logger.V(logutil.DEBUG).Info("SLO Aware Routing Mode", "PredictorBasedScheduling", sloCtx.PredictorBasedScheduling) + + podName := types.NamespacedName{ + Name: targetPod.NamespacedName.Name, + Namespace: targetPod.NamespacedName.Namespace, + } + + id := request.Headers[requtil.RequestIdHeaderKey] + podRequestList, ok := t.runningRequestLists[podName] + if !ok { + err := fmt.Errorf("no running request list found for pod %s", podName.String()) + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Failed to remove request from queue", "requestID", id) + } + + _, removed := podRequestList.Remove(id) + if !removed { + logger.V(logutil.DEBUG).Info("SLOAwareRouter: Item not found in queue", "podName", podName, "requestID", id) + } + t.deleteSLOContextForRequest(request) +} + +func (t *SLOAwareRouter) CheckPredictor(logger logr.Logger, targetPod *backend.Pod) bool { + if targetPod == nil { + logger.V(logutil.DEBUG).Info("SLOAwareRouter: Skipping PostResponse because no target pod was provided.") + return false + } + if t.latencypredictor == nil { + logger.V(logutil.DEBUG).Info("SLOAwareRouter: Skipping PostResponse because predictor missing") + return false + } + return true +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go new file mode 100644 index 000000000..1199be641 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go @@ -0,0 +1,227 @@ +package slo_aware_router + +import ( + "container/heap" + "fmt" + "sort" + "strings" + "sync" +) + +// Request represents an element in the priority queue. +// The index is needed by heap.Remove and is maintained by the heap.Interface methods. +type Request struct { + ID string // Unique identifier + TPOT float64 // The priority value (lower is higher priority) + index int +} + +// RequestPriorityQueue implements a priority queue with item removal by ID. +type RequestPriorityQueue struct { + items []*Request + lookup map[string]*Request + mutex sync.RWMutex +} + +// NewRequestPriorityQueue initializes and returns a new PriorityQueue. +func NewRequestPriorityQueue() *RequestPriorityQueue { + return &RequestPriorityQueue{ + lookup: make(map[string]*Request), + items: []*Request{}, + } +} + +// Clone creates a deep copy of the priority queue. +// The new queue is completely independent of the original. +func (pq *RequestPriorityQueue) Clone() *RequestPriorityQueue { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + // Initialize a new priority queue with pre-allocated capacity. + clonedPq := &RequestPriorityQueue{ + items: make([]*Request, len(pq.items)), + lookup: make(map[string]*Request, len(pq.lookup)), + } + + // Iterate through the original items to create deep copies. + for i, oldItem := range pq.items { + // Create a new Request struct, copying all values. + newItem := &Request{ + ID: oldItem.ID, + TPOT: oldItem.TPOT, + index: oldItem.index, + } + + // Assign the new item to the cloned queue's items slice. + clonedPq.items[i] = newItem + // Update the lookup map in the cloned queue to point to the new item. + clonedPq.lookup[newItem.ID] = newItem + } + + return clonedPq +} + +// Len is the number of items in the queue. +func (pq *RequestPriorityQueue) Len() int { return len(pq.items) } + +// Less reports whether the item with index i should sort before the item with index j. +func (pq *RequestPriorityQueue) Less(i, j int) bool { + return pq.items[i].TPOT < pq.items[j].TPOT +} + +// Swap swaps the items with indexes i and j. +func (pq *RequestPriorityQueue) Swap(i, j int) { + pq.items[i], pq.items[j] = pq.items[j], pq.items[i] + pq.items[i].index = i + pq.items[j].index = j +} + +// Push adds an item to the heap. +func (pq *RequestPriorityQueue) Push(x any) { + item := x.(*Request) + item.index = len(pq.items) + pq.items = append(pq.items, item) +} + +// Pop removes and returns the minimum item from the heap. +func (pq *RequestPriorityQueue) Pop() any { + n := len(pq.items) + item := pq.items[n-1] + pq.items[n-1] = nil // avoid memory leak + item.index = -1 // for safety + pq.items = pq.items[0 : n-1] + return item +} + +// Add adds a new item to the queue. +// Returns true if the item was added, false if an item with the same ID already exists. +func (pq *RequestPriorityQueue) Add(id string, tpot float64) bool { + pq.mutex.Lock() + defer pq.mutex.Unlock() + + // Validate input + if id == "" { + return false + } + if tpot < 0 { + return false + } + + // If item already exists, do not add + if _, exists := pq.lookup[id]; exists { + return false + } + + item := &Request{ + ID: id, + TPOT: tpot, + } + pq.lookup[id] = item + heap.Push(pq, item) + return true +} + +// Update modifies the TPOT value of an existing item in the queue. +// If the item doesn't exist, this method does nothing. +func (pq *RequestPriorityQueue) Update(id string, tpot float64) bool { + pq.mutex.Lock() + defer pq.mutex.Unlock() + + // Validate input + if tpot < 0 { + return false + } + + item, exists := pq.lookup[id] + if !exists { + return false + } + + item.TPOT = tpot + heap.Fix(pq, item.index) + return true +} + +// Remove removes an item from the queue by its ID. +func (pq *RequestPriorityQueue) Remove(id string) (*Request, bool) { + pq.mutex.Lock() + defer pq.mutex.Unlock() + + item, ok := pq.lookup[id] + if !ok { + return nil, false + } + removed := heap.Remove(pq, item.index).(*Request) + delete(pq.lookup, id) + return removed, true +} + +// Peek returns the item with the lowest value without removing it. +func (pq *RequestPriorityQueue) Peek() *Request { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + if len(pq.items) == 0 { + return nil + } + return pq.items[0] +} + +// GetSize returns the current number of items in the queue. +func (pq *RequestPriorityQueue) GetSize() int { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + return len(pq.items) +} + +// Contains checks if an item with the given ID exists in the queue. +func (pq *RequestPriorityQueue) Contains(id string) bool { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + _, exists := pq.lookup[id] + return exists +} + +// ToSlice returns a copy of all items in the queue, sorted by ID for stable comparison. +// This is primarily intended for testing and validation. +func (pq *RequestPriorityQueue) ToSlice() []*Request { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + // Create a copy to avoid returning a reference to the internal slice. + itemsCopy := make([]*Request, len(pq.items)) + copy(itemsCopy, pq.items) + + // Sort by ID to have a deterministic order for comparison in tests. + sort.Slice(itemsCopy, func(i, j int) bool { + return itemsCopy[i].ID < itemsCopy[j].ID + }) + + return itemsCopy +} + +// String returns a string representation of the queue for debugging. +func (pq *RequestPriorityQueue) String() string { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + if len(pq.items) == 0 { + return "RequestPriorityQueue: []" + } + + var builder strings.Builder + builder.WriteString("RequestPriorityQueue: [") + + for i, item := range pq.items { + if i > 0 { + builder.WriteString(", ") + } + builder.WriteString(item.ID) + builder.WriteString("(") + builder.WriteString(fmt.Sprintf("%.2f", item.TPOT)) + builder.WriteString(")") + } + + builder.WriteString("]") + return builder.String() +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go new file mode 100644 index 000000000..a8eba5fe1 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go @@ -0,0 +1,391 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "fmt" + "sync" + "testing" + "time" +) + +func TestNewRequestPriorityQueue(t *testing.T) { + pq := NewRequestPriorityQueue() + + if pq == nil { + t.Fatal("NewRequestPriorityQueue returned nil") + } + + if pq.GetSize() != 0 { + t.Errorf("Expected empty queue, got size %d", pq.GetSize()) + } + + if pq.Peek() != nil { + t.Error("Expected nil from Peek on empty queue") + } +} + +func TestAdd(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test successful add + if !pq.Add("req1", 2.5) { + t.Error("Expected Add to return true for new item") + } + + if pq.GetSize() != 1 { + t.Errorf("Expected size 1, got %d", pq.GetSize()) + } + + // Test duplicate add + if pq.Add("req1", 3.0) { + t.Error("Expected Add to return false for duplicate ID") + } + + if pq.GetSize() != 1 { + t.Errorf("Expected size 1 after duplicate add, got %d", pq.GetSize()) + } + + // Test validation + if pq.Add("", 1.0) { + t.Error("Expected Add to return false for empty ID") + } + + if pq.Add("req2", -1.0) { + t.Error("Expected Add to return false for negative TPOT") + } +} + +func TestPriorityOrdering(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Add items with different priorities + pq.Add("high", 1.0) // highest priority (lowest TPOT) + pq.Add("medium", 5.0) // medium priority + pq.Add("low", 10.0) // lowest priority (highest TPOT) + + // Check that highest priority item is at the top + peek := pq.Peek() + if peek == nil || peek.ID != "high" || peek.TPOT != 1.0 { + t.Errorf("Expected high priority item at top, got %+v", peek) + } + + // Test removal order + expected := []struct { + id string + tpot float64 + }{ + {"high", 1.0}, + {"medium", 5.0}, + {"low", 10.0}, + } + + for _, exp := range expected { + item := pq.Peek() + if item.ID != exp.id || item.TPOT != exp.tpot { + t.Errorf("Expected %s(%.1f), got %s(%.1f)", exp.id, exp.tpot, item.ID, item.TPOT) + } + + removed, ok := pq.Remove(item.ID) + if !ok || removed.ID != exp.id { + t.Errorf("Failed to remove %s", exp.id) + } + } +} + +func TestRemove(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test remove from empty queue + if _, ok := pq.Remove("nonexistent"); ok { + t.Error("Expected Remove to return false for empty queue") + } + + // Add some items + pq.Add("req1", 1.0) + pq.Add("req2", 2.0) + pq.Add("req3", 3.0) + + // Test successful remove + removed, ok := pq.Remove("req2") + if !ok || removed.ID != "req2" || removed.TPOT != 2.0 { + t.Errorf("Expected to remove req2(2.0), got %+v, ok=%v", removed, ok) + } + + if pq.GetSize() != 2 { + t.Errorf("Expected size 2 after removal, got %d", pq.GetSize()) + } + + // Test remove nonexistent + if _, ok := pq.Remove("req2"); ok { + t.Error("Expected Remove to return false for already removed item") + } + + // Verify remaining items are still in correct order + if peek := pq.Peek(); peek.ID != "req1" { + t.Errorf("Expected req1 at top, got %s", peek.ID) + } +} + +func TestUpdate(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test update nonexistent item + if pq.Update("nonexistent", 1.0) { + t.Error("Expected Update to return false for nonexistent item") + } + + // Add items + pq.Add("req1", 1.0) + pq.Add("req2", 2.0) + pq.Add("req3", 3.0) + + // Update to make req3 highest priority + if !pq.Update("req3", 0.5) { + t.Error("Expected Update to return true for existing item") + } + + // Check that req3 is now at the top + if peek := pq.Peek(); peek.ID != "req3" || peek.TPOT != 0.5 { + t.Errorf("Expected req3(0.5) at top, got %s(%.1f)", peek.ID, peek.TPOT) + } + + // Test validation + if pq.Update("req1", -1.0) { + t.Error("Expected Update to return false for negative TPOT") + } +} + +func TestContains(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test empty queue + if pq.Contains("req1") { + t.Error("Expected Contains to return false for empty queue") + } + + // Add item + pq.Add("req1", 1.0) + + // Test existing item + if !pq.Contains("req1") { + t.Error("Expected Contains to return true for existing item") + } + + // Test nonexistent item + if pq.Contains("req2") { + t.Error("Expected Contains to return false for nonexistent item") + } + + // Test after removal + pq.Remove("req1") + if pq.Contains("req1") { + t.Error("Expected Contains to return false after removal") + } +} + +func TestClone(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test clone of empty queue + clone := pq.Clone() + if clone.GetSize() != 0 { + t.Error("Expected cloned empty queue to be empty") + } + + // Add items to original + pq.Add("req1", 1.0) + pq.Add("req2", 2.0) + pq.Add("req3", 3.0) + + // Clone with items + clone = pq.Clone() + + // Verify clone has same items + if clone.GetSize() != pq.GetSize() { + t.Errorf("Expected clone size %d, got %d", pq.GetSize(), clone.GetSize()) + } + + // Verify independence - modify original + pq.Add("req4", 4.0) + if clone.GetSize() == pq.GetSize() { + t.Error("Clone should be independent of original") + } + + // Verify independence - modify clone + clone.Remove("req1") + if !pq.Contains("req1") { + t.Error("Original should not be affected by clone modifications") + } + + // Verify deep copy - items should be different instances + origPeek := pq.Peek() + clonePeek := clone.Peek() + if origPeek == clonePeek { + t.Error("Clone should create new Request instances, not share pointers") + } +} + +func TestString(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test empty queue + str := pq.String() + expected := "RequestPriorityQueue: []" + if str != expected { + t.Errorf("Expected %q, got %q", expected, str) + } + + // Test with items + pq.Add("req1", 1.5) + pq.Add("req2", 2.25) + + str = pq.String() + // Should contain both items in priority order + if !contains(str, "req1(1.50)") || !contains(str, "req2(2.25)") { + t.Errorf("String output missing expected items: %s", str) + } +} + +func TestConcurrency(t *testing.T) { + pq := NewRequestPriorityQueue() + const numWorkers = 10 + const itemsPerWorker = 100 + + var wg sync.WaitGroup + + // Launch workers that add items + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func(workerID int) { + defer wg.Done() + for j := 0; j < itemsPerWorker; j++ { + id := fmt.Sprintf("worker%d-item%d", workerID, j) + tpot := float64(j) + float64(workerID)*0.1 + pq.Add(id, tpot) + } + }(i) + } + + // Launch workers that read from the queue + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < itemsPerWorker/2; j++ { + pq.Peek() + pq.GetSize() + time.Sleep(time.Microsecond) + } + }() + } + + wg.Wait() + + // Verify final state + expectedSize := numWorkers * itemsPerWorker + if pq.GetSize() != expectedSize { + t.Errorf("Expected final size %d, got %d", expectedSize, pq.GetSize()) + } +} + +func TestLargeQueue(t *testing.T) { + pq := NewRequestPriorityQueue() + const numItems = 10000 + + // Add many items + for i := 0; i < numItems; i++ { + id := fmt.Sprintf("item%d", i) + tpot := float64(numItems - i) // Reverse order so item0 has highest priority + pq.Add(id, tpot) + } + + if pq.GetSize() != numItems { + t.Errorf("Expected size %d, got %d", numItems, pq.GetSize()) + } + + // Verify priority ordering by removing items + lastTPOT := -1.0 + for i := 0; i < numItems; i++ { + item := pq.Peek() + if item.TPOT < lastTPOT { + t.Errorf("Priority order violated: %.1f < %.1f", item.TPOT, lastTPOT) + } + lastTPOT = item.TPOT + pq.Remove(item.ID) + } + + if pq.GetSize() != 0 { + t.Errorf("Expected empty queue after removing all items, got size %d", pq.GetSize()) + } +} + +func BenchmarkAdd(b *testing.B) { + pq := NewRequestPriorityQueue() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + id := fmt.Sprintf("item%d", i) + pq.Add(id, float64(i)) + } +} + +func BenchmarkPeek(b *testing.B) { + pq := NewRequestPriorityQueue() + + // Pre-populate queue + for i := 0; i < 1000; i++ { + pq.Add(fmt.Sprintf("item%d", i), float64(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pq.Peek() + } +} + +func BenchmarkRemove(b *testing.B) { + pq := NewRequestPriorityQueue() + + // Pre-populate queue + for i := 0; i < b.N; i++ { + pq.Add(fmt.Sprintf("item%d", i), float64(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pq.Remove(fmt.Sprintf("item%d", i)) + } +} + +// Helper function to check if a string contains a substring +func contains(s, substr string) bool { + return len(s) >= len(substr) && + (s == substr || + s[:len(substr)] == substr || + s[len(s)-len(substr):] == substr || + containsHelper(s, substr)) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go new file mode 100644 index 000000000..758ae401b --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go @@ -0,0 +1,122 @@ +// NewTokenSampler creates a new sampler with deterministic seeding + +package slo_aware_router + +import ( + "hash/fnv" + "math" + "math/rand" + "time" +) + +// TokenSampler handles Poisson-distributed sampling for predictions only +// Training happens on every token regardless of sampling +type TokenSampler struct { + rng *rand.Rand + nextSampleToken int + samplingMean float64 + maxSamples int + sampleCount int +} + +// SetSamplingMean sets the sampling mean (lambda) for the Poisson distribution +func (ts *TokenSampler) SetSamplingMean(mean float64) { + ts.samplingMean = mean +} + +// SetMaxSamples sets the maximum number of samples +func (ts *TokenSampler) SetMaxSamples(max int) { + ts.maxSamples = max +} + +// SetSampleCount sets the current number of predictions made +func (ts *TokenSampler) SetSampleCount(count int) { + ts.sampleCount = count +} + +func NewTokenSampler(requestID string, samplingMean float64, maxSamples int) *TokenSampler { + // Use request ID hash as seed for reproducibility + seed := int64(0) + if requestID != "" { + hash := fnv.New64a() + hash.Write([]byte(requestID)) + seed = int64(hash.Sum64()) + } + if seed == 0 { + seed = time.Now().UnixNano() + } + + sampler := &TokenSampler{ + rng: rand.New(rand.NewSource(seed)), + samplingMean: samplingMean, + maxSamples: maxSamples, + } + + // Set first sample token (skip token 1 since that's TTFT) + sampler.nextSampleToken = 2 + sampler.poissonNext() + + return sampler +} + +// poissonNext generates the next interval using Poisson distribution +func (ts *TokenSampler) poissonNext() int { + lambda := ts.samplingMean + if lambda <= 0 { + return 1 + } + + // For small lambda, use Knuth's algorithm + if lambda < 30 { + l := math.Exp(-lambda) + k := 0 + p := 1.0 + + for p > l { + k++ + p *= ts.rng.Float64() + } + return k - 1 + } + + // For larger lambda, use normal approximation + normal := ts.rng.NormFloat64() + interval := int(math.Round(lambda + math.Sqrt(lambda)*normal)) + if interval < 1 { + return 1 + } + return interval +} + +// ShouldPredict determines if we should make a prediction for the current token +func (ts *TokenSampler) ShouldPredict(currentToken int) bool { + return currentToken == ts.nextSampleToken && ts.sampleCount < ts.maxSamples +} + +// RecordPrediction records that a prediction was made and calculates the next sample token +func (ts *TokenSampler) RecordPrediction(currentToken int) { + if ts.sampleCount >= ts.maxSamples { + return + } + + ts.sampleCount++ + + if ts.sampleCount < ts.maxSamples { + interval := ts.poissonNext() + ts.nextSampleToken = currentToken + interval + } +} + +// GetNextSampleToken returns the next token to predict for +func (ts *TokenSampler) GetNextSampleToken() int { + return ts.nextSampleToken +} + +// SetNextSampleToken sets the next token to predict for +func (ts *TokenSampler) SetNextSampleToken(token int) { + ts.nextSampleToken = token +} + +// GetSampleCount returns the current number of predictions made +func (ts *TokenSampler) GetSampleCount() int { + return ts.sampleCount +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go new file mode 100644 index 000000000..d48d3a1bc --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go @@ -0,0 +1,290 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "fmt" + "math/rand" + "sync" + "time" + + "sigs.k8s.io/controller-runtime/pkg/log" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +type PodPredictionResult struct { + Pod schedulingtypes.Pod + TTFT float64 + TPOT float64 + TTFTValid bool + TPOTValid bool + IsValid bool + Error error + Headroom float64 // Headroom for the pod, if applicable + TTFTHeadroom float64 // TTFT headroom for the pod + PrefixCacheScore float64 // Prefix cache score for the pod +} + +type SLOAwareRouter struct { + tn plugins.TypedName + latencypredictor latencypredictor.PredictorInterface + runningRequestLists map[types.NamespacedName]*RequestPriorityQueue + sloContextStore sync.Map // map[string]*SLORequestContext + headroomStrategy HeadroomStrategy +} + +var _ framework.Scorer = &SLOAwareRouter{} + +func NewSLOAwareRouter(latencypredictor latencypredictor.PredictorInterface, strategy HeadroomStrategy) *SLOAwareRouter { + return &SLOAwareRouter{ + tn: plugins.TypedName{Type: SLOAwareRouterPluginType, Name: SLOAwareRouterPluginType}, + latencypredictor: latencypredictor, + runningRequestLists: make(map[types.NamespacedName]*RequestPriorityQueue), + sloContextStore: sync.Map{}, + headroomStrategy: strategy, + } +} + +func (s *SLOAwareRouter) TypedName() plugins.TypedName { + return s.tn +} + +func (s *SLOAwareRouter) WithName(name string) *SLOAwareRouter { + s.tn.Name = name + return s +} + +// SetHeadroomStrategy allows runtime configuration of headroom selection strategy +func (s *SLOAwareRouter) SetHeadroomStrategy(strategy HeadroomStrategy) { + s.headroomStrategy = strategy +} + +// GetHeadroomStrategy returns the current headroom selection strategy +func (s *SLOAwareRouter) GetHeadroomStrategy() HeadroomStrategy { + return s.headroomStrategy +} + +func (s *SLOAwareRouter) epsilonGreedyAffinityGate( + ctx context.Context, + candidates []PodPredictionResult, + r *rand.Rand, + label string, // e.g. "positive" or "negative" + prefixStickyThreshold float64, +) ([]PodPredictionResult, bool) { + logger := log.FromContext(ctx) + + eligible := make([]PodPredictionResult, 0, len(candidates)) + for _, p := range candidates { + if p.PrefixCacheScore >= prefixStickyThreshold { + eligible = append(eligible, p) + } + } + + // No eligible sticky pods? Explore (no gating). + if len(eligible) == 0 { + return candidates, false + } + + // ε-exploration branch + if r.Float64() < EpsilonExploreSticky { + logger.V(logutil.DEBUG).Info("ε-greedy: exploring (ignoring affinity gate)", + "path", label, "epsilon", EpsilonExploreSticky, "eligibleCount", len(eligible)) + return candidates, false + } + + logger.V(logutil.DEBUG).Info("ε-greedy: exploiting (apply affinity gate)", + "path", label, "threshold", prefixStickyThreshold, "eligibleCount", len(eligible), "total", len(candidates)) + return eligible, true +} + +func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) map[schedulingtypes.Pod]float64 { + logger := log.FromContext(ctx) + if s.latencypredictor == nil { + logger.V(logutil.DEBUG).Info("SLOAwareRouter: no predictor configured, returning nil scores") + return nil + } + + sloCtx := s.getOrMakeSLORequestContext(request) + + var err error + // get request slos + // Get Request SLOs from request header + sloCtx.TTFTSLO, _, err = parseFloatHeader(*request, TTFTSLOHeaderKey) + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", TTFTSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TTFT SLO from header") + } + + sloCtx.AvgTPOTSLO, _, err = parseFloatHeader(*request, TPOTSLOHeaderKey) + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", TPOTSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TPOT SLO from header") + } + sloCtx.PredictorBasedScheduling, err = parseBoolHeader(*request, "x-prediction-based-scheduling") + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("x-prediction-based-scheduling must be a bool: %v", err)}, "SLOAwareRouter: Error parsing PredictorBasedScheduling from header") + } + + // Check if SLOs are provided + if !sloCtx.PredictorBasedScheduling { + logger.V(logutil.DEBUG).Info("PredictorBasedScheduling turned off, skipping prediction-based filtering") + return nil + } + + predictions := s.generatePredictions(ctx, state, request, sloCtx, pods) + s.updateRequestContextWithPredictions(sloCtx, predictions) + + allPreds := append([]PodPredictionResult(nil), predictions...) + + // Initialize scores map with all pods having score 0 + scores := make(map[schedulingtypes.Pod]float64, len(pods)) + for _, pod := range pods { + scores[pod] = 0 + } + + source := rand.NewSource(time.Now().UnixNano()) + r := rand.New(source) + allPreds, sticky := s.epsilonGreedyAffinityGate(ctx, allPreds, r, "overall", AffinityGateTauGlobal) + + // Check if all pods are invalid and all have running requests + allPodsInvalid := true + allPodsHaveRunningRequests := true + + for _, pred := range allPreds { + if pred.IsValid { + allPodsInvalid = false + } + + runningRequestCount := s.getPodRunningRequestCount(pred.Pod) + if runningRequestCount == 0 { + allPodsHaveRunningRequests = false + } + } + + // Set HasValidPod to false if all pods are invalid and all have running requests + if allPodsInvalid && allPodsHaveRunningRequests && !sticky { + sloCtx.HasValidPod = false + logger.V(logutil.DEBUG).Info("All pods are invalid and have running requests, setting HasValidPod to false") + } + + // 2) Tiered selection: positive headroom pods get 99% probability, negative get 1% + var posHeadroomPods, negHeadroomPods []PodPredictionResult + for _, p := range allPreds { + // A pod has positive headroom only if BOTH TTFT and TPOT have positive headroom + if p.Headroom > 0 && p.TTFTHeadroom > 0 { + posHeadroomPods = append(posHeadroomPods, p) + } else { + // A pod has negative headroom if EITHER TTFT or TPOT has negative/zero headroom + negHeadroomPods = append(negHeadroomPods, p) + } + } + + logger.V(logutil.DEBUG).Info("Pod headroom distribution", + "positivePods", len(posHeadroomPods), + "negativePods", len(negHeadroomPods)) + + var selectedPod schedulingtypes.Pod + + if s.headroomStrategy == HeadroomStrategyCompositeOnly { + logger.V(logutil.DEBUG).Info("Selecting from composite scores only") + selectedPod = s.selectFromCompositeScores(ctx, allPreds, r, HeadroomStrategyCompositeOnly) + } else if len(posHeadroomPods) > 0 && len(negHeadroomPods) > 0 { + // 99% chance to select from positive headroom pods, 1% from negative + if r.Float64() < EpsilonExploreNeg { + logger.V(logutil.DEBUG).Info("Selecting from negative headroom pods (1% chance)") + selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) + } else { + logger.V(logutil.DEBUG).Info("Selecting from positive headroom pods (99% chance)") + selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) + } + } else if len(posHeadroomPods) > 0 { + // If only positive headroom pods exist, select from them + logger.V(logutil.DEBUG).Info("Only positive headroom pods available") + selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) + } else if len(negHeadroomPods) > 0 { + // If only negative headroom pods exist, select from them + logger.V(logutil.DEBUG).Info("Only negative headroom pods available") + selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) + } else if len(allPreds) > 0 { + // fallback - select randomly from valid pods + logger.V(logutil.DEBUG).Info("No headroom pods available, selecting randomly from valid pods") + selectedPod = allPreds[r.Intn(len(allPreds))].Pod + } else { + // No valid pods - return all zeros + logger.V(logutil.DEBUG).Info("No valid pods available, returning all zero scores") + return scores + } + + // Set score = 1 for selected pod, 0 for all others + if selectedPod != nil { + scores[selectedPod] = 1 + logger.V(logutil.DEBUG).Info("Selected pod for scheduling", "pod", selectedPod.GetPod().String()) + } + + s.setSLOContextForRequest(request, sloCtx) + + return scores +} + +func (t *SLOAwareRouter) getOrMakeSLORequestContext(request *schedulingtypes.LLMRequest) *SLORequestContext { + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + sloCtx = NewSLORequestContext(request) + } + return sloCtx +} + +func (s *SLOAwareRouter) getPrefixCacheScoreForPod(ctx context.Context, cycleState *schedulingtypes.CycleState, pod schedulingtypes.Pod) float64 { + log.FromContext(ctx).V(logutil.DEBUG).Info("Running getPrefixCacheScoreForPod, getting prefix cache score for pod", "pod", pod.GetPod().String()) + plugintype := prefix.PrefixCachePluginType + pluginname := prefix.PrefixCachePluginType + cycleStateKey := (plugins.TypedName{Type: plugintype, Name: pluginname}).String() + stateData, err := cycleState.Read(plugins.StateKey(cycleStateKey)) + + log.FromContext(ctx).V(logutil.DEBUG).Info("Reading prefix cache state from cycle state", "stateKey", cycleStateKey) + + if err != nil { + // The prefix cache plugin might not be enabled, which is a valid scenario. + log.FromContext(ctx).V(logutil.DEBUG).Info("Prefix cache state not found in cycle state, returning prefix cache score of 0.0", "pod", pod.GetPod().String()) + return 0.0 + } + + prefixCacheState, ok := stateData.(*prefix.SchedulingContextState) + if !ok { + // This should not happen if the plugin is configured correctly. + log.FromContext(ctx).Error(fmt.Errorf("unexpected state type: %T", stateData), "failed to read prefix cache state") + return 0.0 + } + + total := len(prefixCacheState.PrefixHashes) + if total == 0 { + // if the request has no prefixes, return 0.0 + log.FromContext(ctx).V(logutil.DEBUG).Info("No prefixes found in request, returning prefix cache score of 0.0") + return 0.0 + } + + matchLen := prefixCacheState.PrefixCacheServers[prefix.ServerID(pod.GetPod().NamespacedName)] + log.FromContext(ctx).V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "matchLen", matchLen, "totalPrefixes", total) + return float64(matchLen) / float64(total) +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go new file mode 100644 index 000000000..34618ce19 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go @@ -0,0 +1,381 @@ +/* +© 2025 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "context" + "math" + "math/rand" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/log" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// selectFromPositiveHeadroomPods selects a pod from positive headroom pods using headroom strategy +// Updated to incorporate TTFTHeadroom with a configurable blend vs TPOT headroom. +func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, posHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + logger := log.FromContext(ctx) + + if len(posHeadroomPods) == 1 { + return posHeadroomPods[0].Pod + } + + // Apply perfect stickiness (with exploration) + candidates, sticky := s.epsilonGreedyAffinityGate(ctx, posHeadroomPods, r, "positive", AffinityGateTau) + + // If perfect stickiness collapsed us to a single pod, short-circuit + if sticky && len(candidates) == 1 { + return candidates[0].Pod + } + switch s.headroomStrategy { + case HeadroomStrategyCompositeMost: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeMost) + case HeadroomStrategyCompositeLeast: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeLeast) + } + + // Find min/max for TPOT (Headroom) and TTFTHeadroom across positive pods to normalize to [0,1] + minTPOTH, maxTPOTH := math.MaxFloat64, -math.MaxFloat64 + minTTFTH, maxTTFTH := math.MaxFloat64, -math.MaxFloat64 + + for _, p := range candidates { + if p.Headroom < minTPOTH { + minTPOTH = p.Headroom + } + if p.Headroom > maxTPOTH { + maxTPOTH = p.Headroom + } + if p.TTFTHeadroom < minTTFTH { + minTTFTH = p.TTFTHeadroom + } + if p.TTFTHeadroom > maxTTFTH { + maxTTFTH = p.TTFTHeadroom + } + } + + tpotRange := maxTPOTH - minTPOTH + ttftRange := maxTTFTH - minTTFTH + + // Precompute blend weights (renormalize if user sets both to 0) + alpha := HeadroomTTFTWeight + beta := HeadroomTPOTWeight + if alpha+beta <= 0 { + alpha = 1.0 + beta = 0.0 + } + sum := alpha + beta + alpha /= sum + beta /= sum + + logger.V(logutil.DEBUG).Info("Positive headroom normalization ranges", + "minTPOTHeadroom", minTPOTH, "maxTPOTHeadroom", maxTPOTH, + "minTTFTHeadroom", minTTFTH, "maxTTFTHeadroom", maxTTFTH, + "alphaTTFT", alpha, "betaTPOT", beta, "strategy", s.headroomStrategy) + + // Calculate weights for weighted random selection + weightedChoices := make([]Choice, 0, len(candidates)) + total := 0 + + for _, p := range candidates { + // Normalize to [0,1] within the cohort + nTPOTH := 0.5 + if tpotRange > eps { + nTPOTH = (p.Headroom - minTPOTH) / (tpotRange + eps) + } + nTTFTH := 0.5 + if ttftRange > eps { + nTTFTH = (p.TTFTHeadroom - minTTFTH) / (ttftRange + eps) + } + + // Blend: larger combined -> "safer"; smaller -> "tighter packing" + combined := alpha*nTTFTH + beta*nTPOTH + + // Map to integer weights + var w int + switch s.headroomStrategy { + case HeadroomStrategyLeast: + // prefer smaller combined headroom (pack closer to limits) + w = int((1.0-combined)*float64(Wmax-minWeight)) + minWeight + 1 + case HeadroomStrategyMost: + // prefer larger combined headroom (more conservative / spread) + w = int(combined*float64(Wmax-minWeight)) + minWeight + 1 + default: + // Fallback to least + w = int((1.0-combined)*float64(Wmax-minWeight)) + minWeight + 1 + } + + weightedChoices = append(weightedChoices, Choice{PodName: p.Pod, Weight: w}) + total += w + + logger.V(logutil.TRACE).Info("Positive headroom blended weight", + "pod", p.Pod.GetPod().String(), + "ttftHeadroom", p.TTFTHeadroom, "normTTFTHeadroom", nTTFTH, + "tpotHeadroom", p.Headroom, "normTPOTHeadroom", nTPOTH, + "combined", combined, "weight", w) + } + + return s.performWeightedRandomSelection(weightedChoices, total, candidates, r) + +} + +// selectFromNegativeHeadroomPods selects a pod from negative headroom pods using hierarchical TTFT/TPOT logic +// Modified to strictly prefer pods with 0 running requests +func (s *SLOAwareRouter) selectFromNegativeHeadroomPods(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + logger := log.FromContext(ctx) + + if len(negHeadroomPods) == 1 { + return negHeadroomPods[0].Pod + } + + // First, separate pods by running request count + var zeroRunningRequestPods, nonZeroRunningRequestPods []PodPredictionResult + + for _, p := range negHeadroomPods { + runningRequestCount := s.getPodRunningRequestCount(p.Pod) + if runningRequestCount == 0 { + zeroRunningRequestPods = append(zeroRunningRequestPods, p) + } else { + nonZeroRunningRequestPods = append(nonZeroRunningRequestPods, p) + } + } + + logger.V(logutil.DEBUG).Info("Negative headroom pods by running request count", + "zeroRunningRequests", len(zeroRunningRequestPods), + "nonZeroRunningRequests", len(nonZeroRunningRequestPods)) + + // If we have pods with 0 running requests, strictly prefer them + if len(zeroRunningRequestPods) > 0 { + logger.V(logutil.DEBUG).Info("Selecting from pods with zero running requests") + return s.selectFromNegativeHeadroomPodsInternal(ctx, zeroRunningRequestPods, r) + } + + // Otherwise, fall back to pods with running requests + logger.V(logutil.DEBUG).Info("No pods with zero running requests, selecting from pods with running requests") + return s.selectFromNegativeHeadroomPodsInternal(ctx, nonZeroRunningRequestPods, r) +} + +// selectFromNegativeHeadroomPodsInternal handles the actual selection logic for negative headroom pods +func (s *SLOAwareRouter) selectFromNegativeHeadroomPodsInternal(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + if len(negHeadroomPods) == 1 { + return negHeadroomPods[0].Pod + } + + // Apply perfect stickiness (with exploration) + candidates, sticky := s.epsilonGreedyAffinityGate(ctx, negHeadroomPods, r, "negative", AffinityGateTau) + + // If perfect stickiness collapsed us to a single pod, short-circuit + if sticky && len(candidates) == 1 { + return candidates[0].Pod + } + + switch s.headroomStrategy { + case HeadroomStrategyCompositeMost: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeMost) + case HeadroomStrategyCompositeLeast: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeMost) + } + + // Build weighted choices for selection + weightedChoices := make([]Choice, 0, len(candidates)) + total := 0 + + s.handleNegativeHeadroomPodsHierarchical(ctx, candidates, &weightedChoices, &total, minWeight) + + // Perform weighted random selection + return s.performWeightedRandomSelection(weightedChoices, total, candidates, r) +} + +// weightPodsByBlendedDeficit applies blended weighting using TTFT and TPOT deficits. +// Lower blended deficit => higher weight. +func (ps *SLOAwareRouter) weightPodsByBlendedDeficit( + ctx context.Context, + pods []PodPredictionResult, + choices *[]Choice, + total *int, + minWeight int, + alpha, beta float64, // weights for TTFT and TPOT deficits + category string, +) { + logger := log.FromContext(ctx) + if len(pods) == 0 { + return + } + + const Wrange = 80 + const eps = 1e-9 + + // Compute raw deficits (only when headroom is negative) + type deficits struct { + pod PodPredictionResult + ttftDef float64 + tpotDef float64 + } + defs := make([]deficits, 0, len(pods)) + + minTTFT, maxTTFT := math.MaxFloat64, -math.MaxFloat64 + minTPOT, maxTPOT := math.MaxFloat64, -math.MaxFloat64 + + for _, p := range pods { + ttftDef := 0.0 + if p.TTFTHeadroom < 0 { + ttftDef = -p.TTFTHeadroom + } + tpotDef := 0.0 + if p.Headroom < 0 { + tpotDef = -p.Headroom + } + defs = append(defs, deficits{pod: p, ttftDef: ttftDef, tpotDef: tpotDef}) + + if ttftDef < minTTFT { + minTTFT = ttftDef + } + if ttftDef > maxTTFT { + maxTTFT = ttftDef + } + if tpotDef < minTPOT { + minTPOT = tpotDef + } + if tpotDef > maxTPOT { + maxTPOT = tpotDef + } + } + + ttftRange := maxTTFT - minTTFT + tpotRange := maxTPOT - minTPOT + + // Normalize alpha/beta + if alpha+beta <= 0 { + alpha, beta = 1.0, 0.0 + } else { + sum := alpha + beta + alpha /= sum + beta /= sum + } + + logger.V(logutil.DEBUG).Info("Negative headroom blended deficits", + "category", category, + "minTTFTDef", minTTFT, "maxTTFTDef", maxTTFT, + "minTPOTDef", minTPOT, "maxTPOTDef", maxTPOT, + "alphaTTFT", alpha, "betaTPOT", beta, "podCount", len(pods)) + + for _, d := range defs { + // Normalize deficits to [0,1] within this bucket (0 = best / least violation) + nTTFT := 0.0 + if ttftRange > eps { + nTTFT = (d.ttftDef - minTTFT) / (ttftRange + eps) + } + nTPOT := 0.0 + if tpotRange > eps { + nTPOT = (d.tpotDef - minTPOT) / (tpotRange + eps) + } + + // Blended "badness": higher = worse violation + blended := alpha*nTTFT + beta*nTPOT + + // Convert to selection weight: lower badness -> higher weight + // Ensure a floor so no pod is completely excluded within the bucket. + w := int((1.0-blended)*float64(Wrange)) + minWeight + 1 + + *choices = append(*choices, Choice{PodName: d.pod.Pod, Weight: w}) + *total += w + + logger.V(logutil.TRACE).Info("Negative bucket blended weighting", + "pod", d.pod.Pod.GetPod().String(), + "ttftDef", d.ttftDef, "tpotDef", d.tpotDef, + "normTTFT", nTTFT, "normTPOT", nTPOT, + "blendedBadness", blended, "weight", w) + } +} + +func (s *SLOAwareRouter) handleNegativeHeadroomPodsHierarchical( + ctx context.Context, + negHeadroomPods []PodPredictionResult, + choices *[]Choice, + total *int, + minWeightForNegative int, +) { + logger := log.FromContext(ctx) + + // Categorize pods by their headroom status + var negTTFTNegTPOT, negTTFTNonNegTPOT, nonNegTTFTNegTPOT, nonNegTTFTNonNegTPOT []PodPredictionResult + + for _, p := range negHeadroomPods { + if p.TTFTHeadroom < 0 && p.Headroom < 0 { + negTTFTNegTPOT = append(negTTFTNegTPOT, p) + } else if p.TTFTHeadroom < 0 && p.Headroom >= 0 { + negTTFTNonNegTPOT = append(negTTFTNonNegTPOT, p) + } else if p.TTFTHeadroom >= 0 && p.Headroom < 0 { + nonNegTTFTNegTPOT = append(nonNegTTFTNegTPOT, p) + } else { + nonNegTTFTNonNegTPOT = append(nonNegTTFTNonNegTPOT, p) + } + } + + logger.V(logutil.DEBUG).Info("Hierarchical negative headroom pod distribution", + "totalNegative", len(negHeadroomPods), + "negTTFT_negTPOT", len(negTTFTNegTPOT), + "negTTFT_nonNegTPOT", len(negTTFTNonNegTPOT), + "nonNegTTFT_negTPOT", len(nonNegTTFTNegTPOT), + "nonNegTTFT_nonNegTPOT", len(nonNegTTFTNonNegTPOT)) + + // Priority 1: both TTFT and TPOT negative -> blended deficits (both active) + if len(negTTFTNegTPOT) > 0 { + s.weightPodsByBlendedDeficit(ctx, negTTFTNegTPOT, choices, total, minWeightForNegative, + NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "both_negative") + } + + // Priority 2: TTFT negative, TPOT non-negative -> blended still works (TPOT deficit=0) + if len(negTTFTNonNegTPOT) > 0 { + s.weightPodsByBlendedDeficit(ctx, negTTFTNonNegTPOT, choices, total, minWeightForNegative, + NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "ttft_negative") + } + + // Priority 3: TTFT non-negative, TPOT negative -> blended (TTFT deficit=0) + if len(nonNegTTFTNegTPOT) > 0 { + s.weightPodsByBlendedDeficit(ctx, nonNegTTFTNegTPOT, choices, total, minWeightForNegative, + NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "tpot_negative") + } + + // Priority 4: edge-case bucket -> minimal weight + for _, p := range nonNegTTFTNonNegTPOT { + *choices = append(*choices, Choice{PodName: p.Pod, Weight: minWeightForNegative}) + *total += minWeightForNegative + } +} + +func (s *SLOAwareRouter) getPodMinTPOTSLO(pod schedulingtypes.Pod) float64 { + podName := types.NamespacedName{ + Name: pod.GetPod().NamespacedName.Name, + Namespace: pod.GetPod().NamespacedName.Namespace, + } + if runningReqs, ok := s.runningRequestLists[podName]; ok && runningReqs.GetSize() > 0 { + if topReq := runningReqs.Peek(); topReq != nil { + return topReq.TPOT + } + } + return 0 // no running requests or no TPOT SLOs +} + +func (s *SLOAwareRouter) getPodRunningRequestCount(pod schedulingtypes.Pod) int { + podName := types.NamespacedName{ + Name: pod.GetPod().NamespacedName.Name, + Namespace: pod.GetPod().NamespacedName.Namespace, + } + if runningReqs, ok := s.runningRequestLists[podName]; ok { + return runningReqs.GetSize() + } + return 0 // no running requests +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go new file mode 100644 index 000000000..036b5ee76 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go @@ -0,0 +1,53 @@ +/* +© 2025 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + +type HeadroomStrategy string + +type Choice struct { + PodName schedulingtypes.Pod + Weight int +} + +const ( + // HeadroomStrategyLeast prioritizes pods with least positive headroom (better packing) + HeadroomStrategyLeast HeadroomStrategy = "least" + // HeadroomStrategyMost prioritizes pods with most positive headroom (more conservative) + HeadroomStrategyMost HeadroomStrategy = "most" + + HeadroomStrategyCompositeLeast HeadroomStrategy = "composite-least" + HeadroomStrategyCompositeMost HeadroomStrategy = "composite-most" + HeadroomStrategyCompositeOnly HeadroomStrategy = "composite-only" + + // TTFT header string + TTFTSLOHeaderKey = "x-slo-ttft-ms" + // TPOT header string + TPOTSLOHeaderKey = "x-slo-tpot-ms" +) + +const ( + SLOAwareRouterPluginType = "slo-aware-routing" + eps = 1e-9 + Wmax = 100 + minWeight = 1 +) + +type PodSelectionMode string + +const ( + PodSelectionLinear PodSelectionMode = "linear" // weighted-random (current behavior) + PodSelectionMax PodSelectionMode = "max" // pick argmax weight +) diff --git a/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go b/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go new file mode 100644 index 000000000..900335c9e --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go @@ -0,0 +1,154 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package profile + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +const ( + SLOAwareProfileHandlerType = "slo-aware-profile-handler" + DefaultProfileName = "default" + PrefixProfileName = "prefix" + SLOProfileName = "slo" + + // Boolean header string for whether to use predictor based scheduling + PreictionBasedSchedulingHeaderKey = "x-prediction-based-scheduling" +) + +// compile-time type assertion +var _ framework.ProfileHandler = &SLOAwareProfileHandler{} + +// SLOAwareProfileHandlerFactory defines the factory function for SLOAwareProfileHandler. +func SLOAwareProfileHandlerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + return NewSLOAwareProfileHandler().WithName(name), nil +} + +// NewSLOAwareProfileHandler initializes a new SLOAwareProfileHandler and returns its pointer. +func NewSLOAwareProfileHandler() *SLOAwareProfileHandler { + return &SLOAwareProfileHandler{ + typedName: plugins.TypedName{Type: SLOAwareProfileHandlerType, Name: SLOAwareProfileHandlerType}, + } +} + +// SLOAwareProfileHandler handles two profiles: the default profile and the SLO profile. +// When the request has PredictorBasedScheduling=true, it uses the SLO profile result to select +// the destination pod. Otherwise, it uses the default profile result. +type SLOAwareProfileHandler struct { + typedName plugins.TypedName + prefixProfile string // the profile that should be executed first + +} + +// TypedName returns the type and name tuple of this plugin instance. +func (h *SLOAwareProfileHandler) TypedName() plugins.TypedName { + return h.typedName +} + +// WithName sets the name of the profile handler. +func (h *SLOAwareProfileHandler) WithName(name string) *SLOAwareProfileHandler { + h.typedName.Name = name + return h +} + +// Pick selects the SchedulingProfiles to run from the list of candidate profiles, while taking into consideration the request properties and the +// previously executed cycles along with their results. +func (h *SLOAwareProfileHandler) Pick(_ context.Context, _ *types.CycleState, request *types.LLMRequest, profiles map[string]*framework.SchedulerProfile, + profileResults map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile { + if len(profiles) == len(profileResults) { // all profiles have been executed already in previous call + return map[string]*framework.SchedulerProfile{} + } + + if _, executed := profileResults[PrefixProfileName]; !executed { + // if prefix profile was not executed yet, first let the scheduler run the decode profile + return map[string]*framework.SchedulerProfile{ + PrefixProfileName: profiles[PrefixProfileName], + } + } + // otherwise, prefix was already executed. + + // return all profiles except prefix. + profilesToRun := make(map[string]*framework.SchedulerProfile) + for name, profile := range profiles { + if name != PrefixProfileName { + profilesToRun[name] = profile + } + } + return profilesToRun +} + +// ProcessResults handles the outcome of the profile runs after all profiles ran. +// It may aggregate results, log test profile outputs, or apply custom logic. It specifies in the SchedulingResult the +// key of the primary profile that should be used to get the request selected destination. +// When a profile run fails, its result in the profileResults map is nil. +func (h *SLOAwareProfileHandler) ProcessResults(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) { + + if len(profileResults) < 2 { + return nil, errors.New("SLOAwareProfileHandler requires at least two profiles to operate") + } + + predictorBasedScheduling, err := parseBoolHeader(*request, PreictionBasedSchedulingHeaderKey) + if err != nil { + return nil, fmt.Errorf("error parsing predictorBasedScheduling from header failed to choose scheduling profile: x-prediction-based-scheduling must be a bool: %v", err) + } + + if predictorBasedScheduling { // TODO grab header directly from request.Headers instead of request field + if profileResults[SLOProfileName] == nil { // there was an error while running the SLO profile + return nil, fmt.Errorf("failed to run scheduler profile '%s'", SLOProfileName) + } + return &types.SchedulingResult{ + ProfileResults: profileResults, + PrimaryProfileName: SLOProfileName, + }, nil + } + + if profileResults[DefaultProfileName] == nil { // there was an error while running the default profile + return nil, fmt.Errorf("failed to run scheduler profile '%s'", DefaultProfileName) + } + + return &types.SchedulingResult{ + ProfileResults: profileResults, + PrimaryProfileName: DefaultProfileName, + }, nil +} + +// parseFloatHeader retrieves a header by name, parses it as a bool, +// and returns the value or an error if the header is missing or invalid. +func parseBoolHeader(request types.LLMRequest, headerName string) (bool, error) { + // 1. Get header value from the map + headerValue, ok := request.Headers[headerName] + if !ok { + return false, nil // Header not found, return 0 and false + } + + // 2. Parse the header value to a bool + parsedBool, err := strconv.ParseBool(headerValue) + if err != nil { + return false, fmt.Errorf("must be a bool: %v", headerName) + } + + // 3. Return the successfully parsed value + return parsedBool, nil +} diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index c3037175e..014c51d82 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -43,6 +43,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" ) // ExtProcServerRunner provides methods to manage an external process server. @@ -59,6 +60,7 @@ type ExtProcServerRunner struct { Director *requestcontrol.Director SaturationDetector *saturationdetector.Detector UseExperimentalDatalayerV2 bool // Pluggable data layer feature flag + LatencyPredictor latencypredictor.PredictorInterface // This should only be used in tests. We won't need this once we do not inject metrics in the tests. // TODO:(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/432) Cleanup @@ -78,6 +80,7 @@ const ( DefaultHealthChecking = false // default for --health-checking DefaultEnablePprof = true // default for --enable-pprof DefaultTotalQueuedRequestsMetric = "vllm:num_requests_waiting" // default for --total-queued-requests-metric + DefaultTotalRunningRequestsMetric = "vllm:num_requests_running" // default for --total-running-requests-metric DefaultKvCacheUsagePercentageMetric = "vllm:gpu_cache_usage_perc" // default for --kv-cache-usage-percentage-metric DefaultLoraInfoMetric = "vllm:lora_requests_info" // default for --lora-info-metric DefaultCacheInfoMetric = "vllm:cache_config_info" // default for --cache-info-metric diff --git a/sidecars/latencypredictorasync/types.go b/sidecars/latencypredictorasync/types.go index 4b4a1ca0b..c8eadefe2 100644 --- a/sidecars/latencypredictorasync/types.go +++ b/sidecars/latencypredictorasync/types.go @@ -120,7 +120,6 @@ type PredictorInterface interface { PredictBulk(ctx context.Context, requests []PredictionRequest) (*BulkPredictionResponse, error) PredictBulkStrict(ctx context.Context, requests []PredictionRequest) (*BulkPredictionResponse, error) AddTrainingDataBulk(entry []TrainingEntry) error - GetServerStatus(ctx context.Context) (*ServerStatusResponse, error) } // --- Data Models --- From dc5b077e1a35605c7d879e606e10ae1bde4e84da Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Fri, 7 Nov 2025 00:55:31 +0000 Subject: [PATCH 2/6] Update dockerfile, fix issues with SLO context not being set when prediciton id off --- Dockerfile | 1 + .../manifests/inferencepool-resources-lp.yaml | 384 ++++++++++++- .../plugins/multi/slo_aware_router/config.go | 18 + .../plugins/multi/slo_aware_router/helpers.go | 2 +- .../latencypredictor_helper.go | 28 +- .../multi/slo_aware_router/prediction.go | 22 +- .../slo_aware_router/requestcontrol_hooks.go | 25 +- .../plugins/multi/slo_aware_router/scorer.go | 71 ++- .../multi/slo_aware_router/scorer_test.go | 525 ++++++++++++++++++ 9 files changed, 1011 insertions(+), 65 deletions(-) create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go diff --git a/Dockerfile b/Dockerfile index d2fd2300b..70038a077 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,6 +24,7 @@ COPY internal ./internal COPY apix ./apix COPY api ./api COPY version ./version +COPY sidecars ./sidecars WORKDIR /src/cmd/epp RUN go build -ldflags="-X sigs.k8s.io/gateway-api-inference-extension/version.CommitSHA=${COMMIT_SHA} -X sigs.k8s.io/gateway-api-inference-extension/version.BuildRef=${BUILD_REF}" -o /epp diff --git a/config/manifests/inferencepool-resources-lp.yaml b/config/manifests/inferencepool-resources-lp.yaml index c2a3528de..8a7951e96 100644 --- a/config/manifests/inferencepool-resources-lp.yaml +++ b/config/manifests/inferencepool-resources-lp.yaml @@ -17,6 +17,7 @@ data: LATENCY_TPOT_SCALER_PATH: "/models/tpot_scaler.joblib" LATENCY_MODEL_TYPE: "xgboost" LATENCY_MAX_TRAINING_DATA_SIZE_PER_BUCKET: "5000" + LATENCY_QUANTILE_ALPHA: "0.9" --- apiVersion: v1 kind: ConfigMap @@ -74,6 +75,34 @@ spec: protocol: TCP port: 8003 targetPort: 8003 + - name: latency-predictor-4 + protocol: TCP + port: 8004 + targetPort: 8004 + - name: latency-predictor-5 + protocol: TCP + port: 8005 + targetPort: 8005 + - name: latency-predictor-6 + protocol: TCP + port: 8006 + targetPort: 8006 + - name: latency-predictor-7 + protocol: TCP + port: 8007 + targetPort: 8007 + - name: latency-predictor-8 + protocol: TCP + port: 8008 + targetPort: 8008 + - name: latency-predictor-9 + protocol: TCP + port: 8009 + targetPort: 8009 + - name: latency-predictor-10 + protocol: TCP + port: 8010 + targetPort: 8010 - name: prometheus protocol: TCP port: 9090 @@ -106,7 +135,6 @@ spec: spec: serviceAccountName: vllm-llama3-8b-instruct-epp # Conservatively, this timeout should mirror the longest grace period of the pods within the pool - terminationGracePeriodSeconds: 130 containers: # EPP Container - name: epp @@ -132,15 +160,15 @@ spec: - "-enable-latency-predictor" env: - name: PREDICTION_SERVER_URL - value: "http://localhost:8001,http://localhost:8002,http://localhost:8003" # Multiple prediction servers + value: "http://localhost:8001,http://localhost:8002,http://localhost:8003,http://localhost:8004,http://localhost:8005,http://localhost:8006,http://localhost:8007,http://localhost:8008,http://localhost:8009,http://localhost:8010" # All 10 prediction servers - name: TRAINING_SERVER_URL value: "http://localhost:8000" # Single training server for sending training data - name: LATENCY_MAX_SAMPLE_SIZE value: "10000" # Maximum sample size for latency prediction - - name: NEG_HEADROOM_TPOT_WEIGHT - value: "0.2" # Weight for TPOT in negative headroom calculation - - name: NEG_HEADROOM_TTFT_WEIGHT - value: "0.8" # Weight for TTFT in negative headroom calculation + + + + ports: - containerPort: 9002 - containerPort: 9003 @@ -338,6 +366,328 @@ spec: volumeMounts: - name: prediction-server-3-storage mountPath: /server_models + # Prediction Server Sidecar Container 4 + - name: prediction-server-4 + image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest + imagePullPolicy: Always + command: ["uvicorn"] + args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8004"] + ports: + - containerPort: 8004 + name: predict-port-4 + livenessProbe: + httpGet: + path: /healthz + port: 8004 + initialDelaySeconds: 15 + periodSeconds: 15 + readinessProbe: + httpGet: + path: /readyz + port: 8004 + initialDelaySeconds: 10 + periodSeconds: 5 + failureThreshold: 10 + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" + envFrom: + - configMapRef: + name: prediction-server-config + env: + - name: PREDICT_PORT + value: "8004" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: SERVER_TYPE + value: "prediction-4" + - name: TRAINING_SERVER_URL + value: "http://localhost:8000" + volumeMounts: + - name: prediction-server-4-storage + mountPath: /server_models + # Prediction Server Sidecar Container 5 + - name: prediction-server-5 + image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest + imagePullPolicy: Always + command: ["uvicorn"] + args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8005"] + ports: + - containerPort: 8005 + name: predict-port-5 + livenessProbe: + httpGet: + path: /healthz + port: 8005 + initialDelaySeconds: 15 + periodSeconds: 15 + readinessProbe: + httpGet: + path: /readyz + port: 8005 + initialDelaySeconds: 10 + periodSeconds: 5 + failureThreshold: 10 + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" + envFrom: + - configMapRef: + name: prediction-server-config + env: + - name: PREDICT_PORT + value: "8005" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: SERVER_TYPE + value: "prediction-5" + - name: TRAINING_SERVER_URL + value: "http://localhost:8000" + volumeMounts: + - name: prediction-server-5-storage + mountPath: /server_models + # Prediction Server Sidecar Container 6 + - name: prediction-server-6 + image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest + imagePullPolicy: Always + command: ["uvicorn"] + args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8006"] + ports: + - containerPort: 8006 + name: predict-port-6 + livenessProbe: + httpGet: + path: /healthz + port: 8006 + initialDelaySeconds: 15 + periodSeconds: 15 + readinessProbe: + httpGet: + path: /readyz + port: 8006 + initialDelaySeconds: 10 + periodSeconds: 5 + failureThreshold: 10 + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" + envFrom: + - configMapRef: + name: prediction-server-config + env: + - name: PREDICT_PORT + value: "8006" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: SERVER_TYPE + value: "prediction-6" + - name: TRAINING_SERVER_URL + value: "http://localhost:8000" + volumeMounts: + - name: prediction-server-6-storage + mountPath: /server_models + # Prediction Server Sidecar Container 7 + - name: prediction-server-7 + image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest + imagePullPolicy: Always + command: ["uvicorn"] + args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8007"] + ports: + - containerPort: 8007 + name: predict-port-7 + livenessProbe: + httpGet: + path: /healthz + port: 8007 + initialDelaySeconds: 15 + periodSeconds: 15 + readinessProbe: + httpGet: + path: /readyz + port: 8007 + initialDelaySeconds: 10 + periodSeconds: 5 + failureThreshold: 10 + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" + envFrom: + - configMapRef: + name: prediction-server-config + env: + - name: PREDICT_PORT + value: "8007" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: SERVER_TYPE + value: "prediction-7" + - name: TRAINING_SERVER_URL + value: "http://localhost:8000" + volumeMounts: + - name: prediction-server-7-storage + mountPath: /server_models + # Prediction Server Sidecar Container 8 + - name: prediction-server-8 + image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest + imagePullPolicy: Always + command: ["uvicorn"] + args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8008"] + ports: + - containerPort: 8008 + name: predict-port-8 + livenessProbe: + httpGet: + path: /healthz + port: 8008 + initialDelaySeconds: 15 + periodSeconds: 15 + readinessProbe: + httpGet: + path: /readyz + port: 8008 + initialDelaySeconds: 10 + periodSeconds: 5 + failureThreshold: 10 + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" + envFrom: + - configMapRef: + name: prediction-server-config + env: + - name: PREDICT_PORT + value: "8008" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: SERVER_TYPE + value: "prediction-8" + - name: TRAINING_SERVER_URL + value: "http://localhost:8000" + volumeMounts: + - name: prediction-server-8-storage + mountPath: /server_models + # Prediction Server Sidecar Container 9 + - name: prediction-server-9 + image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest + imagePullPolicy: Always + command: ["uvicorn"] + args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8009"] + ports: + - containerPort: 8009 + name: predict-port-9 + livenessProbe: + httpGet: + path: /healthz + port: 8009 + initialDelaySeconds: 15 + periodSeconds: 15 + readinessProbe: + httpGet: + path: /readyz + port: 8009 + initialDelaySeconds: 10 + periodSeconds: 5 + failureThreshold: 10 + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" + envFrom: + - configMapRef: + name: prediction-server-config + env: + - name: PREDICT_PORT + value: "8009" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: SERVER_TYPE + value: "prediction-9" + - name: TRAINING_SERVER_URL + value: "http://localhost:8000" + volumeMounts: + - name: prediction-server-9-storage + mountPath: /server_models + # Prediction Server Sidecar Container 10 + - name: prediction-server-10 + image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest + imagePullPolicy: Always + command: ["uvicorn"] + args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8010"] + ports: + - containerPort: 8010 + name: predict-port-10 + livenessProbe: + httpGet: + path: /healthz + port: 8010 + initialDelaySeconds: 15 + periodSeconds: 15 + readinessProbe: + httpGet: + path: /readyz + port: 8010 + initialDelaySeconds: 10 + periodSeconds: 5 + failureThreshold: 10 + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" + envFrom: + - configMapRef: + name: prediction-server-config + env: + - name: PREDICT_PORT + value: "8010" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: SERVER_TYPE + value: "prediction-10" + - name: TRAINING_SERVER_URL + value: "http://localhost:8000" + volumeMounts: + - name: prediction-server-10-storage + mountPath: /server_models volumes: - name: training-server-storage emptyDir: @@ -351,6 +701,27 @@ spec: - name: prediction-server-3-storage emptyDir: sizeLimit: "10Gi" # Dedicated volume for prediction server 3 + - name: prediction-server-4-storage + emptyDir: + sizeLimit: "10Gi" # Dedicated volume for prediction server 4 + - name: prediction-server-5-storage + emptyDir: + sizeLimit: "10Gi" # Dedicated volume for prediction server 5 + - name: prediction-server-6-storage + emptyDir: + sizeLimit: "10Gi" # Dedicated volume for prediction server 6 + - name: prediction-server-7-storage + emptyDir: + sizeLimit: "10Gi" # Dedicated volume for prediction server 7 + - name: prediction-server-8-storage + emptyDir: + sizeLimit: "10Gi" # Dedicated volume for prediction server 8 + - name: prediction-server-9-storage + emptyDir: + sizeLimit: "10Gi" # Dedicated volume for prediction server 9 + - name: prediction-server-10-storage + emptyDir: + sizeLimit: "10Gi" # Dedicated volume for prediction server 10 - name: plugins-config-volume configMap: name: plugins-config @@ -386,6 +757,7 @@ data: plugins: - pluginRef: slo-aware-routing - pluginRef: max-score-picker + --- # --- RBAC --- kind: Role diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go index cf9d8ee33..046870948 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go @@ -19,6 +19,24 @@ import ( "strings" ) +var DefaultSamplingMean = func() float64 { + if value, exists := os.LookupEnv("SAMPLING_MEAN"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue > 0 { + return parsedValue + } + } + return 100.0 // default value +}() + +var MaxSampledTokens = func() int { + if value, exists := os.LookupEnv("MAX_SAMPLED_TOKENS"); exists { + if parsedValue, err := strconv.Atoi(value); err == nil && parsedValue > 0 { + return parsedValue + } + } + return 20 // default value +}() + var SLOBufferFactor = func() float64 { if value, exists := os.LookupEnv("SLO_BUFFER_FACTOR"); exists { if parsedValue, err := strconv.ParseFloat(value, 64); err == nil { diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go index 3f02d5e52..1d5568243 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go @@ -133,7 +133,7 @@ func (s *SLOAwareRouter) buildCompositeChoices( *total += w choices = append(choices, Choice{PodName: p.Pod, Weight: w}) - log.FromContext(ctx).V(logutil.DEBUG).Info("Composite (neg/pos) score", + log.FromContext(ctx).V(logutil.TRACE).Info("Composite (neg/pos) score", "pod", p.Pod.GetPod().String(), "kvUsage", kvUsage, "kvFree", kvFree, "queue", q, "relQueue", relQueue, diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go index e9a1b128c..8961bda44 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go @@ -28,12 +28,6 @@ import ( latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" ) -const ( - // Poisson sampling parameters for predictions - defaultSamplingMean = 100 // Mean interval between prediction samples (tokens) - maxSampledTokens = 20 // Maximum number of prediction samples per request -) - // RefreshLastSeenMetrics updates sloCtx.LastSeenMetrics from the latest scheduling result. func RefreshLastSeenMetrics(ctx context.Context, sloCtx *SLORequestContext) { if sr := sloCtx.SchedulingResult; sr != nil { @@ -50,18 +44,11 @@ func RefreshLastSeenMetrics(ctx context.Context, sloCtx *SLORequestContext) { } // GetMetricsForPrediction retrieves the latest metrics for prediction from sloCtx.LastSeenMetrics. -func GetLatestMetricsForProfile(ctx context.Context, sloCtx *SLORequestContext, profileName string) (*backendmetrics.MetricsState, error) { +func GetLatestMetricsForProfile(ctx context.Context, sloCtx *SLORequestContext) (*backendmetrics.MetricsState, error) { if len(sloCtx.LastSeenMetrics) == 0 { return nil, fmt.Errorf("no last seen metrics available for prediction") } - // Use the primary profile's metrics for prediction - if metrics, exists := sloCtx.LastSeenMetrics[profileName]; exists { - return metrics, nil - } - - log.FromContext(ctx).V(logutil.DEBUG).Info("No metrics found for profile, trying primary profile", "profile_name", profileName) - primaryProfileName := sloCtx.SchedulingResult.PrimaryProfileName if metrics, exists := sloCtx.LastSeenMetrics[primaryProfileName]; exists { return metrics, nil @@ -82,8 +69,7 @@ func ProcessHeaderForLatencyPrediction( //print the raw scores in scheduling result // Build prediction request - //check if prefill profile name is set, if not use primary profile name - m, err := GetLatestMetricsForProfile(ctx, sloCtx, "prefill") + m, err := GetLatestMetricsForProfile(ctx, sloCtx) if err != nil { logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err) return err @@ -136,14 +122,14 @@ func ProcessFirstTokenForLatencyPrediction( // Initialize sampler if sloCtx.TokenSampler == nil { requestID := sloCtx.SchedulingRequest.Headers[requtil.RequestIdHeaderKey] - sloCtx.TokenSampler = NewTokenSampler(requestID, defaultSamplingMean, maxSampledTokens) + sloCtx.TokenSampler = NewTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens) logger.V(logutil.DEBUG).Info("Initialized token sampler for first token", "request_id", requestID, "next_prediction_token", sloCtx.TokenSampler.GetNextSampleToken()) } // Actual TTFT sloCtx.TTFT = float64(now.Sub(sloCtx.RequestReceivedTimestamp).Milliseconds()) sloCtx.GeneratedTokenCount = 1 - m, err := GetLatestMetricsForProfile(ctx, sloCtx, "prefill") + m, err := GetLatestMetricsForProfile(ctx, sloCtx) if err != nil { logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err) return @@ -166,7 +152,7 @@ func ProcessFirstTokenForLatencyPrediction( if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil { logger.V(logutil.DEBUG).Error(err, "record TTFT training failed") } - m, err = GetLatestMetricsForProfile(ctx, sloCtx, sloCtx.SchedulingResult.PrimaryProfileName) + m, err = GetLatestMetricsForProfile(ctx, sloCtx) if err != nil { logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics", "error", err) @@ -214,7 +200,7 @@ func ProcessTokenForLatencyPrediction( // Initialize sampler if not yet if sloCtx.TokenSampler == nil { requestID := sloCtx.SchedulingRequest.Headers[requtil.RequestIdHeaderKey] - sloCtx.TokenSampler = NewTokenSampler(requestID, defaultSamplingMean, maxSampledTokens) + sloCtx.TokenSampler = NewTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens) logger.V(logutil.DEBUG).Info("Initialized token sampler for subsequent tokens", "request_id", requestID, "next_prediction_token", sloCtx.TokenSampler.GetNextSampleToken()) } @@ -228,7 +214,7 @@ func ProcessTokenForLatencyPrediction( sloCtx.AvgTPOT = calculateRunningAverage(sloCtx.AvgTPOT, latencyMs, len(sloCtx.TPOTObservations)) } - m, err := GetLatestMetricsForProfile(ctx, sloCtx, sloCtx.SchedulingResult.PrimaryProfileName) + m, err := GetLatestMetricsForProfile(ctx, sloCtx) if err != nil { logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics", "error", err) diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go index dcd9c40e0..7a857fa0d 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go @@ -22,8 +22,21 @@ import ( latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" ) +type PodPredictionResult struct { + Pod schedulingtypes.Pod + TTFT float64 + TPOT float64 + TTFTValid bool + TPOTValid bool + IsValid bool + Error error + Headroom float64 // Headroom for the pod, if applicable + TTFTHeadroom float64 // TTFT headroom for the pod + PrefixCacheScore float64 // Prefix cache score for the pod +} + // generatePredictions creates prediction results for all candidate pods -func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *SLORequestContext, candidatePods []schedulingtypes.Pod) []PodPredictionResult { +func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *SLORequestContext, candidatePods []schedulingtypes.Pod) ([]PodPredictionResult, error) { logger := log.FromContext(ctx) predictions := make([]PodPredictionResult, 0, len(candidatePods)) @@ -42,10 +55,9 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul // Generate prediction prediction, err := PredictWithMetrics(ctx, s.latencypredictor, pod.GetMetrics(), request.Body.Completions.Prompt, 1, prefixCacheScore) if err != nil { - logger.V(logutil.DEBUG).Info("Skipping pod due to prediction error", "pod", pod.GetPod().String(), "error", err) + logger.V(logutil.DEBUG).Error(err, "Skipping pod due to prediction error", "pod", pod.GetPod().String(), "error", err) predResult.Error = err - predictions = append(predictions, predResult) - continue + return nil, err } predResult.PrefixCacheScore = prefixCacheScore predResult.TTFT = prediction.TTFT @@ -76,7 +88,7 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul predictions = append(predictions, predResult) } - return predictions + return predictions, nil } // updateRequestContextWithPredictions updates the request context with prediction data diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go index 17399c6a8..072e2e2fb 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go @@ -6,7 +6,6 @@ import ( "time" "github.com/go-logr/logr" - "github.com/google/uuid" "sigs.k8s.io/controller-runtime/pkg/log" "k8s.io/apimachinery/pkg/types" @@ -93,7 +92,7 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype logger := log.FromContext(ctx) if schedulingResult == nil || len(schedulingResult.ProfileResults) == 0 { - logger.V(logutil.DEBUG).Info("SLOAwareRouter: Skipping PreRequest because no scheduling result was provided.") + logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping PreRequest because no scheduling result was provided.") return } @@ -104,11 +103,9 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype Namespace: targetPod.NamespacedName.Namespace, } - logger.V(logutil.DEBUG).Info("request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey], "podName", podName) + logger.V(logutil.TRACE).Info("request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey], "podName", podName) if request.Headers[requtil.RequestIdHeaderKey] == "" { - request.Headers[requtil.RequestIdHeaderKey] = uuid.New().String() - logger.V(logutil.DEBUG).Info("Generated new request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey]) - logger.V(logutil.DEBUG).Info("request headers for SLO tracking", "requestHeaders", request.Headers) + logger.V(logutil.DEBUG).Error(fmt.Errorf("missing request ID"), "SLOAwareRouter.PreRequest: Request is missing request ID header") } id := request.Headers[requtil.RequestIdHeaderKey] @@ -127,7 +124,7 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype added := podRequestList.Add(id, sloCtx.AvgTPOTSLO) if !added { - logger.V(logutil.DEBUG).Info("SLOAwareRouter: Item already exists in queue", "podName", podName, "requestID", id) + logger.V(logutil.TRACE).Info("SLOAwareRouter: Item already exists in queue", "podName", podName, "requestID", id) } // Set up SLO request context @@ -168,7 +165,7 @@ func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedul sloCtx, err := t.getSLOContextForRequest(request) if err != nil { id := request.Headers[requtil.RequestIdHeaderKey] - logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter.ResponseStreaming: Failed to get SLO context for request", "requestID", id) + logger.V(logutil.TRACE).Error(err, "SLOAwareRouter.ResponseStreaming: Failed to get SLO context for request", "requestID", id) return } @@ -195,7 +192,7 @@ func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *scheduli } if sloCtx.TTFT > 0 { - logger.V(logutil.DEBUG).Info("Averages calculated", "avgActualTTFT", sloCtx.TTFT, "avgPredictedTTFT", sloCtx.PredictedTTFT) + logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTTFT", sloCtx.TTFT, "avgPredictedTTFT", sloCtx.PredictedTTFT) metrics.RecordRequestTTFT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.TTFT/1000) metrics.RecordRequestPredictedTTFT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.PredictedTTFT/1000) if sloCtx.TTFTSLO > 0 { @@ -204,7 +201,7 @@ func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *scheduli } if sloCtx.AvgTPOT > 0 { - logger.V(logutil.DEBUG).Info("Averages calculated", "avgActualTPOT", sloCtx.AvgTPOT, "avgPredictedTPOT", sloCtx.AvgPredictedTPOT) + logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTPOT", sloCtx.AvgTPOT, "avgPredictedTPOT", sloCtx.AvgPredictedTPOT) metrics.RecordRequestTPOT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgTPOT/1000) metrics.RecordRequestPredictedTPOT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgPredictedTPOT/1000) if sloCtx.AvgTPOTSLO > 0 { @@ -212,7 +209,7 @@ func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *scheduli } } - logger.V(logutil.DEBUG).Info("SLO Aware Routing Mode", "PredictorBasedScheduling", sloCtx.PredictorBasedScheduling) + logger.V(logutil.TRACE).Info("SLO Aware Routing Mode", "PredictorBasedScheduling", sloCtx.PredictorBasedScheduling) podName := types.NamespacedName{ Name: targetPod.NamespacedName.Name, @@ -228,18 +225,18 @@ func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *scheduli _, removed := podRequestList.Remove(id) if !removed { - logger.V(logutil.DEBUG).Info("SLOAwareRouter: Item not found in queue", "podName", podName, "requestID", id) + logger.V(logutil.TRACE).Info("SLOAwareRouter: Item not found in queue", "podName", podName, "requestID", id) } t.deleteSLOContextForRequest(request) } func (t *SLOAwareRouter) CheckPredictor(logger logr.Logger, targetPod *backend.Pod) bool { if targetPod == nil { - logger.V(logutil.DEBUG).Info("SLOAwareRouter: Skipping PostResponse because no target pod was provided.") + logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping PostResponse because no target pod was provided.") return false } if t.latencypredictor == nil { - logger.V(logutil.DEBUG).Info("SLOAwareRouter: Skipping PostResponse because predictor missing") + logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping PostResponse because predictor missing") return false } return true diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go index d48d3a1bc..b476579b5 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go @@ -35,19 +35,6 @@ import ( latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" ) -type PodPredictionResult struct { - Pod schedulingtypes.Pod - TTFT float64 - TPOT float64 - TTFTValid bool - TPOTValid bool - IsValid bool - Error error - Headroom float64 // Headroom for the pod, if applicable - TTFTHeadroom float64 // TTFT headroom for the pod - PrefixCacheScore float64 // Prefix cache score for the pod -} - type SLOAwareRouter struct { tn plugins.TypedName latencypredictor latencypredictor.PredictorInterface @@ -120,6 +107,48 @@ func (s *SLOAwareRouter) epsilonGreedyAffinityGate( return eligible, true } +// scoreWithoutPredictions provides fallback scoring based only on prefix cache scores +// when latency predictions are unavailable +func (s *SLOAwareRouter) scoreWithoutPredictions( + ctx context.Context, + state *schedulingtypes.CycleState, + pods []schedulingtypes.Pod, + r *rand.Rand, +) map[schedulingtypes.Pod]float64 { + logger := log.FromContext(ctx) + logger.V(logutil.TRACE).Info("Using composite-only scoring without predictions") + + scores := make(map[schedulingtypes.Pod]float64, len(pods)) + for _, pod := range pods { + scores[pod] = 0 + } + + if len(pods) == 0 { + return scores + } + + // Build prediction results with only prefix cache scores + podResults := make([]PodPredictionResult, 0, len(pods)) + for _, pod := range pods { + prefixScore := s.getPrefixCacheScoreForPod(ctx, state, pod) + podResults = append(podResults, PodPredictionResult{ + Pod: pod, + PrefixCacheScore: prefixScore, + IsValid: true, // All pods are valid when we don't check predictions + }) + } + + // Select based on composite scores (prefix cache + other non-prediction metrics) + selectedPod := s.selectFromCompositeScores(ctx, podResults, r, HeadroomStrategyCompositeOnly) + + if selectedPod != nil { + scores[selectedPod] = 1 + logger.V(logutil.TRACE).Info("Selected pod using composite-only scoring", "pod", selectedPod.GetPod().String()) + } + + return scores +} + func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) map[schedulingtypes.Pod]float64 { logger := log.FromContext(ctx) if s.latencypredictor == nil { @@ -149,14 +178,10 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle // Check if SLOs are provided if !sloCtx.PredictorBasedScheduling { logger.V(logutil.DEBUG).Info("PredictorBasedScheduling turned off, skipping prediction-based filtering") + s.setSLOContextForRequest(request, sloCtx) return nil } - predictions := s.generatePredictions(ctx, state, request, sloCtx, pods) - s.updateRequestContextWithPredictions(sloCtx, predictions) - - allPreds := append([]PodPredictionResult(nil), predictions...) - // Initialize scores map with all pods having score 0 scores := make(map[schedulingtypes.Pod]float64, len(pods)) for _, pod := range pods { @@ -165,6 +190,16 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle source := rand.NewSource(time.Now().UnixNano()) r := rand.New(source) + predictions, err := s.generatePredictions(ctx, state, request, sloCtx, pods) + if err != nil { + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Error generating predictions, falling back to composite-only scoring") + // Fall back to composite-only scoring using prefix cache scores + s.setSLOContextForRequest(request, sloCtx) + return s.scoreWithoutPredictions(ctx, state, pods, r) + } + s.updateRequestContextWithPredictions(sloCtx, predictions) + + allPreds := append([]PodPredictionResult(nil), predictions...) allPreds, sticky := s.epsilonGreedyAffinityGate(ctx, allPreds, r, "overall", AffinityGateTauGlobal) // Check if all pods are invalid and all have running requests diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go new file mode 100644 index 000000000..d4637b96f --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go @@ -0,0 +1,525 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/types" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +// mockPredictor implements PredictorInterface for testing +type mockPredictor struct { + predictions map[string]*latencypredictor.PredictionResponse + err error +} + +func (m *mockPredictor) Predict(ctx context.Context, request latencypredictor.PredictionRequest) (*latencypredictor.PredictionResponse, error) { + if m.err != nil { + return nil, m.err + } + // Generate a key based on KV cache percentage to return different predictions for different pods + key := fmt.Sprintf("%.1f", request.KVCachePercentage) + if pred, ok := m.predictions[key]; ok { + return pred, nil + } + // Default prediction + return &latencypredictor.PredictionResponse{TTFT: 0.5, TPOT: 0.03}, nil +} + +func (m *mockPredictor) PredictBulk(ctx context.Context, requests []latencypredictor.PredictionRequest) (*latencypredictor.BulkPredictionResponse, error) { + if m.err != nil { + return nil, m.err + } + // Generate a key based on KV cache percentage to return different predictions for different pods + responses := make([]latencypredictor.PredictionResponse, 0, len(requests)) + for _, request := range requests { + key := fmt.Sprintf("%.1f", request.KVCachePercentage) + if pred, ok := m.predictions[key]; ok { + responses = append(responses, *pred) + } else { + return nil, fmt.Errorf("no prediction for key %s", key) + } + } + return &latencypredictor.BulkPredictionResponse{Predictions: responses}, nil +} + +func (m *mockPredictor) PredictBulkStrict(ctx context.Context, requests []latencypredictor.PredictionRequest) (*latencypredictor.BulkPredictionResponse, error) { + if m.err != nil { + return nil, m.err + } + // Generate a key based on KV cache percentage to return different predictions for different pods + responses := make([]latencypredictor.PredictionResponse, 0, len(requests)) + for _, request := range requests { + key := fmt.Sprintf("%.1f", request.KVCachePercentage) + if pred, ok := m.predictions[key]; ok { + responses = append(responses, *pred) + } else { + return nil, fmt.Errorf("no prediction for key %s", key) + } + } + return &latencypredictor.BulkPredictionResponse{Predictions: responses}, nil +} + +func (m *mockPredictor) AddTrainingDataBulk(data []latencypredictor.TrainingEntry) error { + return nil +} + +func (m *mockPredictor) AddTrainingData(data latencypredictor.TrainingEntry) error { + return nil +} + +func (m *mockPredictor) HealthCheck() error { + return nil +} + +func (m *mockPredictor) GetServerStatus(ctx context.Context) (*latencypredictor.ServerStatusResponse, error) { + return &latencypredictor.ServerStatusResponse{}, nil +} + +func createTestPod(name string, kvCacheUsage float64, runningQueueSize, waitingQueueSize int) schedulingtypes.Pod { + return &schedulingtypes.PodMetrics{ + Pod: &backend.Pod{ + NamespacedName: types.NamespacedName{ + Name: name, + Namespace: "default", + }, + }, + MetricsState: &backendmetrics.MetricsState{ + KVCacheUsagePercent: kvCacheUsage, + RunningQueueSize: runningQueueSize, + WaitingQueueSize: waitingQueueSize, + }, + } +} + +func createTestLLMRequest(ttftSLO, tpotSLO float64, predictionBased bool) *schedulingtypes.LLMRequest { + headers := make(map[string]string) + if ttftSLO > 0 { + headers["x-ttft-slo"] = fmt.Sprintf("%f", ttftSLO) + } + if tpotSLO > 0 { + headers["x-avg-tpot-slo"] = fmt.Sprintf("%f", tpotSLO) + } + headers["x-prediction-based-scheduling"] = fmt.Sprintf("%t", predictionBased) + + return &schedulingtypes.LLMRequest{ + Headers: headers, + Body: &schedulingtypes.LLMRequestBody{ + Completions: &schedulingtypes.CompletionsRequest{ + Prompt: "test prompt", + }, + }, + } +} + +func TestSLOAwareRouter_Score(t *testing.T) { + tests := []struct { + name string + predictor *mockPredictor + strategy HeadroomStrategy + request *schedulingtypes.LLMRequest + pods []schedulingtypes.Pod + expectedScores map[string]float64 // Map of pod name to expected score + expectNil bool + }{ + { + name: "Prediction-based scheduling disabled", + predictor: &mockPredictor{}, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest(1.0, 0.05, false), // predictionBased = false + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), // 50% KV cache, 2 running, 1 waiting + createTestPod("pod2", 0.7, 3, 2), // 70% KV cache, 3 running, 2 waiting + }, + expectNil: true, + }, + { + name: "No predictor configured", + predictor: nil, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest(1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), + }, + expectNil: true, + }, + { + name: "All pods have positive headroom", + predictor: &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.5": {TTFT: 0.5, TPOT: 0.03}, // 50% KV cache + "0.6": {TTFT: 0.6, TPOT: 0.04}, // 60% KV cache + "0.3": {TTFT: 0.4, TPOT: 0.02}, // 30% KV cache + }, + }, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest(1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), // 50% KV cache + createTestPod("pod2", 0.6, 3, 2), // 60% KV cache + createTestPod("pod3", 0.3, 1, 0), // 30% KV cache + }, + // One pod should be selected with score 1, others 0 + expectedScores: map[string]float64{ + // We can't predict which one due to randomness, but exactly one should be 1 + }, + }, + { + name: "All pods have negative headroom", + predictor: &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.8": {TTFT: 1.5, TPOT: 0.08}, // 80% KV cache - high load + "0.9": {TTFT: 1.8, TPOT: 0.09}, // 90% KV cache - very high load + }, + }, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest(1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.8, 5, 3), // 80% KV cache, high load + createTestPod("pod2", 0.9, 6, 4), // 90% KV cache, very high load + }, + // One pod should still be selected even with negative headroom + expectedScores: map[string]float64{}, + }, + { + name: "Mixed positive and negative headroom", + predictor: &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.3": {TTFT: 0.5, TPOT: 0.03}, // 30% KV cache - Positive headroom + "0.9": {TTFT: 1.5, TPOT: 0.08}, // 90% KV cache - Negative headroom + }, + }, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest(1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod-positive", 0.3, 1, 0), // Low KV cache, positive headroom + createTestPod("pod-negative", 0.9, 6, 4), // High KV cache, negative headroom + }, + // With 99% probability, positive headroom pod should be selected + expectedScores: map[string]float64{}, + }, + { + name: "Prediction errors - fallback to composite scoring", + predictor: &mockPredictor{ + err: fmt.Errorf("prediction failed"), + }, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest(1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), + createTestPod("pod2", 0.6, 3, 2), + }, + // Should fall back to composite-only scoring and select one pod + expectedScores: map[string]float64{ + // One pod should be selected with score 1, verified in general validation below + }, + }, + { + name: "Empty pod list", + predictor: &mockPredictor{}, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest(1.0, 0.05, true), + pods: []schedulingtypes.Pod{}, + // Should return empty scores map + expectedScores: map[string]float64{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var router *SLOAwareRouter + if tt.predictor == nil { + router = NewSLOAwareRouter(nil, tt.strategy) + } else { + router = NewSLOAwareRouter(tt.predictor, tt.strategy) + } + + scores := router.Score(context.Background(), schedulingtypes.NewCycleState(), tt.request, tt.pods) + + if tt.expectNil { + assert.Nil(t, scores, "Expected nil scores") + return + } + + assert.NotNil(t, scores, "Expected non-nil scores") + + // If we have specific expected scores, verify them + if len(tt.expectedScores) > 0 { + for _, pod := range tt.pods { + podName := pod.GetPod().NamespacedName.Name + if expectedScore, ok := tt.expectedScores[podName]; ok { + assert.InDelta(t, expectedScore, scores[pod], 0.0001, "Pod %s should have score %f", podName, expectedScore) + } + } + } + + // General validation: exactly one pod should have score 1 (selected), others should have score 0 + // This applies even when predictions fail because we fall back to composite scoring + if !tt.expectNil && len(tt.pods) > 0 && tt.predictor != nil { + selectedCount := 0 + for _, score := range scores { + if score == 1.0 { + selectedCount++ + } else { + assert.InDelta(t, 0.0, score, 0.0001, "Non-selected pods should have score 0") + } + } + assert.Equal(t, 1, selectedCount, "Exactly one pod should be selected with score 1") + } + }) + } +} + +func TestSLOAwareRouter_Strategies(t *testing.T) { + tests := []struct { + name string + strategy HeadroomStrategy + }{ + { + name: "HeadroomStrategyLeast", + strategy: HeadroomStrategyLeast, + }, + { + name: "HeadroomStrategyMost", + strategy: HeadroomStrategyMost, + }, + { + name: "HeadroomStrategyCompositeMost", + strategy: HeadroomStrategyCompositeMost, + }, + { + name: "HeadroomStrategyCompositeLeast", + strategy: HeadroomStrategyCompositeLeast, + }, + { + name: "HeadroomStrategyCompositeOnly", + strategy: HeadroomStrategyCompositeOnly, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.5": {TTFT: 0.5, TPOT: 0.03}, + "0.6": {TTFT: 0.6, TPOT: 0.04}, + "0.3": {TTFT: 0.4, TPOT: 0.02}, + }, + } + router := NewSLOAwareRouter(predictor, tt.strategy) + + request := createTestLLMRequest(1.0, 0.05, true) + pods := []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), + createTestPod("pod2", 0.6, 3, 2), + createTestPod("pod3", 0.3, 1, 0), + } + + scores := router.Score(context.Background(), schedulingtypes.NewCycleState(), request, pods) + + assert.NotNil(t, scores, "Expected non-nil scores for strategy %s", tt.strategy) + + // Verify exactly one pod is selected + selectedCount := 0 + for _, score := range scores { + if score == 1.0 { + selectedCount++ + } + } + assert.Equal(t, 1, selectedCount, "Strategy %s should select exactly one pod", tt.strategy) + }) + } +} + +func TestSLOAwareRouter_SetHeadroomStrategy(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + + assert.Equal(t, HeadroomStrategyLeast, router.GetHeadroomStrategy(), "Initial strategy should be Least") + + router.SetHeadroomStrategy(HeadroomStrategyMost) + assert.Equal(t, HeadroomStrategyMost, router.GetHeadroomStrategy(), "Strategy should be updated to Most") + + router.SetHeadroomStrategy(HeadroomStrategyCompositeOnly) + assert.Equal(t, HeadroomStrategyCompositeOnly, router.GetHeadroomStrategy(), "Strategy should be updated to CompositeOnly") +} + +func TestSLOAwareRouter_TypedName(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + + tn := router.TypedName() + assert.Equal(t, "slo-aware-routing", tn.Type, "Type should be slo-aware-routing") + assert.Equal(t, "slo-aware-routing", tn.Name, "Default name should be slo-aware-routing") +} + +func TestSLOAwareRouter_WithName(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + + customName := "custom-router" + router = router.WithName(customName) + + tn := router.TypedName() + assert.Equal(t, "slo-aware-routing", tn.Type, "Type should remain slo-aware-routing") + assert.Equal(t, customName, tn.Name, "Name should be updated to custom name") +} + +func TestSLOAwareRouter_GetPodRunningRequestCount(t *testing.T) { + tests := []struct { + name string + setupRequests func(*SLOAwareRouter, schedulingtypes.Pod) + expectedCount int + }{ + { + name: "No running requests", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) {}, + expectedCount: 0, + }, + { + name: "One running request", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = NewRequestPriorityQueue() + r.runningRequestLists[podName].Add("req1", 0.04) + }, + expectedCount: 1, + }, + { + name: "Multiple running requests", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = NewRequestPriorityQueue() + r.runningRequestLists[podName].Add("req1", 0.04) + r.runningRequestLists[podName].Add("req2", 0.03) + r.runningRequestLists[podName].Add("req3", 0.05) + }, + expectedCount: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + pod := createTestPod("test-pod", 0.5, 2, 1) + + tt.setupRequests(router, pod) + + count := router.getPodRunningRequestCount(pod) + assert.Equal(t, tt.expectedCount, count, "Running request count should match expected") + }) + } +} + +func TestSLOAwareRouter_GetPodMinTPOTSLO(t *testing.T) { + tests := []struct { + name string + setupRequests func(*SLOAwareRouter, schedulingtypes.Pod) + expectedSLO float64 + }{ + { + name: "No running requests", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) {}, + expectedSLO: 0.0, + }, + { + name: "One running request", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = NewRequestPriorityQueue() + r.runningRequestLists[podName].Add("req1", 0.04) + }, + expectedSLO: 0.04, + }, + { + name: "Multiple running requests - should return minimum", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = NewRequestPriorityQueue() + // Add in any order - heap will maintain minimum at top + r.runningRequestLists[podName].Add("req1", 0.05) + r.runningRequestLists[podName].Add("req2", 0.03) // This is the minimum + r.runningRequestLists[podName].Add("req3", 0.04) + }, + expectedSLO: 0.03, // Minimum TPOT (heap guarantees this is at items[0]) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + pod := createTestPod("test-pod", 0.5, 2, 1) + + tt.setupRequests(router, pod) + + minSLO := router.getPodMinTPOTSLO(pod) + assert.InDelta(t, tt.expectedSLO, minSLO, 0.0001, "Min TPOT SLO should match expected") + }) + } +} + +func TestSLOAwareRouter_GetPrefixCacheScoreForPod(t *testing.T) { + tests := []struct { + name string + setupState func(*schedulingtypes.CycleState) + expectedScore float64 + }{ + { + name: "No prefix cache state", + setupState: func(s *schedulingtypes.CycleState) {}, + expectedScore: 0.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + + state := schedulingtypes.NewCycleState() + tt.setupState(state) + + pod := createTestPod("test-pod", 0.5, 2, 1) + + score := router.getPrefixCacheScoreForPod(context.Background(), state, pod) + assert.InDelta(t, tt.expectedScore, score, 0.0001, "Prefix cache score should match expected") + }) + } +} From 3db55c465765b14e6664412f1be7824420ac15cd Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Sat, 8 Nov 2025 00:10:22 +0000 Subject: [PATCH 3/6] Remove outdated inferencepool-resources deployment --- .../manifests/inferencepool-resources-lp.yaml | 822 ------------------ 1 file changed, 822 deletions(-) delete mode 100644 config/manifests/inferencepool-resources-lp.yaml diff --git a/config/manifests/inferencepool-resources-lp.yaml b/config/manifests/inferencepool-resources-lp.yaml deleted file mode 100644 index 8a7951e96..000000000 --- a/config/manifests/inferencepool-resources-lp.yaml +++ /dev/null @@ -1,822 +0,0 @@ -# Note: If you change this file, please also change the file used for e2e tests! -# -# https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/test/testdata/inferencepool-e2e.yaml - -# --- ConfigMaps --- -apiVersion: v1 -kind: ConfigMap -metadata: - name: latency-predictor-config - namespace: default -data: - LATENCY_RETRAINING_INTERVAL_SEC: "1" - LATENCY_MIN_SAMPLES_FOR_RETRAIN: "100" - LATENCY_TTFT_MODEL_PATH: "/models/ttft.joblib" - LATENCY_TPOT_MODEL_PATH: "/models/tpot.joblib" - LATENCY_TTFT_SCALER_PATH: "/models/ttft_scaler.joblib" - LATENCY_TPOT_SCALER_PATH: "/models/tpot_scaler.joblib" - LATENCY_MODEL_TYPE: "xgboost" - LATENCY_MAX_TRAINING_DATA_SIZE_PER_BUCKET: "5000" - LATENCY_QUANTILE_ALPHA: "0.9" ---- -apiVersion: v1 -kind: ConfigMap -metadata: - name: prediction-server-config - namespace: default -data: - LATENCY_MODEL_TYPE: "xgboost" - PREDICT_HOST: "0.0.0.0" - LOCAL_TTFT_MODEL_PATH: "/server_models/ttft.joblib" # Use individual storage - LOCAL_TPOT_MODEL_PATH: "/server_models/tpot.joblib" - LOCAL_TTFT_SCALER_PATH: "/server_models/ttft_scaler.joblib" - LOCAL_TPOT_SCALER_PATH: "/server_models/tpot_scaler.joblib" ---- -# --- InferencePool --- -apiVersion: inference.networking.x-k8s.io/v1alpha2 -kind: InferencePool -metadata: - name: vllm-llama3-8b-instruct -spec: - targetPortNumber: 8000 - selector: - app: vllm-llama3-8b-instruct - extensionRef: - name: vllm-llama3-8b-instruct-epp ---- -# --- EPP Service --- -apiVersion: v1 -kind: Service -metadata: - name: vllm-llama3-8b-instruct-epp - namespace: default -spec: - selector: - app: vllm-llama3-8b-instruct-epp - ports: - - name: epp-grpc - protocol: TCP - port: 9002 - targetPort: 9002 - appProtocol: http2 - - name: latency-predictor-training - protocol: TCP - port: 8000 - targetPort: 8000 - - name: latency-predictor-1 - protocol: TCP - port: 8001 - targetPort: 8001 - - name: latency-predictor-2 - protocol: TCP - port: 8002 - targetPort: 8002 - - name: latency-predictor-3 - protocol: TCP - port: 8003 - targetPort: 8003 - - name: latency-predictor-4 - protocol: TCP - port: 8004 - targetPort: 8004 - - name: latency-predictor-5 - protocol: TCP - port: 8005 - targetPort: 8005 - - name: latency-predictor-6 - protocol: TCP - port: 8006 - targetPort: 8006 - - name: latency-predictor-7 - protocol: TCP - port: 8007 - targetPort: 8007 - - name: latency-predictor-8 - protocol: TCP - port: 8008 - targetPort: 8008 - - name: latency-predictor-9 - protocol: TCP - port: 8009 - targetPort: 8009 - - name: latency-predictor-10 - protocol: TCP - port: 8010 - targetPort: 8010 - - name: prometheus - protocol: TCP - port: 9090 - targetPort: 9090 - type: LoadBalancer ---- -apiVersion: v1 -kind: ServiceAccount -metadata: - name: vllm-llama3-8b-instruct-epp - namespace: default ---- -# --- EPP Deployment with Individual Container Volumes --- -apiVersion: apps/v1 -kind: Deployment -metadata: - name: vllm-llama3-8b-instruct-epp - namespace: default - labels: - app: vllm-llama3-8b-instruct-epp -spec: - replicas: 1 # Multiple EPP pods for scaling - selector: - matchLabels: - app: vllm-llama3-8b-instruct-epp - template: - metadata: - labels: - app: vllm-llama3-8b-instruct-epp - spec: - serviceAccountName: vllm-llama3-8b-instruct-epp - # Conservatively, this timeout should mirror the longest grace period of the pods within the pool - containers: - # EPP Container - - name: epp - image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/slo-routing-epp-exp - imagePullPolicy: Always - args: - - -pool-name - - "vllm-llama3-8b-instruct" - - "-pool-namespace" - - "default" - - --pool-group - - "inference.networking.x-k8s.io" - - -v - - "4" - - --zap-encoder - - "json" - - -grpc-port - - "9002" - - -grpc-health-port - - "9003" - - "--config-file" - - "/config/default-plugins.yaml" - - "-enable-latency-predictor" - env: - - name: PREDICTION_SERVER_URL - value: "http://localhost:8001,http://localhost:8002,http://localhost:8003,http://localhost:8004,http://localhost:8005,http://localhost:8006,http://localhost:8007,http://localhost:8008,http://localhost:8009,http://localhost:8010" # All 10 prediction servers - - name: TRAINING_SERVER_URL - value: "http://localhost:8000" # Single training server for sending training data - - name: LATENCY_MAX_SAMPLE_SIZE - value: "10000" # Maximum sample size for latency prediction - - - - - ports: - - containerPort: 9002 - - containerPort: 9003 - - name: metrics - containerPort: 9090 - livenessProbe: - grpc: - port: 9003 - service: inference-extension - initialDelaySeconds: 5 - periodSeconds: 10 - readinessProbe: - grpc: - port: 9003 - service: inference-extension - initialDelaySeconds: 5 - periodSeconds: 10 - volumeMounts: - - name: plugins-config-volume - mountPath: "/config" - # Training Server Sidecar Container - - name: training-server - image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_training:latest - imagePullPolicy: Always - ports: - - containerPort: 8000 - name: training-port - livenessProbe: - httpGet: - path: /healthz - port: 8000 - initialDelaySeconds: 30 - periodSeconds: 20 - readinessProbe: - httpGet: - path: /readyz - port: 8000 - initialDelaySeconds: 45 - periodSeconds: 10 - resources: - requests: - cpu: "2000m" - memory: "4Gi" - limits: - cpu: "4000m" - memory: "8Gi" - envFrom: - - configMapRef: - name: latency-predictor-config - env: - - name: POD_NAME - valueFrom: - fieldRef: - fieldPath: metadata.name - - name: SERVER_TYPE - value: "training" - volumeMounts: - - name: training-server-storage - mountPath: /models - # Prediction Server Sidecar Container 1 - - name: prediction-server-1 - image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest - imagePullPolicy: Always - command: ["uvicorn"] - args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8001"] - ports: - - containerPort: 8001 - name: predict-port-1 - livenessProbe: - httpGet: - path: /healthz - port: 8001 - initialDelaySeconds: 15 - periodSeconds: 15 - readinessProbe: - httpGet: - path: /readyz - port: 8001 - initialDelaySeconds: 10 - periodSeconds: 5 - failureThreshold: 10 - resources: - requests: - cpu: "500m" - memory: "1Gi" - limits: - cpu: "1000m" - memory: "2Gi" - envFrom: - - configMapRef: - name: prediction-server-config - env: - - name: PREDICT_PORT - value: "8001" - - name: POD_NAME - valueFrom: - fieldRef: - fieldPath: metadata.name - - name: SERVER_TYPE - value: "prediction-1" - - name: TRAINING_SERVER_URL - value: "http://localhost:8000" - volumeMounts: - - name: prediction-server-1-storage - mountPath: /server_models - # Prediction Server Sidecar Container 2 - - name: prediction-server-2 - image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest - imagePullPolicy: Always - command: ["uvicorn"] - args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8002"] - ports: - - containerPort: 8002 - name: predict-port-2 - livenessProbe: - httpGet: - path: /healthz - port: 8002 - initialDelaySeconds: 15 - periodSeconds: 15 - readinessProbe: - httpGet: - path: /readyz - port: 8002 - initialDelaySeconds: 10 - periodSeconds: 5 - failureThreshold: 10 - resources: - requests: - cpu: "500m" - memory: "1Gi" - limits: - cpu: "1000m" - memory: "2Gi" - envFrom: - - configMapRef: - name: prediction-server-config - env: - - name: PREDICT_PORT - value: "8002" - - name: POD_NAME - valueFrom: - fieldRef: - fieldPath: metadata.name - - name: SERVER_TYPE - value: "prediction-2" - - name: TRAINING_SERVER_URL - value: "http://localhost:8000" - volumeMounts: - - name: prediction-server-2-storage - mountPath: /server_models - # Prediction Server Sidecar Container 3 - - name: prediction-server-3 - image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest - imagePullPolicy: Always - command: ["uvicorn"] - args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8003"] - ports: - - containerPort: 8003 - name: predict-port-3 - livenessProbe: - httpGet: - path: /healthz - port: 8003 - initialDelaySeconds: 15 - periodSeconds: 15 - readinessProbe: - httpGet: - path: /readyz - port: 8003 - initialDelaySeconds: 10 - periodSeconds: 5 - failureThreshold: 10 - resources: - requests: - cpu: "500m" - memory: "1Gi" - limits: - cpu: "1000m" - memory: "2Gi" - envFrom: - - configMapRef: - name: prediction-server-config - env: - - name: PREDICT_PORT - value: "8003" - - name: POD_NAME - valueFrom: - fieldRef: - fieldPath: metadata.name - - name: SERVER_TYPE - value: "prediction-3" - - name: TRAINING_SERVER_URL - value: "http://localhost:8000" - volumeMounts: - - name: prediction-server-3-storage - mountPath: /server_models - # Prediction Server Sidecar Container 4 - - name: prediction-server-4 - image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest - imagePullPolicy: Always - command: ["uvicorn"] - args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8004"] - ports: - - containerPort: 8004 - name: predict-port-4 - livenessProbe: - httpGet: - path: /healthz - port: 8004 - initialDelaySeconds: 15 - periodSeconds: 15 - readinessProbe: - httpGet: - path: /readyz - port: 8004 - initialDelaySeconds: 10 - periodSeconds: 5 - failureThreshold: 10 - resources: - requests: - cpu: "500m" - memory: "1Gi" - limits: - cpu: "1000m" - memory: "2Gi" - envFrom: - - configMapRef: - name: prediction-server-config - env: - - name: PREDICT_PORT - value: "8004" - - name: POD_NAME - valueFrom: - fieldRef: - fieldPath: metadata.name - - name: SERVER_TYPE - value: "prediction-4" - - name: TRAINING_SERVER_URL - value: "http://localhost:8000" - volumeMounts: - - name: prediction-server-4-storage - mountPath: /server_models - # Prediction Server Sidecar Container 5 - - name: prediction-server-5 - image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest - imagePullPolicy: Always - command: ["uvicorn"] - args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8005"] - ports: - - containerPort: 8005 - name: predict-port-5 - livenessProbe: - httpGet: - path: /healthz - port: 8005 - initialDelaySeconds: 15 - periodSeconds: 15 - readinessProbe: - httpGet: - path: /readyz - port: 8005 - initialDelaySeconds: 10 - periodSeconds: 5 - failureThreshold: 10 - resources: - requests: - cpu: "500m" - memory: "1Gi" - limits: - cpu: "1000m" - memory: "2Gi" - envFrom: - - configMapRef: - name: prediction-server-config - env: - - name: PREDICT_PORT - value: "8005" - - name: POD_NAME - valueFrom: - fieldRef: - fieldPath: metadata.name - - name: SERVER_TYPE - value: "prediction-5" - - name: TRAINING_SERVER_URL - value: "http://localhost:8000" - volumeMounts: - - name: prediction-server-5-storage - mountPath: /server_models - # Prediction Server Sidecar Container 6 - - name: prediction-server-6 - image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest - imagePullPolicy: Always - command: ["uvicorn"] - args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8006"] - ports: - - containerPort: 8006 - name: predict-port-6 - livenessProbe: - httpGet: - path: /healthz - port: 8006 - initialDelaySeconds: 15 - periodSeconds: 15 - readinessProbe: - httpGet: - path: /readyz - port: 8006 - initialDelaySeconds: 10 - periodSeconds: 5 - failureThreshold: 10 - resources: - requests: - cpu: "500m" - memory: "1Gi" - limits: - cpu: "1000m" - memory: "2Gi" - envFrom: - - configMapRef: - name: prediction-server-config - env: - - name: PREDICT_PORT - value: "8006" - - name: POD_NAME - valueFrom: - fieldRef: - fieldPath: metadata.name - - name: SERVER_TYPE - value: "prediction-6" - - name: TRAINING_SERVER_URL - value: "http://localhost:8000" - volumeMounts: - - name: prediction-server-6-storage - mountPath: /server_models - # Prediction Server Sidecar Container 7 - - name: prediction-server-7 - image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest - imagePullPolicy: Always - command: ["uvicorn"] - args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8007"] - ports: - - containerPort: 8007 - name: predict-port-7 - livenessProbe: - httpGet: - path: /healthz - port: 8007 - initialDelaySeconds: 15 - periodSeconds: 15 - readinessProbe: - httpGet: - path: /readyz - port: 8007 - initialDelaySeconds: 10 - periodSeconds: 5 - failureThreshold: 10 - resources: - requests: - cpu: "500m" - memory: "1Gi" - limits: - cpu: "1000m" - memory: "2Gi" - envFrom: - - configMapRef: - name: prediction-server-config - env: - - name: PREDICT_PORT - value: "8007" - - name: POD_NAME - valueFrom: - fieldRef: - fieldPath: metadata.name - - name: SERVER_TYPE - value: "prediction-7" - - name: TRAINING_SERVER_URL - value: "http://localhost:8000" - volumeMounts: - - name: prediction-server-7-storage - mountPath: /server_models - # Prediction Server Sidecar Container 8 - - name: prediction-server-8 - image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest - imagePullPolicy: Always - command: ["uvicorn"] - args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8008"] - ports: - - containerPort: 8008 - name: predict-port-8 - livenessProbe: - httpGet: - path: /healthz - port: 8008 - initialDelaySeconds: 15 - periodSeconds: 15 - readinessProbe: - httpGet: - path: /readyz - port: 8008 - initialDelaySeconds: 10 - periodSeconds: 5 - failureThreshold: 10 - resources: - requests: - cpu: "500m" - memory: "1Gi" - limits: - cpu: "1000m" - memory: "2Gi" - envFrom: - - configMapRef: - name: prediction-server-config - env: - - name: PREDICT_PORT - value: "8008" - - name: POD_NAME - valueFrom: - fieldRef: - fieldPath: metadata.name - - name: SERVER_TYPE - value: "prediction-8" - - name: TRAINING_SERVER_URL - value: "http://localhost:8000" - volumeMounts: - - name: prediction-server-8-storage - mountPath: /server_models - # Prediction Server Sidecar Container 9 - - name: prediction-server-9 - image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest - imagePullPolicy: Always - command: ["uvicorn"] - args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8009"] - ports: - - containerPort: 8009 - name: predict-port-9 - livenessProbe: - httpGet: - path: /healthz - port: 8009 - initialDelaySeconds: 15 - periodSeconds: 15 - readinessProbe: - httpGet: - path: /readyz - port: 8009 - initialDelaySeconds: 10 - periodSeconds: 5 - failureThreshold: 10 - resources: - requests: - cpu: "500m" - memory: "1Gi" - limits: - cpu: "1000m" - memory: "2Gi" - envFrom: - - configMapRef: - name: prediction-server-config - env: - - name: PREDICT_PORT - value: "8009" - - name: POD_NAME - valueFrom: - fieldRef: - fieldPath: metadata.name - - name: SERVER_TYPE - value: "prediction-9" - - name: TRAINING_SERVER_URL - value: "http://localhost:8000" - volumeMounts: - - name: prediction-server-9-storage - mountPath: /server_models - # Prediction Server Sidecar Container 10 - - name: prediction-server-10 - image: us-central1-docker.pkg.dev/benjaminbraun-gke-dev/slo-routing/latency_prediction:latest - imagePullPolicy: Always - command: ["uvicorn"] - args: ["prediction_server:app", "--host", "0.0.0.0", "--port", "8010"] - ports: - - containerPort: 8010 - name: predict-port-10 - livenessProbe: - httpGet: - path: /healthz - port: 8010 - initialDelaySeconds: 15 - periodSeconds: 15 - readinessProbe: - httpGet: - path: /readyz - port: 8010 - initialDelaySeconds: 10 - periodSeconds: 5 - failureThreshold: 10 - resources: - requests: - cpu: "500m" - memory: "1Gi" - limits: - cpu: "1000m" - memory: "2Gi" - envFrom: - - configMapRef: - name: prediction-server-config - env: - - name: PREDICT_PORT - value: "8010" - - name: POD_NAME - valueFrom: - fieldRef: - fieldPath: metadata.name - - name: SERVER_TYPE - value: "prediction-10" - - name: TRAINING_SERVER_URL - value: "http://localhost:8000" - volumeMounts: - - name: prediction-server-10-storage - mountPath: /server_models - volumes: - - name: training-server-storage - emptyDir: - sizeLimit: "20Gi" # Dedicated volume for training server - - name: prediction-server-1-storage - emptyDir: - sizeLimit: "10Gi" # Dedicated volume for prediction server 1 - - name: prediction-server-2-storage - emptyDir: - sizeLimit: "10Gi" # Dedicated volume for prediction server 2 - - name: prediction-server-3-storage - emptyDir: - sizeLimit: "10Gi" # Dedicated volume for prediction server 3 - - name: prediction-server-4-storage - emptyDir: - sizeLimit: "10Gi" # Dedicated volume for prediction server 4 - - name: prediction-server-5-storage - emptyDir: - sizeLimit: "10Gi" # Dedicated volume for prediction server 5 - - name: prediction-server-6-storage - emptyDir: - sizeLimit: "10Gi" # Dedicated volume for prediction server 6 - - name: prediction-server-7-storage - emptyDir: - sizeLimit: "10Gi" # Dedicated volume for prediction server 7 - - name: prediction-server-8-storage - emptyDir: - sizeLimit: "10Gi" # Dedicated volume for prediction server 8 - - name: prediction-server-9-storage - emptyDir: - sizeLimit: "10Gi" # Dedicated volume for prediction server 9 - - name: prediction-server-10-storage - emptyDir: - sizeLimit: "10Gi" # Dedicated volume for prediction server 10 - - name: plugins-config-volume - configMap: - name: plugins-config ---- -apiVersion: v1 -kind: ConfigMap -metadata: - name: plugins-config - namespace: default -data: - default-plugins.yaml: | - apiVersion: inference.networking.x-k8s.io/v1alpha1 - kind: EndpointPickerConfig - plugins: - - type: queue-scorer - - type: kv-cache-utilization-scorer - - type: prefix-cache-scorer - - type: slo-aware-routing - - type: slo-aware-profile-handler - - type: max-score-picker - schedulingProfiles: - - name: prefix - plugins: - - pluginRef: prefix-cache-scorer - - name: default - plugins: - - pluginRef: slo-aware-routing - weight: 0 - - pluginRef: queue-scorer - - pluginRef: kv-cache-utilization-scorer - - pluginRef: max-score-picker - - name: slo - plugins: - - pluginRef: slo-aware-routing - - pluginRef: max-score-picker - ---- -# --- RBAC --- -kind: Role -apiVersion: rbac.authorization.k8s.io/v1 -metadata: - name: pod-read - namespace: default -rules: -- apiGroups: [ "inference.networking.x-k8s.io" ] - resources: [ "inferenceobjectives", "inferencepools" ] - verbs: [ "get", "watch", "list" ] -- apiGroups: [ "inference.networking.k8s.io" ] - resources: [ "inferencepools" ] - verbs: [ "get", "watch", "list" ] -- apiGroups: [ "" ] - resources: [ "pods" ] - verbs: [ "get", "watch", "list" ] ---- -kind: RoleBinding -apiVersion: rbac.authorization.k8s.io/v1 -metadata: - name: pod-read-binding - namespace: default -subjects: -- kind: ServiceAccount - name: vllm-llama3-8b-instruct-epp - namespace: default -roleRef: - apiGroup: rbac.authorization.k8s.io - kind: Role - name: pod-read ---- -kind: ClusterRole -apiVersion: rbac.authorization.k8s.io/v1 -metadata: - name: auth-reviewer -rules: -- apiGroups: - - authentication.k8s.io - resources: - - tokenreviews - verbs: - - create -- apiGroups: - - authorization.k8s.io - resources: - - subjectaccessreviews - verbs: - - create ---- -kind: ClusterRoleBinding -apiVersion: rbac.authorization.k8s.io/v1 -metadata: - name: auth-reviewer-binding -subjects: -- kind: ServiceAccount - name: vllm-llama3-8b-instruct-epp - namespace: default -roleRef: - apiGroup: rbac.authorization.k8s.io - kind: ClusterRole - name: auth-reviewer \ No newline at end of file From ed3b0cd2bfc188077682c450c6f7f04f0ecc150c Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Sat, 8 Nov 2025 00:25:02 +0000 Subject: [PATCH 4/6] Add requestcontrol_hooks test and fix boilerplate in slo_aware_router to be consistent --- .../plugins/multi/slo_aware_router/config.go | 8 +- .../plugins/multi/slo_aware_router/headers.go | 8 +- .../latencypredictor_helper.go | 8 +- .../multi/slo_aware_router/prediction.go | 8 +- .../slo_aware_router/requestcontrol_hooks.go | 16 + .../requestcontrol_hooks_test.go | 953 ++++++++++++++++++ .../slo_aware_router/running_request_queue.go | 16 + .../plugins/multi/slo_aware_router/sampler.go | 16 +- .../multi/slo_aware_router/scorer_test.go | 20 +- .../multi/slo_aware_router/selection.go | 8 +- .../plugins/multi/slo_aware_router/types.go | 8 +- 11 files changed, 1047 insertions(+), 22 deletions(-) create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go index 046870948..fcb4b7223 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go @@ -1,5 +1,6 @@ /* -© 2025 The Kubernetes Authors. +Copyright 2025 The Kubernetes Authors. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -7,7 +8,10 @@ You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License. +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ // Package requestcontrol contains helpers to decouple latency-predictor logic. diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go index 2588b3104..8574ec41b 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go @@ -1,5 +1,6 @@ /* -© 2025 The Kubernetes Authors. +Copyright 2025 The Kubernetes Authors. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -7,7 +8,10 @@ You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License. +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ // Package requestcontrol contains helpers to decouple latency-predictor logic. diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go index 8961bda44..aa47f93c9 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go @@ -1,5 +1,6 @@ /* -© 2025 The Kubernetes Authors. +Copyright 2025 The Kubernetes Authors. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -7,7 +8,10 @@ You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License. +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ // Package requestcontrol contains helpers to decouple latency-predictor logic. diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go index 7a857fa0d..0c2cfa0a9 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go @@ -1,5 +1,6 @@ /* -© 2025 The Kubernetes Authors. +Copyright 2025 The Kubernetes Authors. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -7,7 +8,10 @@ You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License. +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ // Package requestcontrol contains helpers to decouple latency-predictor logic. diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go index 072e2e2fb..789a9d51c 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go @@ -1,3 +1,19 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package slo_aware_router import ( diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go new file mode 100644 index 000000000..cc40612df --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go @@ -0,0 +1,953 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/go-logr/logr" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/types" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" +) + +// Helper functions + +func createTestSchedulingResult(pod *backend.Pod, kvUsage float64, runningQueue int, waitingQueue int) *schedulingtypes.SchedulingResult { + + mockPod := createTestPod(pod.NamespacedName.Name, kvUsage, runningQueue, waitingQueue) + + return &schedulingtypes.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*schedulingtypes.ProfileRunResult{ + "default": { + TargetPods: []schedulingtypes.Pod{mockPod}, + }, + }, + } +} + +func createTestRouter() *SLOAwareRouter { + return &SLOAwareRouter{ + sloContextStore: sync.Map{}, + runningRequestLists: make(map[types.NamespacedName]*RequestPriorityQueue), + latencypredictor: nil, + } +} + +// Test cases + +func TestNewSLORequestContext(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + + ctx := NewSLORequestContext(request) + + assert.NotNil(t, ctx) + assert.Equal(t, *request, ctx.SchedulingRequest) + assert.NotNil(t, ctx.LastSeenMetrics) + assert.NotNil(t, ctx.PrefixCacheScoresForPods) + assert.NotNil(t, ctx.PredictedTTFTForScheduling) + assert.NotNil(t, ctx.PredictedTPOTForScheduling) + assert.Empty(t, ctx.LastSeenMetrics) + assert.Empty(t, ctx.PrefixCacheScoresForPods) +} + +func TestSLOAwareRouter_SetAndGetSLOContext(t *testing.T) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + sloCtx := NewSLORequestContext(request) + + // Set context + router.setSLOContextForRequest(request, sloCtx) + + // Get context + retrievedCtx, err := router.getSLOContextForRequest(request) + + require.NoError(t, err) + assert.Equal(t, sloCtx, retrievedCtx) +} + +func TestSLOAwareRouter_GetSLOContext_NotFound(t *testing.T) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + + // Try to get context that doesn't exist + ctx, err := router.getSLOContextForRequest(request) + + assert.Error(t, err) + assert.Nil(t, ctx) + assert.Contains(t, err.Error(), "SLO context not found") +} + +func TestSLOAwareRouter_DeleteSLOContext(t *testing.T) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + sloCtx := NewSLORequestContext(request) + + // Set and then delete context + router.setSLOContextForRequest(request, sloCtx) + router.deleteSLOContextForRequest(request) + + // Verify it's deleted + ctx, err := router.getSLOContextForRequest(request) + assert.Error(t, err) + assert.Nil(t, ctx) +} + +func TestSLOAwareRouter_PreRequest_NoSchedulingResult(t *testing.T) { + router := createTestRouter() + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + + // Call PreRequest with nil scheduling result + router.PreRequest(ctx, request, nil) + + // Should not create SLO context + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_PreRequest_EmptySchedulingResult(t *testing.T) { + router := createTestRouter() + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + + schedulingResult := &schedulingtypes.SchedulingResult{ + ProfileResults: map[string]*schedulingtypes.ProfileRunResult{}, + } + + // Call PreRequest with empty scheduling result + router.PreRequest(ctx, request, schedulingResult) + + // Should not create SLO context + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_PreRequest_Success(t *testing.T) { + router := createTestRouter() + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create and set initial SLO context + sloCtx := NewSLORequestContext(request) + sloCtx.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request, sloCtx) + + // Initialize the request priority queue + router.runningRequestLists[pod.GetPod().NamespacedName] = NewRequestPriorityQueue() + + beforeTime := time.Now() + router.PreRequest(ctx, request, schedulingResult) + afterTime := time.Now() + + // Verify SLO context was updated + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.Equal(t, pod.GetPod(), retrievedCtx.TargetPod) + assert.Equal(t, schedulingResult, retrievedCtx.SchedulingResult) + assert.True(t, retrievedCtx.RequestReceivedTimestamp.After(beforeTime) || + retrievedCtx.RequestReceivedTimestamp.Equal(beforeTime)) + assert.True(t, retrievedCtx.RequestReceivedTimestamp.Before(afterTime) || + retrievedCtx.RequestReceivedTimestamp.Equal(afterTime)) +} + +func TestSLOAwareRouter_PreRequest_GeneratesRequestID(t *testing.T) { + router := createTestRouter() + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("", 100, 50, true) + request.Headers[requtil.RequestIdHeaderKey] = "" // Explicitly empty + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create and set initial SLO context + sloCtx := NewSLORequestContext(request) + sloCtx.AvgTPOTSLO = 50 + + // Since request ID is empty initially, we need to handle this + // The PreRequest should generate a new ID, so let's test that + router.PreRequest(ctx, request, schedulingResult) + + // Request ID should now be set + assert.NotEmpty(t, request.Headers[requtil.RequestIdHeaderKey]) + // Verify it's a valid UUID format + _, err := uuid.Parse(request.Headers[requtil.RequestIdHeaderKey]) + assert.NoError(t, err, "Generated request ID should be a valid UUID") +} + +func TestSLOAwareRouter_PreRequest_AddsToQueue(t *testing.T) { + router := createTestRouter() + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create and set initial SLO context + sloCtx := NewSLORequestContext(request) + sloCtx.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request, sloCtx) + + // PreRequest should create the queue + router.PreRequest(ctx, request, schedulingResult) + + // Verify queue was created and request was added + queue, exists := router.runningRequestLists[pod.GetPod().NamespacedName] + assert.True(t, exists, "Queue should be created for pod") + assert.NotNil(t, queue) +} + +func TestSLOAwareRouter_PreRequest_QueueAlreadyExists(t *testing.T) { + router := createTestRouter() + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request1 := createTestLLMRequest("test-id-1", 100, 50, true) + request2 := createTestLLMRequest("test-id-2", 100, 50, true) + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create and set initial SLO contexts + sloCtx1 := NewSLORequestContext(request1) + sloCtx1.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request1, sloCtx1) + + sloCtx2 := NewSLORequestContext(request2) + sloCtx2.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request2, sloCtx2) + + // Add first request + router.PreRequest(ctx, request1, schedulingResult) + + // Add second request to same pod + router.PreRequest(ctx, request2, schedulingResult) + + // Verify both are in the same queue + queue, exists := router.runningRequestLists[pod.GetPod().NamespacedName] + assert.True(t, exists) + assert.NotNil(t, queue) +} + +func TestSLOAwareRouter_ResponseReceived_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic and should return early + router.ResponseReceived(ctx, request, response, pod.GetPod()) + + // Context should still exist + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} + +func TestSLOAwareRouter_ResponseReceived_NoPod(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic with nil pod + router.ResponseReceived(ctx, request, response, nil) + + // Predictor should not be called + +} + +func TestSLOAwareRouter_ResponseReceived_NoContext(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Don't set SLO context + router.ResponseReceived(ctx, request, response, pod.GetPod()) + + // Should handle missing context gracefully + +} + +func TestSLOAwareRouter_ResponseStreaming_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic and should return early + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // Context should still exist + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} +func TestSLOAwareRouter_ResponseStreaming_FirstToken(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + sloCtx := NewSLORequestContext(request) + sloCtx.RequestReceivedTimestamp = time.Now() + sloCtx.SchedulingResult = schedulingResult + sloCtx.SchedulingRequest = *request + sloCtx.TTFTSLO = 100 + sloCtx.AvgTPOTSLO = 50 + sloCtx.IncomingModelName = "test-model" + sloCtx.PredictedTTFT = 80.0 + sloCtx.AvgPredictedTPOT = 30.0 + // ADD THIS - populate metrics + sloCtx.LastSeenMetrics["prefill"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + sloCtx.LastSeenMetrics["default"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + router.setSLOContextForRequest(request, sloCtx) + + // Initialize the queue and add the request + queue := NewRequestPriorityQueue() + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + + beforeTime := time.Now() + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + afterTime := time.Now() + + // Verify first token timestamp was set + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.True(t, retrievedCtx.LastTokenTimestamp.After(beforeTime) || + retrievedCtx.LastTokenTimestamp.Equal(beforeTime)) + assert.True(t, retrievedCtx.LastTokenTimestamp.Before(afterTime) || + retrievedCtx.LastTokenTimestamp.Equal(afterTime)) +} + +func TestSLOAwareRouter_ResponseStreaming_SubsequentTokens(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + sloCtx := NewSLORequestContext(request) + sloCtx.RequestReceivedTimestamp = time.Now() + sloCtx.SchedulingResult = schedulingResult + sloCtx.SchedulingRequest = *request + sloCtx.TTFTSLO = 100 + sloCtx.AvgTPOTSLO = 50 + sloCtx.IncomingModelName = "test-model" + sloCtx.PredictedTTFT = 80.0 + sloCtx.AvgPredictedTPOT = 30.0 + // ADD THIS - populate metrics + sloCtx.LastSeenMetrics["prefill"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + sloCtx.LastSeenMetrics["default"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + firstTokenTime := time.Now().Add(-100 * time.Millisecond) + + router.setSLOContextForRequest(request, sloCtx) + + // Initialize the queue and add the request + queue := NewRequestPriorityQueue() + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // Verify token timestamp was updated + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.True(t, retrievedCtx.LastTokenTimestamp.After(firstTokenTime)) +} + +func TestSLOAwareRouter_ResponseComplete_QueueNotFound(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + sloCtx.IncomingModelName = "test-model" + sloCtx.TargetPod = pod.GetPod() // ADD THIS to avoid other issues + router.setSLOContextForRequest(request, sloCtx) + + // Create an EMPTY queue (not nil, but empty) to test queue.Remove behavior + router.runningRequestLists[pod.GetPod().NamespacedName] = NewRequestPriorityQueue() + + // Should handle gracefully when request is not in queue + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Context should be deleted + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} +func TestSLOAwareRouter_ResponseStreaming_NoContext(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Don't set SLO context - should handle gracefully + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // Should not panic + +} + +func TestSLOAwareRouter_ResponseComplete_Success(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Create queue and add request + queue := NewRequestPriorityQueue() + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + + sloCtx := NewSLORequestContext(request) + sloCtx.TTFT = 80 + sloCtx.AvgTPOT = 30 + sloCtx.PredictedTTFT = 85 + sloCtx.AvgPredictedTPOT = 32 + sloCtx.TTFTSLO = 100 + sloCtx.AvgTPOTSLO = 50 + sloCtx.IncomingModelName = "incoming-model" + router.setSLOContextForRequest(request, sloCtx) + + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify context was deleted + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) + + // Verify request was removed from queue + assert.Equal(t, 0, queue.Len()) +} + +func TestSLOAwareRouter_ResponseComplete_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Context should still exist (deletion happens only with predictor) + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} + +func TestSLOAwareRouter_ResponseComplete_NoPod(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic with nil pod + router.ResponseComplete(ctx, request, response, nil) + + // Context should still exist (deletion happens only with validpod.GetPod()) + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} + +func TestSLOAwareRouter_ResponseComplete_NoContext(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Don't set SLO context - should handle gracefully + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Should not panic + +} + +func TestSLOAwareRouter_ResponseComplete_WithMetrics(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Create queue + queue := NewRequestPriorityQueue() + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + + sloCtx := NewSLORequestContext(request) + sloCtx.TTFT = 80 + sloCtx.AvgTPOT = 30 + sloCtx.PredictedTTFT = 85 + sloCtx.AvgPredictedTPOT = 32 + sloCtx.TTFTSLO = 100 + sloCtx.AvgTPOTSLO = 50 + sloCtx.IncomingModelName = "incoming-model" + router.setSLOContextForRequest(request, sloCtx) + + // Should record metrics without panicking + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify cleanup + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_ResponseComplete_NoSLOs(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test-id", 0, 0, true) // No SLOs + response := &requestcontrol.Response{} + + // Create queue + queue := NewRequestPriorityQueue() + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 0) + + sloCtx := NewSLORequestContext(request) + sloCtx.TTFT = 80 + sloCtx.AvgTPOT = 30 + sloCtx.IncomingModelName = "test-model" + router.setSLOContextForRequest(request, sloCtx) + + // Should handle missing SLOs gracefully + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify cleanup + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_CheckPredictor_NilPod(t *testing.T) { + router := createTestRouter() + logger := logr.Discard() + + result := router.CheckPredictor(logger, nil) + + assert.False(t, result) +} + +func TestSLOAwareRouter_CheckPredictor_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + logger := logr.Discard() + pod := createTestPod("test-pod", 1, 1, 1) + + result := router.CheckPredictor(logger, pod.GetPod()) + + assert.False(t, result) +} + +func TestSLOAwareRouter_CheckPredictor_Success(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + logger := logr.Discard() + pod := createTestPod("test-pod", 1, 1, 1) + + result := router.CheckPredictor(logger, pod.GetPod()) + + assert.True(t, result) +} + +func TestSLORequestContext_Fields(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := NewSLORequestContext(request) + + // Test all field initialization + assert.NotNil(t, ctx.LastSeenMetrics) + assert.NotNil(t, ctx.PrefixCacheScoresForPods) + assert.NotNil(t, ctx.PredictedTTFTForScheduling) + assert.NotNil(t, ctx.PredictedTPOTForScheduling) + assert.Empty(t, ctx.TPOTObservations) + assert.Empty(t, ctx.PredictedTPOTObservations) + assert.Zero(t, ctx.GeneratedTokenCount) + assert.Zero(t, ctx.TTFT) + assert.Zero(t, ctx.AvgTPOT) + assert.Nil(t, ctx.TargetPod) + assert.Nil(t, ctx.SchedulingResult) + assert.Nil(t, ctx.TokenSampler) +} + +func TestSLORequestContext_UpdateMetrics(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := NewSLORequestContext(request) + + // Add some metrics + metricsState := &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 3, + } + ctx.LastSeenMetrics["test-pod"] = metricsState + + assert.Len(t, ctx.LastSeenMetrics, 1) + assert.Equal(t, 0.5, ctx.LastSeenMetrics["test-pod"].KVCacheUsagePercent) + assert.Equal(t, 3, ctx.LastSeenMetrics["test-pod"].WaitingQueueSize) +} + +func TestSLORequestContext_PredictionData(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := NewSLORequestContext(request) + + // Set prediction data + ctx.PredictedTTFTForScheduling["pod1"] = 80.0 + ctx.PredictedTPOTForScheduling["pod1"] = 30.0 + ctx.PredictedTTFTForScheduling["pod2"] = 90.0 + ctx.PredictedTPOTForScheduling["pod2"] = 35.0 + + assert.Len(t, ctx.PredictedTTFTForScheduling, 2) + assert.Len(t, ctx.PredictedTPOTForScheduling, 2) + assert.Equal(t, 80.0, ctx.PredictedTTFTForScheduling["pod1"]) + assert.Equal(t, 30.0, ctx.PredictedTPOTForScheduling["pod1"]) +} + +func TestSLORequestContext_PrefixCacheScores(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := NewSLORequestContext(request) + + // Set prefix cache scores + ctx.PrefixCacheScoresForPods["pod1"] = 0.8 + ctx.PrefixCacheScoresForPods["pod2"] = 0.6 + ctx.PrefixCacheScoresForPods["pod3"] = 0.9 + + assert.Len(t, ctx.PrefixCacheScoresForPods, 3) + assert.Equal(t, 0.8, ctx.PrefixCacheScoresForPods["pod1"]) + assert.Equal(t, 0.9, ctx.PrefixCacheScoresForPods["pod3"]) +} + +func TestSLOAwareRouter_ConcurrentContextAccess(t *testing.T) { + router := createTestRouter() + + // Test concurrent access to context store + var wg sync.WaitGroup + numGoroutines := 100 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + requestID := uuid.New().String() + request := createTestLLMRequest(requestID, 100, 50, true) + sloCtx := NewSLORequestContext(request) + + // Set context + router.setSLOContextForRequest(request, sloCtx) + + // Get context + retrievedCtx, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) + assert.NotNil(t, retrievedCtx) + + // Delete context + router.deleteSLOContextForRequest(request) + }(i) + } + + wg.Wait() +} + +func TestSLOAwareRouter_MultipleRequests_SamePod(t *testing.T) { + router := createTestRouter() + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + + request1 := createTestLLMRequest("test-id-1", 100, 50, true) + request2 := createTestLLMRequest("test-id-2", 100, 50, true) + request3 := createTestLLMRequest("test-id-3", 100, 50, true) + + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create and set SLO contexts + for _, req := range []*schedulingtypes.LLMRequest{request1, request2, request3} { + sloCtx := NewSLORequestContext(req) + sloCtx.AvgTPOTSLO = 50 + router.setSLOContextForRequest(req, sloCtx) + } + + // Add all requests + router.PreRequest(ctx, request1, schedulingResult) + router.PreRequest(ctx, request2, schedulingResult) + router.PreRequest(ctx, request3, schedulingResult) + + // Verify queue has all requests + queue, exists := router.runningRequestLists[pod.GetPod().NamespacedName] + assert.True(t, exists) + assert.NotNil(t, queue) +} + +func TestSLOAwareRouter_RequestLifecycle_Complete(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create initial context + sloCtx := NewSLORequestContext(request) + sloCtx.AvgTPOTSLO = 50 + sloCtx.IncomingModelName = "test-model" + router.setSLOContextForRequest(request, sloCtx) + + // 1. PreRequest + router.PreRequest(ctx, request, schedulingResult) + + // Verify context exists + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.NotNil(t, retrievedCtx.TargetPod) + + // 2. ResponseReceived + router.ResponseReceived(ctx, request, response, pod.GetPod()) + + // 3. ResponseStreaming (first token) + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // 4. ResponseStreaming (subsequent tokens) + retrievedCtx, _ = router.getSLOContextForRequest(request) + retrievedCtx.TTFT = 100 // Mark first token received + router.setSLOContextForRequest(request, retrievedCtx) + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // 5. ResponseComplete + retrievedCtx, _ = router.getSLOContextForRequest(request) + retrievedCtx.TTFT = 80 + retrievedCtx.AvgTPOT = 30 + router.setSLOContextForRequest(request, retrievedCtx) + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify context was cleaned up + _, err = router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_MultipleRequests_DifferentPods(t *testing.T) { + router := createTestRouter() + ctx := context.Background() + + pod1 := createTestPod("test-pod-1", 1, 1, 1) + pod2 := createTestPod("test-pod-2", 1, 1, 1) + + request1 := createTestLLMRequest("test-id-1", 100, 50, true) + request2 := createTestLLMRequest("test-id-2", 100, 50, true) + + schedulingResult1 := createTestSchedulingResult(pod1.GetPod(), 1, 1, 1) + schedulingResult2 := createTestSchedulingResult(pod2.GetPod(), 1, 1, 1) + + // Create and set SLO contexts + sloCtx1 := NewSLORequestContext(request1) + sloCtx1.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request1, sloCtx1) + + sloCtx2 := NewSLORequestContext(request2) + sloCtx2.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request2, sloCtx2) + + // Add requests to different pods + router.PreRequest(ctx, request1, schedulingResult1) + router.PreRequest(ctx, request2, schedulingResult2) + + // Verify separate queues were created + queue1, exists1 := router.runningRequestLists[pod1.GetPod().NamespacedName] + queue2, exists2 := router.runningRequestLists[pod2.GetPod().NamespacedName] + + assert.True(t, exists1) + assert.True(t, exists2) + assert.NotNil(t, queue1) + assert.NotNil(t, queue2) + assert.NotEqual(t, queue1, queue2) +} + +func TestSLORequestContext_SLOValidation(t *testing.T) { + tests := []struct { + name string + ttftSLO float64 + tpotSLO float64 + expectSLOs bool + }{ + { + name: "Both SLOs set", + ttftSLO: 100, + tpotSLO: 50, + expectSLOs: true, + }, + { + name: "No SLOs", + ttftSLO: 0, + tpotSLO: 0, + expectSLOs: false, + }, + { + name: "Only TTFT SLO", + ttftSLO: 100, + tpotSLO: 0, + expectSLOs: false, + }, + { + name: "Only TPOT SLO", + ttftSLO: 0, + tpotSLO: 50, + expectSLOs: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + request := createTestLLMRequest("test-id", tt.ttftSLO, tt.tpotSLO, true) + ctx := NewSLORequestContext(request) + ctx.TTFTSLO = tt.ttftSLO + ctx.AvgTPOTSLO = tt.tpotSLO + + hasBothSLOs := ctx.TTFTSLO > 0 && ctx.AvgTPOTSLO > 0 + assert.Equal(t, tt.expectSLOs, hasBothSLOs) + }) + } +} + +// Benchmark tests + +func BenchmarkSLOAwareRouter_PreRequest(b *testing.B) { + router := createTestRouter() + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + requestID := uuid.New().String() + request := createTestLLMRequest(requestID, 100, 50, true) + sloCtx := NewSLORequestContext(request) + sloCtx.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request, sloCtx) + router.PreRequest(ctx, request, schedulingResult) + } +} + +func BenchmarkSLOAwareRouter_ContextOperations(b *testing.B) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + sloCtx := NewSLORequestContext(request) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + router.setSLOContextForRequest(request, sloCtx) + _, _ = router.getSLOContextForRequest(request) + router.deleteSLOContextForRequest(request) + } +} + +func BenchmarkSLORequestContext_Creation(b *testing.B) { + request := createTestLLMRequest("test", 100, 50, true) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = NewSLORequestContext(request) + } +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go index 1199be641..ce1e997b0 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go @@ -1,3 +1,19 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package slo_aware_router import ( diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go index 758ae401b..bdeca3037 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go @@ -1,4 +1,18 @@ -// NewTokenSampler creates a new sampler with deterministic seeding +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ package slo_aware_router diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go index d4637b96f..da073ff65 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go @@ -27,6 +27,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" ) @@ -115,8 +116,9 @@ func createTestPod(name string, kvCacheUsage float64, runningQueueSize, waitingQ } } -func createTestLLMRequest(ttftSLO, tpotSLO float64, predictionBased bool) *schedulingtypes.LLMRequest { +func createTestLLMRequest(reqID string, ttftSLO, tpotSLO float64, predictionBased bool) *schedulingtypes.LLMRequest { headers := make(map[string]string) + headers[requtil.RequestIdHeaderKey] = reqID if ttftSLO > 0 { headers["x-ttft-slo"] = fmt.Sprintf("%f", ttftSLO) } @@ -149,7 +151,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { name: "Prediction-based scheduling disabled", predictor: &mockPredictor{}, strategy: HeadroomStrategyLeast, - request: createTestLLMRequest(1.0, 0.05, false), // predictionBased = false + request: createTestLLMRequest("test", 1.0, 0.05, false), // predictionBased = false pods: []schedulingtypes.Pod{ createTestPod("pod1", 0.5, 2, 1), // 50% KV cache, 2 running, 1 waiting createTestPod("pod2", 0.7, 3, 2), // 70% KV cache, 3 running, 2 waiting @@ -160,7 +162,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { name: "No predictor configured", predictor: nil, strategy: HeadroomStrategyLeast, - request: createTestLLMRequest(1.0, 0.05, true), + request: createTestLLMRequest("test", 1.0, 0.05, true), pods: []schedulingtypes.Pod{ createTestPod("pod1", 0.5, 2, 1), }, @@ -176,7 +178,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { }, }, strategy: HeadroomStrategyLeast, - request: createTestLLMRequest(1.0, 0.05, true), + request: createTestLLMRequest("test", 1.0, 0.05, true), pods: []schedulingtypes.Pod{ createTestPod("pod1", 0.5, 2, 1), // 50% KV cache createTestPod("pod2", 0.6, 3, 2), // 60% KV cache @@ -196,7 +198,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { }, }, strategy: HeadroomStrategyLeast, - request: createTestLLMRequest(1.0, 0.05, true), + request: createTestLLMRequest("test", 1.0, 0.05, true), pods: []schedulingtypes.Pod{ createTestPod("pod1", 0.8, 5, 3), // 80% KV cache, high load createTestPod("pod2", 0.9, 6, 4), // 90% KV cache, very high load @@ -213,7 +215,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { }, }, strategy: HeadroomStrategyLeast, - request: createTestLLMRequest(1.0, 0.05, true), + request: createTestLLMRequest("test", 1.0, 0.05, true), pods: []schedulingtypes.Pod{ createTestPod("pod-positive", 0.3, 1, 0), // Low KV cache, positive headroom createTestPod("pod-negative", 0.9, 6, 4), // High KV cache, negative headroom @@ -227,7 +229,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { err: fmt.Errorf("prediction failed"), }, strategy: HeadroomStrategyLeast, - request: createTestLLMRequest(1.0, 0.05, true), + request: createTestLLMRequest("test", 1.0, 0.05, true), pods: []schedulingtypes.Pod{ createTestPod("pod1", 0.5, 2, 1), createTestPod("pod2", 0.6, 3, 2), @@ -241,7 +243,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { name: "Empty pod list", predictor: &mockPredictor{}, strategy: HeadroomStrategyLeast, - request: createTestLLMRequest(1.0, 0.05, true), + request: createTestLLMRequest("test", 1.0, 0.05, true), pods: []schedulingtypes.Pod{}, // Should return empty scores map expectedScores: map[string]float64{}, @@ -331,7 +333,7 @@ func TestSLOAwareRouter_Strategies(t *testing.T) { } router := NewSLOAwareRouter(predictor, tt.strategy) - request := createTestLLMRequest(1.0, 0.05, true) + request := createTestLLMRequest("test", 1.0, 0.05, true) pods := []schedulingtypes.Pod{ createTestPod("pod1", 0.5, 2, 1), createTestPod("pod2", 0.6, 3, 2), diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go index 34618ce19..eeab50433 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go @@ -1,5 +1,6 @@ /* -© 2025 The Kubernetes Authors. +Copyright 2025 The Kubernetes Authors. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -7,7 +8,10 @@ You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License. +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ // Package requestcontrol contains helpers to decouple latency-predictor logic. diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go index 036b5ee76..8030866d8 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go @@ -1,5 +1,6 @@ /* -© 2025 The Kubernetes Authors. +Copyright 2025 The Kubernetes Authors. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -7,7 +8,10 @@ You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License. +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ // Package requestcontrol contains helpers to decouple latency-predictor logic. From 37fe013335cefd2ac8fe8fb514b482fef753cd2e Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Sat, 8 Nov 2025 00:41:13 +0000 Subject: [PATCH 5/6] Remove unnecessary test in requestcontrol_hooks_test --- .../requestcontrol_hooks_test.go | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go index cc40612df..71903b626 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go @@ -178,29 +178,6 @@ func TestSLOAwareRouter_PreRequest_Success(t *testing.T) { retrievedCtx.RequestReceivedTimestamp.Equal(afterTime)) } -func TestSLOAwareRouter_PreRequest_GeneratesRequestID(t *testing.T) { - router := createTestRouter() - ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("", 100, 50, true) - request.Headers[requtil.RequestIdHeaderKey] = "" // Explicitly empty - schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) - - // Create and set initial SLO context - sloCtx := NewSLORequestContext(request) - sloCtx.AvgTPOTSLO = 50 - - // Since request ID is empty initially, we need to handle this - // The PreRequest should generate a new ID, so let's test that - router.PreRequest(ctx, request, schedulingResult) - - // Request ID should now be set - assert.NotEmpty(t, request.Headers[requtil.RequestIdHeaderKey]) - // Verify it's a valid UUID format - _, err := uuid.Parse(request.Headers[requtil.RequestIdHeaderKey]) - assert.NoError(t, err, "Generated request ID should be a valid UUID") -} - func TestSLOAwareRouter_PreRequest_AddsToQueue(t *testing.T) { router := createTestRouter() ctx := context.Background() From b2a7d45ec560fc28287c87a6bdb8fe917cd61f69 Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Mon, 10 Nov 2025 21:20:22 +0000 Subject: [PATCH 6/6] Fix streamed request being called one final time after request complete, add predictor check to the beginning of each requestcontrol hook --- pkg/epp/requestcontrol/director.go | 5 +++-- .../slo_aware_router/requestcontrol_hooks.go | 17 ++++++++++------- .../requestcontrol_hooks_test.go | 15 +++++++++++++++ 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index f6f7deebe..55eec0694 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -269,8 +269,9 @@ func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *hand logger := log.FromContext(ctx).WithValues("stage", "bodyChunk") logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk") response := &Response{ - RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], - Headers: reqCtx.Response.Headers, + RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], + Headers: reqCtx.Response.Headers, + EndOfStream: reqCtx.ResponseComplete, } d.runResponseStreamingPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod) diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go index 789a9d51c..f865bbeb3 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go @@ -113,6 +113,9 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype } targetPod := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName].TargetPods[0].GetPod() + if !t.CheckPredictor(logger, targetPod) { + return + } podName := types.NamespacedName{ Name: targetPod.NamespacedName.Name, @@ -153,6 +156,10 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) { logger := log.FromContext(ctx) + if !t.CheckPredictor(logger, targetPod) { + return + } + id := request.Headers[requtil.RequestIdHeaderKey] sloCtx, err := t.getSLOContextForRequest(request) @@ -161,10 +168,6 @@ func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *scheduli return } - if !t.CheckPredictor(logger, targetPod) { - return - } - if err := ProcessHeaderForLatencyPrediction(ctx, t.latencypredictor, sloCtx); err != nil { logger.V(logutil.DEBUG).Error(err, "ProcessHeader in latencypredictor failed") } @@ -173,7 +176,7 @@ func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *scheduli func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) { logger := log.FromContext(ctx) - if !t.CheckPredictor(logger, pod) { + if !t.CheckPredictor(logger, pod) || response.EndOfStream { return } @@ -248,11 +251,11 @@ func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *scheduli func (t *SLOAwareRouter) CheckPredictor(logger logr.Logger, targetPod *backend.Pod) bool { if targetPod == nil { - logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping PostResponse because no target pod was provided.") + logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping hook because no target pod was provided.") return false } if t.latencypredictor == nil { - logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping PostResponse because predictor missing") + logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping hook because predictor missing") return false } return true diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go index 71903b626..96999af2f 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go @@ -150,6 +150,9 @@ func TestSLOAwareRouter_PreRequest_EmptySchedulingResult(t *testing.T) { func TestSLOAwareRouter_PreRequest_Success(t *testing.T) { router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) @@ -180,6 +183,9 @@ func TestSLOAwareRouter_PreRequest_Success(t *testing.T) { func TestSLOAwareRouter_PreRequest_AddsToQueue(t *testing.T) { router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) @@ -201,6 +207,9 @@ func TestSLOAwareRouter_PreRequest_AddsToQueue(t *testing.T) { func TestSLOAwareRouter_PreRequest_QueueAlreadyExists(t *testing.T) { router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) request1 := createTestLLMRequest("test-id-1", 100, 50, true) @@ -729,6 +738,9 @@ func TestSLOAwareRouter_ConcurrentContextAccess(t *testing.T) { func TestSLOAwareRouter_MultipleRequests_SamePod(t *testing.T) { router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) @@ -807,6 +819,9 @@ func TestSLOAwareRouter_RequestLifecycle_Complete(t *testing.T) { func TestSLOAwareRouter_MultipleRequests_DifferentPods(t *testing.T) { router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + ctx := context.Background() pod1 := createTestPod("test-pod-1", 1, 1, 1)