Skip to content

Commit 9782009

Browse files
feat: modelcatelog enhancements
1 parent 531577e commit 9782009

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+149
-62
lines changed

core/bifrost.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1986,8 +1986,8 @@ func executeRequestWithRetries[T any](
19861986
// Retry if status code or error object indicates rate limiting
19871987
if (bifrostError.StatusCode != nil && retryableStatusCodes[*bifrostError.StatusCode]) ||
19881988
(bifrostError.Error != nil &&
1989-
(isRateLimitError(bifrostError.Error.Message) ||
1990-
(bifrostError.Error.Type != nil && isRateLimitError(*bifrostError.Error.Type)))) {
1989+
(IsRateLimitErrorMessage(bifrostError.Error.Message) ||
1990+
(bifrostError.Error.Type != nil && IsRateLimitErrorMessage(*bifrostError.Error.Type)))) {
19911991
shouldRetry = true
19921992
logger.Debug("detected rate limit error in message, will retry: %s", bifrostError.Error.Message)
19931993
}

core/bifrost_test.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ func TestCalculateBackoff_MaxBackoffCap(t *testing.T) {
353353
}
354354
}
355355

356-
// Test isRateLimitError - all patterns
356+
// Test IsRateLimitErrorMessage - all patterns
357357
func TestIsRateLimitError_AllPatterns(t *testing.T) {
358358
// Test all patterns from rateLimitPatterns
359359
patterns := []string{
@@ -382,36 +382,36 @@ func TestIsRateLimitError_AllPatterns(t *testing.T) {
382382
for _, pattern := range patterns {
383383
t.Run(fmt.Sprintf("Pattern_%s", strings.ReplaceAll(pattern, " ", "_")), func(t *testing.T) {
384384
// Test exact match
385-
if !isRateLimitError(pattern) {
385+
if !IsRateLimitErrorMessage(pattern) {
386386
t.Errorf("Pattern '%s' should be detected as rate limit error", pattern)
387387
}
388388

389389
// Test case insensitive - uppercase
390-
if !isRateLimitError(strings.ToUpper(pattern)) {
390+
if !IsRateLimitErrorMessage(strings.ToUpper(pattern)) {
391391
t.Errorf("Uppercase pattern '%s' should be detected as rate limit error", strings.ToUpper(pattern))
392392
}
393393

394394
// Test case insensitive - mixed case
395-
if !isRateLimitError(strings.Title(pattern)) {
395+
if !IsRateLimitErrorMessage(strings.Title(pattern)) {
396396
t.Errorf("Title case pattern '%s' should be detected as rate limit error", strings.Title(pattern))
397397
}
398398

399399
// Test as part of larger message
400400
message := fmt.Sprintf("Error: %s occurred", pattern)
401-
if !isRateLimitError(message) {
401+
if !IsRateLimitErrorMessage(message) {
402402
t.Errorf("Pattern '%s' in message '%s' should be detected", pattern, message)
403403
}
404404

405405
// Test with prefix and suffix
406406
message = fmt.Sprintf("API call failed due to %s - please retry later", pattern)
407-
if !isRateLimitError(message) {
407+
if !IsRateLimitErrorMessage(message) {
408408
t.Errorf("Pattern '%s' in complex message should be detected", pattern)
409409
}
410410
})
411411
}
412412
}
413413

414-
// Test isRateLimitError - negative cases
414+
// Test IsRateLimitErrorMessage - negative cases
415415
func TestIsRateLimitError_NegativeCases(t *testing.T) {
416416
negativeCases := []string{
417417
"",
@@ -431,31 +431,31 @@ func TestIsRateLimitError_NegativeCases(t *testing.T) {
431431

432432
for _, testCase := range negativeCases {
433433
t.Run(fmt.Sprintf("Negative_%s", strings.ReplaceAll(testCase, " ", "_")), func(t *testing.T) {
434-
if isRateLimitError(testCase) {
434+
if IsRateLimitErrorMessage(testCase) {
435435
t.Errorf("Message '%s' should NOT be detected as rate limit error", testCase)
436436
}
437437
})
438438
}
439439
}
440440

441-
// Test isRateLimitError - edge cases
441+
// Test IsRateLimitErrorMessage - edge cases
442442
func TestIsRateLimitError_EdgeCases(t *testing.T) {
443443
t.Run("EmptyString", func(t *testing.T) {
444-
if isRateLimitError("") {
444+
if IsRateLimitErrorMessage("") {
445445
t.Error("Empty string should not be detected as rate limit error")
446446
}
447447
})
448448

449449
t.Run("OnlyWhitespace", func(t *testing.T) {
450-
if isRateLimitError(" \t\n ") {
450+
if IsRateLimitErrorMessage(" \t\n ") {
451451
t.Error("Whitespace-only string should not be detected as rate limit error")
452452
}
453453
})
454454

455455
t.Run("UnicodeCharacters", func(t *testing.T) {
456456
// Test with unicode characters that might affect case conversion
457457
message := "RATE LIMIT exceeded 🚫"
458-
if !isRateLimitError(message) {
458+
if !IsRateLimitErrorMessage(message) {
459459
t.Error("Message with unicode should still detect rate limit pattern")
460460
}
461461
})
@@ -544,7 +544,7 @@ func BenchmarkCalculateBackoff(b *testing.B) {
544544
}
545545
}
546546

547-
// Benchmark isRateLimitError performance
547+
// Benchmark IsRateLimitErrorMessage performance
548548
func BenchmarkIsRateLimitError(b *testing.B) {
549549
messages := []string{
550550
"rate limit exceeded",
@@ -559,7 +559,7 @@ func BenchmarkIsRateLimitError(b *testing.B) {
559559

560560
b.ResetTimer()
561561
for i := 0; i < b.N; i++ {
562-
isRateLimitError(messages[i%len(messages)])
562+
IsRateLimitErrorMessage(messages[i%len(messages)])
563563
}
564564
}
565565

core/changelog.md

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,4 @@
11
<!-- The pattern we follow here is to keep the changelog for the latest version -->
22
<!-- Old changelogs are automatically attached to the GitHub releases -->
33

4-
- feat: Use all keys for list models request
5-
- refactor: Cohere provider to use completeRequest and response pooling for all requests
6-
- chore: Added id, object, and model fields to Chat Completion responses from Bedrock and Cohere providers
7-
- feat: Add request level control for adding extra headers, url path, skipping key selection, and sending back raw response
8-
- feat: Added support for gzip decompression of response bodies from all providers
9-
- feat: Added support for anthropic's meta data and thinking signature fields
10-
- feat: Moved all streaming calls to use fasthttp client for efficiency
11-
- refactor: Moved all convertors from schemas/providers to separate provider packages in providers directory
4+
- refactor: minor until changes

core/utils.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ func validateRequest(req *schemas.BifrostRequest) *schemas.BifrostError {
9292
return nil
9393
}
9494

95-
// isRateLimitError checks if an error message indicates a rate limit issue
96-
func isRateLimitError(errorMessage string) bool {
95+
// IsRateLimitErrorMessage checks if an error message indicates a rate limit issue
96+
func IsRateLimitErrorMessage(errorMessage string) bool {
9797
if errorMessage == "" {
9898
return false
9999
}

core/version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.2.17
1+
1.2.18

framework/changelog.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
<!-- The pattern we follow here is to keep the changelog for the latest version -->
22
<!-- Old changelogs are automatically attached to the GitHub releases -->
33

4-
- chore: Upgrades core to 1.2.17
5-
- feat: Adds dynamic plugins support
6-
- feat: Adds auth tables in config store
4+
- chore: Upgrades core to 1.2.18
5+
- enhancement: provider lookup enhancements

framework/modelcatalog/main.go

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"context"
66
"fmt"
77
"slices"
8+
"strings"
89
"sync"
910
"time"
1011

@@ -16,7 +17,7 @@ import (
1617
// Default sync interval and config key
1718
const (
1819
DefaultPricingSyncInterval = 24 * time.Hour
19-
ConfigLastPricingSyncKey = "LastModelPricingSync"
20+
ConfigLastPricingSyncKey = "LastModelPricingSync"
2021
DefaultPricingURL = "https://getbifrost.ai/datasheet"
2122
TokenTierAbove128K = 128000
2223
)
@@ -215,6 +216,51 @@ func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvid
215216
providers = append(providers, provider)
216217
}
217218
}
219+
220+
// Handler special provider cases
221+
// 1. Handler openrouter models
222+
if !slices.Contains(providers, schemas.OpenRouter) {
223+
for _, provider := range providers {
224+
if openRouterModels, ok := mc.modelPool[schemas.OpenRouter]; ok {
225+
if slices.Contains(openRouterModels, string(provider)+"/"+model) {
226+
providers = append(providers, schemas.OpenRouter)
227+
}
228+
}
229+
}
230+
}
231+
232+
// 2. Handle vertex models
233+
if !slices.Contains(providers, schemas.Vertex) {
234+
for _, provider := range providers {
235+
if vertexModels, ok := mc.modelPool[schemas.Vertex]; ok {
236+
if slices.Contains(vertexModels, string(provider)+"/"+model) {
237+
providers = append(providers, schemas.Vertex)
238+
}
239+
}
240+
}
241+
}
242+
243+
// 3. Handle openai models for groq
244+
if !slices.Contains(providers, schemas.Groq) && strings.Contains(model, "gpt-") {
245+
if groqModels, ok := mc.modelPool[schemas.Groq]; ok {
246+
if slices.Contains(groqModels, "openai/"+model) {
247+
providers = append(providers, schemas.Groq)
248+
}
249+
}
250+
}
251+
252+
// 4. Handle anthropic models for bedrock
253+
if !slices.Contains(providers, schemas.Bedrock) && strings.Contains(model, "claude") {
254+
if bedrockModels, ok := mc.modelPool[schemas.Bedrock]; ok {
255+
for _, bedrockModel := range bedrockModels {
256+
if strings.Contains(bedrockModel, model) {
257+
providers = append(providers, schemas.Bedrock)
258+
break
259+
}
260+
}
261+
}
262+
}
263+
218264
return providers
219265
}
220266

@@ -235,6 +281,18 @@ func (mc *ModelCatalog) AddModelDataToPool(modelData *schemas.BifrostListModelsR
235281
}
236282
}
237283

284+
// RefineModelForProvider refines the model for a given provider.
285+
// e.g. "gpt-oss-120b" for groq provider -> "openai/gpt-oss-120b"
286+
func (mc *ModelCatalog) RefineModelForProvider(provider schemas.ModelProvider, model string) string {
287+
switch provider {
288+
case schemas.Groq:
289+
if model == "gpt-oss-120b" {
290+
return "openai/" + model
291+
}
292+
}
293+
return model
294+
}
295+
238296
// populateModelPool populates the model pool with all available models per provider (thread-safe)
239297
func (mc *ModelCatalog) populateModelPoolFromPricingData() {
240298
// Acquire write lock for the entire rebuild operation

framework/version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.1.20
1+
1.1.21

plugins/governance/changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
<!-- The pattern we follow here is to keep the changelog for the latest version -->
22
<!-- Old changelogs are automatically attached to the GitHub releases -->
33

4-
- chore: version update core to 1.2.17 and framework to 1.1.19
4+
- chore: version update core to 1.2.18 and framework to 1.1.21

plugins/governance/version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.3.21
1+
1.3.22

0 commit comments

Comments
 (0)