Skip to content

Commit 15aac16

Browse files
committed
infModelRewrite reconciler logic.
1 parent c749e2d commit 15aac16

File tree

3 files changed

+353
-0
lines changed

3 files changed

+353
-0
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package controller
18+
19+
import (
20+
"context"
21+
"fmt"
22+
23+
"k8s.io/apimachinery/pkg/api/errors"
24+
ctrl "sigs.k8s.io/controller-runtime"
25+
"sigs.k8s.io/controller-runtime/pkg/client"
26+
"sigs.k8s.io/controller-runtime/pkg/event"
27+
"sigs.k8s.io/controller-runtime/pkg/log"
28+
"sigs.k8s.io/controller-runtime/pkg/predicate"
29+
30+
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
31+
"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
32+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
33+
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
34+
)
35+
36+
type InferenceModelRewriteReconciler struct {
37+
client.Reader
38+
Datastore datastore.Datastore
39+
PoolGKNN common.GKNN
40+
}
41+
42+
func (c *InferenceModelRewriteReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
43+
logger := log.FromContext(ctx).V(logutil.DEFAULT)
44+
ctx = ctrl.LoggerInto(ctx, logger)
45+
46+
logger.Info("Reconciling InferenceModelRewrite")
47+
48+
infModelRewrite := &v1alpha2.InferenceModelRewrite{}
49+
notFound := false
50+
if err := c.Get(ctx, req.NamespacedName, infModelRewrite); err != nil {
51+
if !errors.IsNotFound(err) {
52+
return ctrl.Result{}, fmt.Errorf("unable to get InferenceModelRewrite - %w", err)
53+
}
54+
notFound = true
55+
}
56+
57+
if notFound || !infModelRewrite.DeletionTimestamp.IsZero() || infModelRewrite.Spec.PoolRef.Name != v1alpha2.ObjectName(c.PoolGKNN.Name) {
58+
// InferenceModelRewrite object got deleted or changed the referenced pool.
59+
c.Datastore.RewriteDelete(req.NamespacedName)
60+
return ctrl.Result{}, nil
61+
}
62+
63+
// Add or update if the InferenceModelRewrite instance has a creation timestamp older than the existing entry of the model.
64+
logger = logger.WithValues("poolRef", infModelRewrite.Spec.PoolRef)
65+
c.Datastore.RewriteSet(infModelRewrite)
66+
logger.Info("Added/Updated InferenceModelRewrite")
67+
68+
return ctrl.Result{}, nil
69+
}
70+
71+
func (c *InferenceModelRewriteReconciler) SetupWithManager(ctx context.Context, mgr ctrl.Manager) error {
72+
return ctrl.NewControllerManagedBy(mgr).
73+
For(&v1alpha2.InferenceModelRewrite{}).
74+
WithEventFilter(predicate.Funcs{
75+
CreateFunc: func(e event.CreateEvent) bool { return c.eventPredicate(e.Object.(*v1alpha2.InferenceModelRewrite)) },
76+
UpdateFunc: func(e event.UpdateEvent) bool {
77+
return c.eventPredicate(e.ObjectOld.(*v1alpha2.InferenceModelRewrite)) || c.eventPredicate(e.ObjectNew.(*v1alpha2.InferenceModelRewrite))
78+
},
79+
DeleteFunc: func(e event.DeleteEvent) bool { return c.eventPredicate(e.Object.(*v1alpha2.InferenceModelRewrite)) },
80+
GenericFunc: func(e event.GenericEvent) bool { return c.eventPredicate(e.Object.(*v1alpha2.InferenceModelRewrite)) },
81+
}).
82+
Complete(c)
83+
}
84+
85+
func (c *InferenceModelRewriteReconciler) eventPredicate(infModelRewrite *v1alpha2.InferenceModelRewrite) bool {
86+
return string(infModelRewrite.Spec.PoolRef.Name) == c.PoolGKNN.Name && string(infModelRewrite.Spec.PoolRef.Group) == c.PoolGKNN.Group
87+
}
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package controller
18+
19+
import (
20+
"context"
21+
"testing"
22+
"time"
23+
24+
"github.com/google/go-cmp/cmp"
25+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
26+
"k8s.io/apimachinery/pkg/runtime"
27+
"k8s.io/apimachinery/pkg/runtime/schema"
28+
"k8s.io/apimachinery/pkg/types"
29+
clientgoscheme "k8s.io/client-go/kubernetes/scheme"
30+
ctrl "sigs.k8s.io/controller-runtime"
31+
"sigs.k8s.io/controller-runtime/pkg/client"
32+
"sigs.k8s.io/controller-runtime/pkg/client/fake"
33+
34+
v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
35+
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
36+
"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
37+
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
38+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
39+
utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing"
40+
)
41+
42+
var (
43+
poolForRewrite = utiltest.MakeInferencePool("test-pool1").Namespace("ns1").ObjRef()
44+
rewrite1 = makeInferenceModelRewrite("rewrite1").
45+
Namespace(poolForRewrite.Namespace).
46+
PoolName(poolForRewrite.Name).
47+
CreationTimestamp(metav1.Unix(1000, 0)).
48+
ObjRef()
49+
rewrite1Pool2 = makeInferenceModelRewrite(rewrite1.Name).
50+
Namespace(rewrite1.Namespace).
51+
PoolName("test-pool2").
52+
CreationTimestamp(metav1.Unix(1001, 0)).
53+
ObjRef()
54+
rewrite1Updated = makeInferenceModelRewrite(rewrite1.Name).
55+
Namespace(rewrite1.Namespace).
56+
PoolName(poolForRewrite.Name).
57+
CreationTimestamp(metav1.Unix(1003, 0)).
58+
Rules([]v1alpha2.InferenceModelRewriteRule{{}}).
59+
ObjRef()
60+
rewrite1Deleted = makeInferenceModelRewrite(rewrite1.Name).
61+
Namespace(rewrite1.Namespace).
62+
PoolName(poolForRewrite.Name).
63+
CreationTimestamp(metav1.Unix(1004, 0)).
64+
DeletionTimestamp().
65+
ObjRef()
66+
rewrite2 = makeInferenceModelRewrite("rewrite2").
67+
Namespace(poolForRewrite.Namespace).
68+
PoolName(poolForRewrite.Name).
69+
CreationTimestamp(metav1.Unix(1000, 0)).
70+
ObjRef()
71+
)
72+
73+
type inferenceModelRewriteBuilder struct {
74+
*v1alpha2.InferenceModelRewrite
75+
}
76+
77+
func makeInferenceModelRewrite(name string) *inferenceModelRewriteBuilder {
78+
return &inferenceModelRewriteBuilder{
79+
&v1alpha2.InferenceModelRewrite{
80+
ObjectMeta: metav1.ObjectMeta{
81+
Name: name,
82+
},
83+
},
84+
}
85+
}
86+
87+
func (b *inferenceModelRewriteBuilder) Namespace(ns string) *inferenceModelRewriteBuilder {
88+
b.ObjectMeta.Namespace = ns
89+
return b
90+
}
91+
92+
func (b *inferenceModelRewriteBuilder) PoolName(name string) *inferenceModelRewriteBuilder {
93+
b.Spec.PoolRef.Name = v1alpha2.ObjectName(name)
94+
return b
95+
}
96+
97+
func (b *inferenceModelRewriteBuilder) CreationTimestamp(t metav1.Time) *inferenceModelRewriteBuilder {
98+
b.ObjectMeta.CreationTimestamp = t
99+
return b
100+
}
101+
102+
func (b *inferenceModelRewriteBuilder) DeletionTimestamp() *inferenceModelRewriteBuilder {
103+
now := metav1.Now()
104+
b.ObjectMeta.DeletionTimestamp = &now
105+
return b
106+
}
107+
108+
func (b *inferenceModelRewriteBuilder) Rules(rules []v1alpha2.InferenceModelRewriteRule) *inferenceModelRewriteBuilder {
109+
b.Spec.Rules = rules
110+
return b
111+
}
112+
113+
func (b *inferenceModelRewriteBuilder) ObjRef() *v1alpha2.InferenceModelRewrite {
114+
return b.InferenceModelRewrite
115+
}
116+
117+
func TestInferenceModelRewriteReconciler(t *testing.T) {
118+
tests := []struct {
119+
name string
120+
rewritesInStore []*v1alpha2.InferenceModelRewrite
121+
rewritesInAPIServer []*v1alpha2.InferenceModelRewrite
122+
rewrite *v1alpha2.InferenceModelRewrite
123+
incomingReq *types.NamespacedName
124+
wantRewrites []*v1alpha2.InferenceModelRewrite
125+
wantResult ctrl.Result
126+
}{
127+
{
128+
name: "Empty store, add new rewrite",
129+
rewrite: rewrite1,
130+
wantRewrites: []*v1alpha2.InferenceModelRewrite{rewrite1},
131+
},
132+
{
133+
name: "Existing rewrite changed pools",
134+
rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1},
135+
rewrite: rewrite1Pool2,
136+
wantRewrites: []*v1alpha2.InferenceModelRewrite{},
137+
},
138+
{
139+
name: "Not found, delete existing rewrite",
140+
rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1},
141+
incomingReq: &types.NamespacedName{Name: rewrite1.Name, Namespace: rewrite1.Namespace},
142+
wantRewrites: []*v1alpha2.InferenceModelRewrite{},
143+
},
144+
{
145+
name: "Deletion timestamp set, delete existing rewrite",
146+
rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1},
147+
rewrite: rewrite1Deleted,
148+
incomingReq: &types.NamespacedName{Name: rewrite1Deleted.Name, Namespace: rewrite1Deleted.Namespace},
149+
wantRewrites: []*v1alpha2.InferenceModelRewrite{},
150+
},
151+
{
152+
name: "Rewrite updated",
153+
rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1},
154+
rewrite: rewrite1Updated,
155+
wantRewrites: []*v1alpha2.InferenceModelRewrite{rewrite1Updated},
156+
},
157+
{
158+
name: "Rewrite not found, no matching existing rewrite to delete",
159+
rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1},
160+
incomingReq: &types.NamespacedName{Name: "non-existent-rewrite", Namespace: poolForRewrite.Namespace},
161+
wantRewrites: []*v1alpha2.InferenceModelRewrite{rewrite1},
162+
},
163+
{
164+
name: "Add to existing",
165+
rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1},
166+
rewrite: rewrite2,
167+
wantRewrites: []*v1alpha2.InferenceModelRewrite{rewrite1, rewrite2},
168+
},
169+
}
170+
for _, test := range tests {
171+
t.Run(test.name, func(t *testing.T) {
172+
scheme := runtime.NewScheme()
173+
_ = clientgoscheme.AddToScheme(scheme)
174+
_ = v1alpha2.Install(scheme)
175+
_ = v1.Install(scheme)
176+
initObjs := []client.Object{}
177+
if test.rewrite != nil && test.rewrite.DeletionTimestamp.IsZero() {
178+
initObjs = append(initObjs, test.rewrite)
179+
}
180+
for _, r := range test.rewritesInAPIServer {
181+
initObjs = append(initObjs, r)
182+
}
183+
fakeClient := fake.NewClientBuilder().
184+
WithScheme(scheme).
185+
WithObjects(initObjs...).
186+
Build()
187+
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
188+
ds := datastore.NewDatastore(t.Context(), pmf, 0)
189+
for _, r := range test.rewritesInStore {
190+
ds.RewriteSet(r)
191+
}
192+
_ = ds.PoolSet(context.Background(), fakeClient, poolForRewrite)
193+
reconciler := &InferenceModelRewriteReconciler{
194+
Reader: fakeClient,
195+
Datastore: ds,
196+
PoolGKNN: common.GKNN{
197+
NamespacedName: types.NamespacedName{Name: poolForRewrite.Name, Namespace: poolForRewrite.Namespace},
198+
GroupKind: schema.GroupKind{Group: poolForRewrite.GroupVersionKind().Group, Kind: poolForRewrite.GroupVersionKind().Kind},
199+
},
200+
}
201+
if test.incomingReq == nil {
202+
test.incomingReq = &types.NamespacedName{Name: test.rewrite.Name, Namespace: test.rewrite.Namespace}
203+
}
204+
205+
result, err := reconciler.Reconcile(context.Background(), ctrl.Request{NamespacedName: *test.incomingReq})
206+
if err != nil {
207+
t.Fatalf("expected no error, got %v", err)
208+
}
209+
210+
if diff := cmp.Diff(result, test.wantResult); diff != "" {
211+
t.Errorf("Unexpected result diff (+got/-want): %s", diff)
212+
}
213+
214+
if len(test.wantRewrites) != len(ds.RewriteGetAll()) {
215+
t.Errorf("Unexpected number of rewrites; want: %d, got:%d", len(test.wantRewrites), len(ds.RewriteGetAll()))
216+
}
217+
218+
if diff := diffStoreRewrites(ds, test.wantRewrites); diff != "" {
219+
t.Errorf("Unexpected diff (+got/-want): %s", diff)
220+
}
221+
})
222+
}
223+
}
224+
225+
func diffStoreRewrites(ds datastore.Datastore, wantRewrites []*v1alpha2.InferenceModelRewrite) string {
226+
if wantRewrites == nil {
227+
wantRewrites = []*v1alpha2.InferenceModelRewrite{}
228+
}
229+
230+
gotRewrites := ds.RewriteGetAll()
231+
if diff := cmp.Diff(wantRewrites, gotRewrites); diff != "" {
232+
return "rewrites:" + diff
233+
}
234+
return ""
235+
}

pkg/epp/datastore/datastore.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ type Datastore interface {
6060
ObjectiveDelete(namespacedName types.NamespacedName)
6161
ObjectiveGetAll() []*v1alpha2.InferenceObjective
6262

63+
// InferenceModelRewrite operations
64+
RewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite)
65+
RewriteDelete(namespacedName types.NamespacedName)
66+
RewriteGetAll() []*v1alpha2.InferenceModelRewrite
67+
6368
// PodList lists pods matching the given predicate.
6469
PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics
6570
PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool
@@ -74,6 +79,7 @@ func NewDatastore(parentCtx context.Context, epFactory datalayer.EndpointFactory
7479
parentCtx: parentCtx,
7580
poolAndObjectivesMu: sync.RWMutex{},
7681
objectives: make(map[string]*v1alpha2.InferenceObjective),
82+
rewrites: make(map[types.NamespacedName]*v1alpha2.InferenceModelRewrite),
7783
pods: &sync.Map{},
7884
modelServerMetricsPort: modelServerMetricsPort,
7985
epf: epFactory,
@@ -89,6 +95,8 @@ type datastore struct {
8995
pool *v1.InferencePool
9096
// key: InferenceObjective.Spec.ModelName, value: *InferenceObjective
9197
objectives map[string]*v1alpha2.InferenceObjective
98+
// key: types.NamespacedName, value: *v1alpha2.InferenceModelRewrite
99+
rewrites map[types.NamespacedName]*v1alpha2.InferenceModelRewrite
92100
// key: types.NamespacedName, value: backendmetrics.PodMetrics
93101
pods *sync.Map
94102
// modelServerMetricsPort metrics port from EPP command line argument
@@ -102,6 +110,7 @@ func (ds *datastore) Clear() {
102110
defer ds.poolAndObjectivesMu.Unlock()
103111
ds.pool = nil
104112
ds.objectives = make(map[string]*v1alpha2.InferenceObjective)
113+
ds.rewrites = make(map[types.NamespacedName]*v1alpha2.InferenceModelRewrite)
105114
// stop all pods go routines before clearing the pods map.
106115
ds.pods.Range(func(_, v any) bool {
107116
ds.epf.ReleaseEndpoint(v.(backendmetrics.PodMetrics))
@@ -197,6 +206,28 @@ func (ds *datastore) ObjectiveGetAll() []*v1alpha2.InferenceObjective {
197206
return res
198207
}
199208

209+
func (ds *datastore) RewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite) {
210+
ds.poolAndObjectivesMu.Lock()
211+
defer ds.poolAndObjectivesMu.Unlock()
212+
ds.rewrites[types.NamespacedName{Name: infModelRewrite.Name, Namespace: infModelRewrite.Namespace}] = infModelRewrite
213+
}
214+
215+
func (ds *datastore) RewriteDelete(namespacedName types.NamespacedName) {
216+
ds.poolAndObjectivesMu.Lock()
217+
defer ds.poolAndObjectivesMu.Unlock()
218+
delete(ds.rewrites, namespacedName)
219+
}
220+
221+
func (ds *datastore) RewriteGetAll() []*v1alpha2.InferenceModelRewrite {
222+
ds.poolAndObjectivesMu.RLock()
223+
defer ds.poolAndObjectivesMu.RUnlock()
224+
res := []*v1alpha2.InferenceModelRewrite{}
225+
for _, v := range ds.rewrites {
226+
res = append(res, v)
227+
}
228+
return res
229+
}
230+
200231
// /// Pods/endpoints APIs ///
201232
// TODO: add a flag for callers to specify the staleness threshold for metrics.
202233
// ref: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/1046#discussion_r2246351694

0 commit comments

Comments
 (0)