Skip to content

Commit ea8b4c3

Browse files
authored
Add configure numGPUBlocks for approximate prefix cache. (#1748)
1 parent 176601e commit ea8b4c3

File tree

7 files changed

+165
-43
lines changed

7 files changed

+165
-43
lines changed

pkg/epp/backend/metrics/metrics.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ const (
3838
LoraInfoMaxAdaptersMetricName = "max_lora"
3939

4040
CacheConfigBlockSizeInfoMetricName = "block_size"
41+
CacheConfigNumGPUBlocksMetricName = "num_gpu_blocks"
4142
)
4243

4344
type PodMetricsClientImpl struct {
@@ -148,12 +149,16 @@ func (p *PodMetricsClientImpl) promToPodMetrics(
148149
errs = multierr.Append(errs, err)
149150
} else {
150151
for _, v := range cacheMetrics.GetLabel() {
151-
if v.GetName() == CacheConfigBlockSizeInfoMetricName {
152+
switch v.GetName() {
153+
case CacheConfigBlockSizeInfoMetricName:
152154
updated.CacheBlockSize, err = strconv.Atoi(v.GetValue())
153155
if err != nil {
154156
errs = multierr.Append(errs, err)
155-
} else {
156-
break
157+
}
158+
case CacheConfigNumGPUBlocksMetricName:
159+
updated.CacheNumGPUBlocks, err = strconv.Atoi(v.GetValue())
160+
if err != nil {
161+
errs = multierr.Append(errs, err)
157162
}
158163
}
159164
}

pkg/epp/backend/metrics/metrics_test.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,95 @@ func TestGetLatestLoraMetric(t *testing.T) {
373373
}
374374
}
375375

376+
func TestCacheConfigInfoMetrics(t *testing.T) {
377+
testCases := []struct {
378+
name string
379+
metricFamilies map[string]*dto.MetricFamily
380+
mapping *MetricMapping
381+
existingMetrics *MetricsState
382+
expectedMetrics *MetricsState
383+
expectedErr error
384+
}{
385+
{
386+
name: "successful cache config metrics",
387+
metricFamilies: map[string]*dto.MetricFamily{
388+
"vllm_cache_config": makeMetricFamily("vllm_cache_config",
389+
makeMetric(map[string]string{"block_size": "16", "num_gpu_blocks": "1024"}, 1.0, 1000),
390+
),
391+
},
392+
mapping: &MetricMapping{
393+
CacheConfigInfo: &MetricSpec{MetricName: "vllm_cache_config"},
394+
},
395+
existingMetrics: &MetricsState{},
396+
expectedMetrics: &MetricsState{
397+
CacheBlockSize: 16,
398+
CacheNumGPUBlocks: 1024,
399+
},
400+
expectedErr: nil,
401+
},
402+
{
403+
name: "invalid block_size value",
404+
metricFamilies: map[string]*dto.MetricFamily{
405+
"vllm_cache_config": makeMetricFamily("vllm_cache_config",
406+
makeMetric(map[string]string{"block_size": "invalid", "num_gpu_blocks": "1024"}, 1.0, 1000),
407+
),
408+
},
409+
mapping: &MetricMapping{
410+
CacheConfigInfo: &MetricSpec{MetricName: "vllm_cache_config"},
411+
},
412+
existingMetrics: &MetricsState{},
413+
expectedMetrics: &MetricsState{
414+
CacheNumGPUBlocks: 1024,
415+
},
416+
expectedErr: errors.New("strconv.Atoi: parsing \"invalid\": invalid syntax"),
417+
},
418+
{
419+
name: "invalid num_gpu_blocks value",
420+
metricFamilies: map[string]*dto.MetricFamily{
421+
"vllm_cache_config": makeMetricFamily("vllm_cache_config",
422+
makeMetric(map[string]string{"block_size": "16", "num_gpu_blocks": "invalid"}, 1.0, 1000),
423+
),
424+
},
425+
mapping: &MetricMapping{
426+
CacheConfigInfo: &MetricSpec{MetricName: "vllm_cache_config"},
427+
},
428+
existingMetrics: &MetricsState{},
429+
expectedMetrics: &MetricsState{
430+
CacheBlockSize: 16,
431+
},
432+
expectedErr: errors.New("strconv.Atoi: parsing \"invalid\": invalid syntax"),
433+
},
434+
{
435+
name: "no cache config if not in MetricMapping",
436+
metricFamilies: map[string]*dto.MetricFamily{
437+
"vllm_cache_config": makeMetricFamily("vllm_cache_config",
438+
makeMetric(map[string]string{"block_size": "16", "num_gpu_blocks": "1024"}, 1.0, 1000),
439+
),
440+
},
441+
mapping: &MetricMapping{}, // No CacheConfigInfo defined
442+
existingMetrics: &MetricsState{},
443+
expectedMetrics: &MetricsState{},
444+
expectedErr: nil,
445+
},
446+
}
447+
448+
for _, tc := range testCases {
449+
t.Run(tc.name, func(t *testing.T) {
450+
p := &PodMetricsClientImpl{MetricMapping: tc.mapping}
451+
updated, err := p.promToPodMetrics(tc.metricFamilies, tc.existingMetrics)
452+
453+
if tc.expectedErr != nil {
454+
assert.Error(t, err)
455+
assert.Contains(t, err.Error(), tc.expectedErr.Error())
456+
} else {
457+
assert.NoError(t, err)
458+
assert.Equal(t, tc.expectedMetrics.CacheBlockSize, updated.CacheBlockSize, "CacheBlockSize mismatch")
459+
assert.Equal(t, tc.expectedMetrics.CacheNumGPUBlocks, updated.CacheNumGPUBlocks, "CacheNumGPUBlocks mismatch")
460+
}
461+
})
462+
}
463+
}
464+
376465
func TestPromToPodMetrics(t *testing.T) {
377466
tests := []struct {
378467
name string

pkg/epp/datalayer/metrics.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ type Metrics struct {
3333
KVCacheUsagePercent float64
3434
KvCacheMaxTokenCapacity int
3535
CacheBlockSize int
36+
// Number of GPU blocks in the model server for KV Cache.
37+
CacheNumGPUBlocks int
3638

3739
// UpdateTime records the last time when the metrics were updated.
3840
UpdateTime time.Time
@@ -77,6 +79,7 @@ func (m *Metrics) Clone() *Metrics {
7779
KVCacheUsagePercent: m.KVCacheUsagePercent,
7880
KvCacheMaxTokenCapacity: m.KvCacheMaxTokenCapacity,
7981
CacheBlockSize: m.CacheBlockSize,
82+
CacheNumGPUBlocks: m.CacheNumGPUBlocks,
8083
UpdateTime: m.UpdateTime,
8184
}
8285
}

pkg/epp/datalayer/metrics/extractor.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ const (
4040
LoraInfoMaxAdaptersMetricName = "max_lora"
4141

4242
CacheConfigBlockSizeInfoMetricName = "block_size"
43+
CacheConfigNumGPUBlocksMetricName = "num_gpu_blocks"
4344
)
4445

4546
// Extractor implements the metrics extraction based on the model
@@ -173,11 +174,19 @@ func populateLoRAMetrics(clone *datalayer.Metrics, metric *dto.Metric, errs *[]e
173174
func populateCacheInfoMetrics(clone *datalayer.Metrics, metric *dto.Metric, errs *[]error) {
174175
clone.CacheBlockSize = 0
175176
for _, label := range metric.GetLabel() {
176-
if label.GetName() == CacheConfigBlockSizeInfoMetricName {
177+
switch label.GetName() {
178+
case CacheConfigBlockSizeInfoMetricName:
177179
if label.GetValue() != "" {
178180
if val, err := strconv.Atoi(label.GetValue()); err == nil {
179181
clone.CacheBlockSize = val
180-
break
182+
} else {
183+
*errs = append(*errs, err)
184+
}
185+
}
186+
case CacheConfigNumGPUBlocksMetricName:
187+
if label.GetValue() != "" {
188+
if val, err := strconv.Atoi(label.GetValue()); err == nil {
189+
clone.CacheNumGPUBlocks = val
181190
} else {
182191
*errs = append(*errs, err)
183192
}

pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,32 +31,36 @@ import (
3131
// An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that
3232
// prefix cached.
3333
type indexer struct {
34-
mu sync.RWMutex
35-
hashToPods map[BlockHash]podSet // the lookup data structure to find pods that have the BlockHash cached
36-
podToLRU map[ServerID]*lru.Cache[BlockHash, struct{}] // key is pod namespacedName, value is an LRU cache
37-
maxLRUSize int
34+
mu sync.RWMutex
35+
hashToPods map[BlockHash]podSet // the lookup data structure to find pods that have the BlockHash cached
36+
podToLRU map[ServerID]*lru.Cache[BlockHash, struct{}] // key is pod namespacedName, value is an LRU cache
37+
defaultLRUSize int
3838
}
3939

4040
// newIndexer initializes an indexer with size limits and starts cache size reporting.
41-
func newIndexer(ctx context.Context, maxLRUSize int) *indexer {
41+
func newIndexer(ctx context.Context, defaultLRUSize int) *indexer {
4242
indexer := &indexer{
43-
hashToPods: make(map[BlockHash]podSet),
44-
podToLRU: make(map[ServerID]*lru.Cache[BlockHash, struct{}]),
45-
maxLRUSize: maxLRUSize,
43+
hashToPods: make(map[BlockHash]podSet),
44+
podToLRU: make(map[ServerID]*lru.Cache[BlockHash, struct{}]),
45+
defaultLRUSize: defaultLRUSize,
4646
}
4747

4848
go indexer.reportLRUSize(ctx, time.Second)
4949
return indexer
5050
}
5151

5252
// Add adds a list of prefix hashes to the cache, tied to the server.
53-
func (i *indexer) Add(hashes []BlockHash, pod ServerID) {
53+
func (i *indexer) Add(hashes []BlockHash, pod Server) {
5454
i.mu.Lock()
5555
// Check if the LRU pod exist
56-
lruForPod, exists := i.podToLRU[pod]
56+
lruForPod, exists := i.podToLRU[pod.ServerID]
5757
if !exists {
58-
newLRU, _ := lru.NewWithEvict[BlockHash, struct{}](i.maxLRUSize, i.makeEvictionFn(pod))
59-
i.podToLRU[pod] = newLRU
58+
lruSize := pod.numOfGPUBlocks
59+
if lruSize <= 0 {
60+
lruSize = i.defaultLRUSize
61+
}
62+
newLRU, _ := lru.NewWithEvict(lruSize, i.makeEvictionFn(pod.ServerID))
63+
i.podToLRU[pod.ServerID] = newLRU
6064
lruForPod = newLRU
6165
}
6266

@@ -70,12 +74,12 @@ func (i *indexer) Add(hashes []BlockHash, pod ServerID) {
7074
// Update hashToPods once under lock
7175
i.mu.Lock()
7276
for _, hash := range hashes {
73-
pods := i.hashToPods[hash]
74-
if pods == nil {
75-
pods = make(podSet)
77+
podIDs := i.hashToPods[hash]
78+
if podIDs == nil {
79+
podIDs = make(podSet)
7680
}
77-
pods[pod] = struct{}{}
78-
i.hashToPods[hash] = pods
81+
podIDs[pod.ServerID] = struct{}{}
82+
i.hashToPods[hash] = podIDs
7983
}
8084

8185
i.mu.Unlock()
@@ -143,7 +147,7 @@ func (i *indexer) reportLRUSize(ctx context.Context, interval time.Duration) {
143147
"avg entries per pod", avg,
144148
"pod with max cache", maxPodName,
145149
"max pod size", maxPodEntries,
146-
"global max LRU cache capacity per pod", i.maxLRUSize,
150+
"global max LRU cache capacity per pod", i.defaultLRUSize,
147151
)
148152

149153
i.mu.RUnlock()

pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,28 @@ import (
2323
)
2424

2525
func TestIndexer_AddAndGet(t *testing.T) {
26-
i := newIndexer(context.Background(), 2)
26+
server := Server{
27+
ServerID: ServerID{Namespace: "default", Name: "server1"},
28+
numOfGPUBlocks: 2,
29+
}
30+
i := newIndexer(context.Background(), 3) // Initialize with an lruSize greater than server.numOfGPUBlocks to verify server-defined limits take precedence.
2731

2832
hash1 := BlockHash(1)
29-
server := ServerID{Namespace: "default", Name: "server1"}
3033
// Add an entry to the cache
3134
i.Add([]BlockHash{hash1}, server)
3235

3336
// Retrieve the entry
34-
assert.Equal(t, 1, i.podToLRU[server].Len(), "Cache size should be 1 after adding an entry")
37+
assert.Equal(t, 1, i.podToLRU[server.ServerID].Len(), "Cache size should be 1 after adding an entry")
3538
servers := i.Get(hash1)
36-
assert.Contains(t, servers, server, "Cache should contain the added server")
39+
assert.Contains(t, servers, server.ServerID, "Cache should contain the added server")
3740

3841
// Add another entry to the cache, the cache size should be incremented to 2.
3942
i.Add([]BlockHash{BlockHash(2)}, server)
40-
assert.Equal(t, 2, i.podToLRU[server].Len(), "Cache size should be 2 after adding an entry")
43+
assert.Equal(t, 2, i.podToLRU[server.ServerID].Len(), "Cache size should be 2 after adding an entry")
4144

4245
// Add another entry to the cache, which should evict the first one due to max size.
4346
i.Add([]BlockHash{BlockHash(3)}, server)
44-
assert.Equal(t, 2, i.podToLRU[server].Len(), "Cache size should still be 2 after adding an entry")
47+
assert.Equal(t, 2, i.podToLRU[server.ServerID].Len(), "Cache size should still be 2 after adding an entry")
4548

4649
servers = i.Get(BlockHash(4))
4750
assert.Empty(t, servers, "Cache should not contain non-existent hash")
@@ -52,8 +55,8 @@ func TestIndexer_RemovePodAndEviction(t *testing.T) {
5255

5356
i := newIndexer(context.Background(), indexerSize)
5457

55-
server1 := ServerID{Namespace: "default", Name: "server1"}
56-
server2 := ServerID{Namespace: "default", Name: "server2"}
58+
server1 := Server{ServerID: ServerID{Namespace: "default", Name: "server1"}}
59+
server2 := Server{ServerID: ServerID{Namespace: "default", Name: "server2"}}
5760

5861
// Add indexerSize hashes to both servers
5962
var hashes []BlockHash
@@ -65,15 +68,15 @@ func TestIndexer_RemovePodAndEviction(t *testing.T) {
6568
}
6669

6770
// Ensure all entries are added
68-
assert.Equal(t, indexerSize, i.podToLRU[server1].Len(), "server1 should have 10 entries")
69-
assert.Equal(t, indexerSize, i.podToLRU[server2].Len(), "server2 should have 10 entries")
71+
assert.Equal(t, indexerSize, i.podToLRU[server1.ServerID].Len(), "server1 should have 10 entries")
72+
assert.Equal(t, indexerSize, i.podToLRU[server2.ServerID].Len(), "server2 should have 10 entries")
7073

7174
// Ensure each hash in hashToPods maps to both server1 and server2
7275
for _, h := range hashes {
7376
pods := i.hashToPods[h]
7477
assert.Len(t, pods, 2, "Each hash should be associated with exactly 2 pods")
75-
assert.Contains(t, pods, server1, "hash should be associated with server1")
76-
assert.Contains(t, pods, server2, "hash should be associated with server2")
78+
assert.Contains(t, pods, server1.ServerID, "hash should be associated with server1")
79+
assert.Contains(t, pods, server2.ServerID, "hash should be associated with server2")
7780
}
7881

7982
// Add indexerSize hash to server1 → should evict BlockHash(0)
@@ -82,25 +85,25 @@ func TestIndexer_RemovePodAndEviction(t *testing.T) {
8285
i.Add([]BlockHash{newHash}, server1)
8386

8487
// server1 LRU should still be at max capacity
85-
assert.Equal(t, indexerSize, i.podToLRU[server1].Len(), "server1 LRU should maintain max size")
88+
assert.Equal(t, indexerSize, i.podToLRU[server1.ServerID].Len(), "server1 LRU should maintain max size")
8689

8790
// BlockHash(0) should no longer have server1 in hashToPods
8891
pods := i.Get(evictedHash)
89-
assert.NotContains(t, pods, server1, "server1 should be evicted from hashToPods for hash 0")
90-
assert.Contains(t, pods, server2, "server2 should still have hash 0")
92+
assert.NotContains(t, pods, server1.ServerID, "server1 should be evicted from hashToPods for hash 0")
93+
assert.Contains(t, pods, server2.ServerID, "server2 should still have hash 0")
9194

9295
// Remove server2
93-
i.RemovePod(server2)
96+
i.RemovePod(server2.ServerID)
9497

9598
// hashToPods for hash 0 should now be empty
9699
pods = i.Get(evictedHash)
97-
assert.NotContains(t, pods, server2, "server2 should be removed from hash 0")
100+
assert.NotContains(t, pods, server2.ServerID, "server2 should be removed from hash 0")
98101
assert.Empty(t, pods, "hash 0 should have no pods after both eviction and removal")
99102

100103
// All remaining hashes should map only to server1
101104
for hash, pods := range i.hashToPods {
102105
assert.Len(t, pods, 1, "hash %v should have only 1 pod after server2 removal", hash)
103-
assert.Contains(t, pods, server1, "hash %v should only contain server1", hash)
106+
assert.Contains(t, pods, server1.ServerID, "hash %v should only contain server1", hash)
104107
}
105108

106109
// Ensure hashToPods contains exactly indexerSize hashes (post-eviction and server2 removal)

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,19 @@ type podSet map[ServerID]struct{}
9696

9797
type Indexer interface {
9898
Get(hash BlockHash) podSet
99-
Add(hashes []BlockHash, server ServerID)
99+
Add(hashes []BlockHash, server Server)
100100
RemovePod(server ServerID)
101101
Pods() []ServerID
102102
}
103103

104104
// BlockHash is a hash of the block of request body.
105105
type BlockHash uint64
106106

107+
type Server struct {
108+
ServerID
109+
numOfGPUBlocks int
110+
}
111+
107112
type ServerID k8stypes.NamespacedName
108113

109114
func (s ServerID) String() string {
@@ -224,6 +229,7 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
224229
func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) {
225230
primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]
226231
targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile
232+
gpuBlocks := primaryProfileResult.TargetPods[0].GetMetrics().CacheNumGPUBlocks
227233

228234
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
229235
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it
@@ -238,7 +244,10 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
238244
// WaitGroup is added to the Plugin struct to allow waiting in tests.
239245
p.wg.Add(1)
240246
go func() {
241-
p.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName))
247+
p.indexer.Add(state.PrefixHashes, Server{
248+
ServerID(targetPod.NamespacedName),
249+
gpuBlocks,
250+
})
242251
p.wg.Done()
243252
}()
244253

0 commit comments

Comments
 (0)