Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 126 additions & 30 deletions cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"net/http/pprof"
"os"
"runtime"
"strconv"
"strings"
"sync/atomic"

"github.com/go-logr/logr"
Expand Down Expand Up @@ -100,6 +102,8 @@ var (
poolName = flag.String("pool-name", runserver.DefaultPoolName, "Name of the InferencePool this Endpoint Picker is associated with.")
poolGroup = flag.String("pool-group", runserver.DefaultPoolGroup, "group of the InferencePool this Endpoint Picker is associated with.")
poolNamespace = flag.String("pool-namespace", "", "Namespace of the InferencePool this Endpoint Picker is associated with.")
selector = flag.String("selector", "", "selector to filter pods on. Format: a comma-separated list of labels, e.g., 'app: vllm-llama3-8b-instruct,env=prod'.")
targetPorts = flag.String("target-ports", "", "target ports of model server pods. Format: a comma-separated list of labels, e.g., '3000,3001,3002'")
logVerbosity = flag.Int("v", logging.DEFAULT, "number for the log level verbosity")
secureServing = flag.Bool("secure-serving", runserver.DefaultSecureServing, "Enables secure serving. Defaults to true.")
healthChecking = flag.Bool("health-checking", runserver.DefaultHealthChecking, "Enables health checking")
Expand Down Expand Up @@ -194,14 +198,72 @@ func (r *Runner) Run(ctx context.Context) error {
setupLog.Error(err, "Failed to get Kubernetes rest config")
return err
}
// Setup EndPointsPool
endPointsPool := datalayer.NewEndPointsPool(false, common.GKNN{})
if *poolName != "" {
// Determine pool namespace: if --pool-namespace is non-empty, use it; else NAMESPACE env var; else default
resolvePoolNamespace := func() string {
if *poolNamespace != "" {
return *poolNamespace
}
if nsEnv := os.Getenv("NAMESPACE"); nsEnv != "" {
return nsEnv
}
return runserver.DefaultPoolNamespace
}
resolvedPoolNamespace := resolvePoolNamespace()
poolNamespacedName := types.NamespacedName{
Name: *poolName,
Namespace: resolvedPoolNamespace,
}
poolGroupKind := schema.GroupKind{
Group: *poolGroup,
Kind: "InferencePool",
}
poolGKNN := common.GKNN{
NamespacedName: poolNamespacedName,
GroupKind: poolGroupKind,
}
endPointsPool.GKNN = poolGKNN
}

if *selector != "" {
endPointsPool.EndPoints.Selector, err = strToMap(*selector)
if err != nil {
setupLog.Error(err, "Failed to parse flag %q with error: %w", "selector", err)
return err
}
endPointsPool.EndPoints.TargetPorts, err = strToUniqueIntSlice(*targetPorts)
if err != nil {
setupLog.Error(err, "Failed to parse flag %q with error: %w", "target-ports", err)
}
endPointsPool.StandaloneMode = true

// Determine EPP namespace: NAMESPACE env var; else default
eppNsEnv := os.Getenv("EPP_NAMESPACE")
if eppNsEnv == "" {
setupLog.Error(err, "Failed to get environment variable EPP_NAMESPACE")
}
// Determine EPP name: EPP_NAME env var
eppNameEnv := os.Getenv("EPP_NAME")
if eppNameEnv == "" {
setupLog.Error(err, "Failed to get environment variable EPP_NAME")

}
endPointsPool.GKNN = common.GKNN{
NamespacedName: types.NamespacedName{Namespace: eppNsEnv, Name: eppNameEnv},
GroupKind: schema.GroupKind{Kind: "apps", Group: "Deployment"},
}

}

// --- Setup Datastore ---
useDatalayerV2 := env.GetEnvBool(enableExperimentalDatalayerV2, false, setupLog)
epf, err := r.setupMetricsCollection(setupLog, useDatalayerV2)
if err != nil {
return err
}
datastore := datastore.NewDatastore(ctx, epf, int32(*modelServerMetricsPort))
datastore := datastore.NewDatastore(ctx, epf, int32(*modelServerMetricsPort), endPointsPool)

// --- Setup Metrics Server ---
customCollectors := []prometheus.Collector{collectors.NewInferencePoolMetricsCollector(datastore)}
Expand All @@ -223,34 +285,10 @@ func (r *Runner) Run(ctx context.Context) error {
}(),
}

// Determine pool namespace: if --pool-namespace is non-empty, use it; else NAMESPACE env var; else default
resolvePoolNamespace := func() string {
if *poolNamespace != "" {
return *poolNamespace
}
if nsEnv := os.Getenv("NAMESPACE"); nsEnv != "" {
return nsEnv
}
return runserver.DefaultPoolNamespace
}
resolvedPoolNamespace := resolvePoolNamespace()
poolNamespacedName := types.NamespacedName{
Name: *poolName,
Namespace: resolvedPoolNamespace,
}
poolGroupKind := schema.GroupKind{
Group: *poolGroup,
Kind: "InferencePool",
}
poolGKNN := common.GKNN{
NamespacedName: poolNamespacedName,
GroupKind: poolGroupKind,
}

isLeader := &atomic.Bool{}
isLeader.Store(false)

mgr, err := runserver.NewDefaultManager(poolGKNN, cfg, metricsServerOptions, *haEnableLeaderElection)
mgr, err := runserver.NewDefaultManager(endPointsPool, cfg, metricsServerOptions, *haEnableLeaderElection)
if err != nil {
setupLog.Error(err, "Failed to create controller manager")
return err
Expand Down Expand Up @@ -339,8 +377,7 @@ func (r *Runner) Run(ctx context.Context) error {
// --- Setup ExtProc Server Runner ---
serverRunner := &runserver.ExtProcServerRunner{
GrpcPort: *grpcPort,
PoolNamespacedName: poolNamespacedName,
PoolGKNN: poolGKNN,
EndPointsPool: endPointsPool,
Datastore: datastore,
SecureServing: *secureServing,
HealthChecking: *healthChecking,
Expand Down Expand Up @@ -547,9 +584,19 @@ func registerHealthServer(mgr manager.Manager, logger logr.Logger, ds datastore.
}

func validateFlags() error {
if *poolName == "" {
return fmt.Errorf("required %q flag not set", "poolName")
if (*poolName != "" && *selector != "") || (*poolName == "" && *selector == "") {
return errors.New("either poolName or selector must be set")
}
if *selector != "" {
targetPortsList, err := strToUniqueIntSlice(*targetPorts)
if err != nil {
return fmt.Errorf("unexpected value for %q flag with error %w", "target-ports", err)
}
if len(targetPortsList) == 0 || len(targetPortsList) > 8 {
return fmt.Errorf("flag %q should have length from 1 to 8", "target-ports")
}
}

if *configText != "" && *configFile != "" {
return fmt.Errorf("both the %q and %q flags can not be set at the same time", "configText", "configFile")
}
Expand All @@ -560,6 +607,55 @@ func validateFlags() error {
return nil
}

func strToUniqueIntSlice(s string) ([]int, error) {
seen := make(map[int]struct{})
var intList []int

if s == "" {
return intList, nil
}

strList := strings.Split(s, ",")

for _, str := range strList {
trimmedStr := strings.TrimSpace(str)
if trimmedStr == "" {
continue
}
portInt, err := strconv.Atoi(trimmedStr)
if err != nil {
return nil, fmt.Errorf("invalid number: '%s' is not an integer", trimmedStr)
}

if _, ok := seen[portInt]; !ok {
seen[portInt] = struct{}{}
intList = append(intList, portInt)
}
}
return intList, nil
}

func strToMap(s string) (map[string]string, error) {
m := make(map[string]string)
if s == "" {
return m, nil
}

mPairs := strings.Split(s, ",")
for _, pair := range mPairs {
trimmedPair := strings.TrimSpace(pair)
if trimmedPair == "" {
continue
}
kv := strings.Split(trimmedPair, ":")
if len(kv) != 2 {
return nil, errors.New("invalid format, expected key:value paris")
}
m[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
}
return m, nil
}

func verifyMetricMapping(mapping backendmetrics.MetricMapping, logger logr.Logger) {
if mapping.TotalQueuedRequests == nil {
logger.Info("Not scraping metric: TotalQueuedRequests")
Expand Down
6 changes: 3 additions & 3 deletions pkg/epp/backend/metrics/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func refreshPrometheusMetrics(logger logr.Logger, datastore datalayer.PoolInfo,
}

podTotalCount := len(podMetrics)
metrics.RecordInferencePoolAvgKVCache(pool.Name, kvCacheTotal/float64(podTotalCount))
metrics.RecordInferencePoolAvgQueueSize(pool.Name, float64(queueTotal/podTotalCount))
metrics.RecordInferencePoolReadyPods(pool.Name, float64(podTotalCount))
metrics.RecordInferencePoolAvgKVCache(pool.GKNN.Name, kvCacheTotal/float64(podTotalCount))
metrics.RecordInferencePoolAvgQueueSize(pool.GKNN.Name, float64(queueTotal/podTotalCount))
metrics.RecordInferencePoolReadyPods(pool.GKNN.Name, float64(podTotalCount))
}
5 changes: 2 additions & 3 deletions pkg/epp/backend/metrics/pod_metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/types"

v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
)

Expand Down Expand Up @@ -86,8 +85,8 @@ func TestMetricsRefresh(t *testing.T) {

type fakeDataStore struct{}

func (f *fakeDataStore) PoolGet() (*v1.InferencePool, error) {
return &v1.InferencePool{Spec: v1.InferencePoolSpec{TargetPorts: []v1.Port{{Number: 8000}}}}, nil
func (f *fakeDataStore) PoolGet() (*datalayer.EndPointsPool, error) {
return &datalayer.EndPointsPool{}, nil
}

func (f *fakeDataStore) PodList(func(PodMetrics) bool) []PodMetrics {
Expand Down
35 changes: 19 additions & 16 deletions pkg/epp/controller/inferenceobjective_reconciler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"testing"
"time"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"

"github.com/google/go-cmp/cmp"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
Expand All @@ -36,16 +38,17 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pool"
utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing"
)

var (
pool = utiltest.MakeInferencePool("test-pool1").Namespace("ns1").ObjRef()
inferencePool = utiltest.MakeInferencePool("test-pool1").Namespace("ns1").ObjRef()
infObjective1 = utiltest.MakeInferenceObjective("model1").
Namespace(pool.Namespace).
Namespace(inferencePool.Namespace).
Priority(1).
CreationTimestamp(metav1.Unix(1000, 0)).
PoolName(pool.Name).
PoolName(inferencePool.Name).
PoolGroup("inference.networking.k8s.io").ObjRef()
infObjective1Pool2 = utiltest.MakeInferenceObjective(infObjective1.Name).
Namespace(infObjective1.Namespace).
Expand All @@ -57,24 +60,24 @@ var (
Namespace(infObjective1.Namespace).
Priority(2).
CreationTimestamp(metav1.Unix(1003, 0)).
PoolName(pool.Name).
PoolName(inferencePool.Name).
PoolGroup("inference.networking.k8s.io").ObjRef()
infObjective1Deleted = utiltest.MakeInferenceObjective(infObjective1.Name).
Namespace(infObjective1.Namespace).
CreationTimestamp(metav1.Unix(1004, 0)).
DeletionTimestamp().
PoolName(pool.Name).
PoolName(inferencePool.Name).
PoolGroup("inference.networking.k8s.io").ObjRef()
infObjective1DiffGroup = utiltest.MakeInferenceObjective(infObjective1.Name).
Namespace(pool.Namespace).
Namespace(inferencePool.Namespace).
Priority(1).
CreationTimestamp(metav1.Unix(1005, 0)).
PoolName(pool.Name).
PoolName(inferencePool.Name).
PoolGroup("inference.networking.x-k8s.io").ObjRef()
infObjective2 = utiltest.MakeInferenceObjective("model2").
Namespace(pool.Namespace).
Namespace(inferencePool.Namespace).
CreationTimestamp(metav1.Unix(1000, 0)).
PoolName(pool.Name).
PoolName(inferencePool.Name).
PoolGroup("inference.networking.k8s.io").ObjRef()
)

Expand Down Expand Up @@ -120,7 +123,7 @@ func TestInferenceObjectiveReconciler(t *testing.T) {
{
name: "Objective not found, no matching existing objective to delete",
objectivessInStore: []*v1alpha2.InferenceObjective{infObjective1},
incomingReq: &types.NamespacedName{Name: "non-existent-objective", Namespace: pool.Namespace},
incomingReq: &types.NamespacedName{Name: "non-existent-objective", Namespace: inferencePool.Namespace},
wantObjectives: []*v1alpha2.InferenceObjective{infObjective1},
},
{
Expand Down Expand Up @@ -160,17 +163,18 @@ func TestInferenceObjectiveReconciler(t *testing.T) {
WithObjects(initObjs...).
Build()
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
ds := datastore.NewDatastore(t.Context(), pmf, 0)
ds := datastore.NewDatastore(t.Context(), pmf, 0, datalayer.NewEndPointsPool(false, pool.ToGKNN(inferencePool)))
for _, m := range test.objectivessInStore {
ds.ObjectiveSet(m)
}
_ = ds.PoolSet(context.Background(), fakeClient, pool)
endPointsPool := pool.InferencePoolToEndPointsPool(inferencePool)
_ = ds.PoolSet(context.Background(), fakeClient, endPointsPool)
reconciler := &InferenceObjectiveReconciler{
Reader: fakeClient,
Datastore: ds,
PoolGKNN: common.GKNN{
NamespacedName: types.NamespacedName{Name: pool.Name, Namespace: pool.Namespace},
GroupKind: schema.GroupKind{Group: pool.GroupVersionKind().Group, Kind: pool.GroupVersionKind().Kind},
NamespacedName: types.NamespacedName{Name: inferencePool.Name, Namespace: inferencePool.Namespace},
GroupKind: schema.GroupKind{Group: inferencePool.GroupVersionKind().Group, Kind: inferencePool.GroupVersionKind().Kind},
},
}
if test.incomingReq == nil {
Expand All @@ -190,8 +194,7 @@ func TestInferenceObjectiveReconciler(t *testing.T) {
if len(test.wantObjectives) != len(ds.ObjectiveGetAll()) {
t.Errorf("Unexpected; want: %d, got:%d", len(test.wantObjectives), len(ds.ObjectiveGetAll()))
}

if diff := diffStore(ds, diffStoreParams{wantPool: pool, wantObjectives: test.wantObjectives}); diff != "" {
if diff := diffStore(ds, diffStoreParams{wantPool: endPointsPool, wantObjectives: test.wantObjectives}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}

Expand Down
18 changes: 6 additions & 12 deletions pkg/epp/controller/inferencepool_reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ import (
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"

v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
pooltuil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pool"
)

// InferencePoolReconciler utilizes the controller runtime to reconcile Instance Gateway resources
Expand Down Expand Up @@ -75,25 +77,17 @@ func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Reques
c.Datastore.Clear()
return ctrl.Result{}, nil
}

// 4. Convert the fetched object to the canonical v1.InferencePool.
v1infPool := &v1.InferencePool{}

var endPointsPool *datalayer.EndPointsPool
switch pool := obj.(type) {
case *v1.InferencePool:
// If it's already a v1 object, just use it.
v1infPool = pool
endPointsPool = pooltuil.InferencePoolToEndPointsPool(pool)
case *v1alpha2.InferencePool:
var err error
err = pool.ConvertTo(v1infPool)
if err != nil {
return ctrl.Result{}, fmt.Errorf("failed to convert XInferencePool to InferencePool - %w", err)
}
endPointsPool = pooltuil.AlphaInferencePoolToEndPointsPool(pool)
default:
return ctrl.Result{}, fmt.Errorf("unsupported API group: %s", c.PoolGKNN.Group)
}

if err := c.Datastore.PoolSet(ctx, c.Reader, v1infPool); err != nil {
if err := c.Datastore.PoolSet(ctx, c.Reader, endPointsPool); err != nil {
return ctrl.Result{}, fmt.Errorf("failed to update datastore - %w", err)
}

Expand Down
Loading