@@ -23,6 +23,7 @@ import (
2323 "fmt"
2424 "math/rand"
2525 "net"
26+ "sort"
2627 "strings"
2728 "time"
2829
@@ -46,6 +47,7 @@ type Datastore interface {
4647 PoolGet () (* v1.InferencePool , error )
4748 ObjectiveGet (modelName string ) * v1alpha2.InferenceObjective
4849 PodList (predicate func (backendmetrics.PodMetrics ) bool ) []backendmetrics.PodMetrics
50+ RewriteGetAll () []* v1alpha2.InferenceModelRewrite
4951}
5052
5153// Scheduler defines the interface required by the Director for scheduling.
@@ -106,6 +108,9 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
106108 // Default to incoming model name
107109 reqCtx .TargetModelName = reqCtx .IncomingModelName
108110 }
111+
112+ d .applyWeightedModelRewrite (reqCtx )
113+
109114 reqCtx .Request .Body ["model" ] = reqCtx .TargetModelName
110115
111116 requestBody , err := requtil .ExtractRequestBody (reqCtx .Request .Body )
@@ -166,6 +171,56 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
166171 return reqCtx , nil
167172}
168173
174+ func (d * Director ) applyWeightedModelRewrite (reqCtx * handlers.RequestContext ) {
175+ rewrites := d .datastore .RewriteGetAll ()
176+ if len (rewrites ) == 0 {
177+ return
178+ }
179+
180+ sort .Slice (rewrites , func (i , j int ) bool {
181+ return rewrites [i ].CreationTimestamp .Before (& rewrites [j ].CreationTimestamp )
182+ })
183+
184+ for _ , rewrite := range rewrites {
185+ for _ , rule := range rewrite .Spec .Rules {
186+ for _ , match := range rule .Matches {
187+ if match .Model != nil && match .Model .Value == reqCtx .IncomingModelName {
188+ reqCtx .TargetModelName = d .selectWeightedModel (rule .Targets )
189+ return
190+ }
191+ }
192+ }
193+ }
194+ }
195+
196+ func (d * Director ) selectWeightedModel (models []v1alpha2.TargetModel ) string {
197+ if len (models ) == 0 {
198+ return ""
199+ }
200+
201+ var totalWeight int32
202+ for _ , model := range models {
203+ totalWeight += model .Weight
204+ }
205+
206+ if totalWeight == 0 {
207+ // If total weight is 0, distribute evenly
208+ return models [rand .Intn (len (models ))].ModelRewrite
209+ }
210+
211+ randomNum := rand .Intn (int (totalWeight ))
212+ var currentWeight int32
213+ for _ , model := range models {
214+ currentWeight += model .Weight
215+ if randomNum < int (currentWeight ) {
216+ return model .ModelRewrite
217+ }
218+ }
219+
220+ // Should not happen
221+ return models [len (models )- 1 ].ModelRewrite
222+ }
223+
169224// getCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore.
170225// according to EPP protocol, if "x-gateway-destination-endpoint-subset" is set on the request metadata and specifies
171226// a subset of endpoints, only these endpoints will be considered as candidates for the scheduler.
0 commit comments