Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type Bifrost struct {
responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init
pluginPipelinePool sync.Pool // Pool for PluginPipeline objects
bifrostRequestPool sync.Pool // Pool for BifrostRequest objects
pricingData sync.Map // pricing data for each model
logger schemas.Logger // logger instance, default logger is used if not provided
mcpManager *MCPManager // MCP integration manager (nil if MCP not configured)
dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
Expand Down Expand Up @@ -98,6 +99,7 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
plugins: atomic.Pointer[[]schemas.Plugin]{},
requestQueues: sync.Map{},
waitGroups: sync.Map{},
pricingData: sync.Map{},
keySelector: config.KeySelector,
logger: config.Logger,
}
Expand Down Expand Up @@ -287,6 +289,8 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr
}
return nil, bifrostErr
}
// Add pricing data to the response
response.AddPricing(bifrost.GetPricingDataForModel)
return response, nil
}

Expand Down Expand Up @@ -369,6 +373,9 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr
break
}

// Add pricing data to the response
response.AddPricing(bifrost.GetPricingDataForModel)

providerModels = append(providerModels, response.Data...)

// Check if there are more pages
Expand Down Expand Up @@ -822,6 +829,38 @@ func (bifrost *Bifrost) ReloadPlugin(plugin schemas.Plugin) error {
}
}

// SetPricingData sets pricing data for all the models.
// This is used to set pricing data for all the models at once.
//
// Parameters:
// - pricingData: A map of model names to pricing data
func (bifrost *Bifrost) SetPricingData(pricingData map[string]schemas.DataSheetPricingEntry) {
for model, pricing := range pricingData {
bifrost.pricingData.Store(pricing.Provider+"/"+model, pricing)
}
}
Comment on lines +838 to +841
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Normalize model keys to prevent double provider prefixes.

If incoming map keys already include "/", we end up storing "provider/provider/model". Strip the prefix before storing.

Apply this diff:

-func (bifrost *Bifrost) SetPricingData(pricingData map[string]schemas.DataSheetPricingEntry) {
-  for model, pricing := range pricingData {
-    bifrost.pricingData.Store(pricing.Provider+"/"+model, pricing)
-  }
-}
+func (bifrost *Bifrost) SetPricingData(pricingData map[string]schemas.DataSheetPricingEntry) {
+  for model, pricing := range pricingData {
+    // Normalize model: strip "<provider>/" if present
+    if strings.HasPrefix(model, pricing.Provider+"/") {
+      model = strings.TrimPrefix(model, pricing.Provider+"/")
+    }
+    bifrost.pricingData.Store(pricing.Provider+"/"+model, pricing)
+  }
+}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for model, pricing := range pricingData {
bifrost.pricingData.Store(pricing.Provider+"/"+model, pricing)
}
}
func (bifrost *Bifrost) SetPricingData(pricingData map[string]schemas.DataSheetPricingEntry) {
for model, pricing := range pricingData {
// Normalize model: strip "<provider>/" if present
if strings.HasPrefix(model, pricing.Provider+"/") {
model = strings.TrimPrefix(model, pricing.Provider+"/")
}
bifrost.pricingData.Store(pricing.Provider+"/"+model, pricing)
}
}
🤖 Prompt for AI Agents
In core/bifrost.go around lines 838 to 841, the loop stores pricing entries
using pricing.Provider+"/"+model which can duplicate the provider prefix if
model already starts with "provider/"; update the loop to normalize keys by
checking if model begins with pricing.Provider+"/" (or any "<provider>/") and
strip that leading "<provider>/" before constructing the stored key so you
always store a single "provider/model" entry.


// GetPricingDataForModel returns pricing data for a model.
// This is used to get pricing data for a model.
//
// Parameters:
// - model: The model to get pricing data for
// - provider: The provider to get pricing data for
//
// Returns:
// - pricing: The pricing data for the model, nil if not found
func (bifrost *Bifrost) GetPricingDataForModel(model string, provider schemas.ModelProvider) *schemas.DataSheetPricingEntry {
pricing, ok := bifrost.pricingData.Load(string(provider) + "/" + model)
if !ok {
return nil
}
if pricing, ok := pricing.(schemas.DataSheetPricingEntry); ok {
return &pricing
}
return nil
}

// GetConfiguredProviders returns a configured providers list.
func (bifrost *Bifrost) GetConfiguredProviders() ([]schemas.ModelProvider, error) {
providers := bifrost.providers.Load()
if providers == nil {
Expand Down
183 changes: 121 additions & 62 deletions core/schemas/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package schemas
import (
"encoding/base64"
"fmt"
"strings"

"github.com/bytedance/sonic"
)
Expand Down Expand Up @@ -44,9 +45,95 @@ type BifrostListModelsResponse struct {
HasMore *bool `json:"-"`
}

// ApplyPagination applies offset-based pagination to a BifrostListModelsResponse.
// Uses opaque tokens with LastID validation to ensure cursor integrity.
// Returns the paginated response with properly set NextPageToken.
func (response *BifrostListModelsResponse) ApplyPagination(pageSize int, pageToken string) *BifrostListModelsResponse {
if response == nil {
return nil
}

totalItems := len(response.Data)

if pageSize <= 0 {
return response
}

cursor := decodePaginationCursor(pageToken)
offset := cursor.Offset

// Validate cursor integrity if LastID is present
if cursor.LastID != "" && !validatePaginationCursor(cursor, response.Data) {
// Invalid cursor: reset to beginning
offset = 0
}

if offset >= totalItems {
// Return empty page, no next token
return &BifrostListModelsResponse{
Data: []Model{},
ExtraFields: response.ExtraFields,
NextPageToken: "",
}
}

endIndex := offset + pageSize
if endIndex > totalItems {
endIndex = totalItems
}

paginatedData := response.Data[offset:endIndex]

paginatedResponse := &BifrostListModelsResponse{
Data: paginatedData,
ExtraFields: response.ExtraFields,
}

if endIndex < totalItems {
// Get the last item ID for cursor validation
var lastID string
if len(paginatedData) > 0 {
lastID = paginatedData[len(paginatedData)-1].ID
}

nextToken, err := encodePaginationCursor(endIndex, lastID)
if err == nil {
paginatedResponse.NextPageToken = nextToken
}
} else {
paginatedResponse.NextPageToken = ""
}

return paginatedResponse
}

type PricingFetcher func(model string, provider ModelProvider) *DataSheetPricingEntry

// AddPricing adds pricing data to the response.
// This is used to add pricing data to the response.
//
// Parameters:
// - fetcher: The pricing fetcher function
//
// Returns:
// - response: The response with pricing data
Comment on lines +118 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix documentation: method doesn't return a value.

The documentation indicates the method returns the response, but the method signature on line 120 shows func (response *BifrostListModelsResponse) AddPricing(fetcher PricingFetcher) with no return value. The method modifies the response in-place.

Apply this diff to fix the documentation:

 // AddPricing adds pricing data to the response.
-// This is used to add pricing data to the response.
+// This method modifies the response in-place by enriching each model with pricing information.
 //
 // Parameters:
 //   - fetcher: The pricing fetcher function
-//
-// Returns:
-//   - response: The response with pricing data

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In core/schemas/models.go around lines 118 to 119, the method comment
incorrectly states the method returns the response while the signature func
(response *BifrostListModelsResponse) AddPricing(fetcher PricingFetcher) has no
return; update the doc comment to reflect that AddPricing mutates the response
in-place (for example change "Returns: - response: The response with pricing
data" to "Modifies the response in-place to include pricing data" or remove the
Returns section entirely), keeping wording concise and accurate.

func (response *BifrostListModelsResponse) AddPricing(fetcher PricingFetcher) {
for i, modelData := range response.Data {
model := strings.TrimPrefix(modelData.ID, string(response.ExtraFields.Provider)+"/")
pricing := fetcher(model, response.ExtraFields.Provider)
if pricing != nil {
if response.Data[i].Pricing == nil {
response.Data[i].Pricing = &Pricing{}
}
response.Data[i].Pricing.DataSheetPricingEntry = pricing
}
}
}
Comment on lines +120 to +131
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add nil check for fetcher parameter to prevent panic.

If fetcher is nil, line 123 will cause a nil pointer dereference panic. Add a defensive check at the start of the method.

Apply this diff to add the nil check:

 func (response *BifrostListModelsResponse) AddPricing(fetcher PricingFetcher) {
+	if fetcher == nil {
+		return
+	}
 	for i, modelData := range response.Data {
🤖 Prompt for AI Agents
In core/schemas/models.go around lines 120 to 131, the method AddPricing can
panic when the PricingFetcher parameter is nil because fetcher(...) is called
unguarded; add a defensive nil-check at the top of the method (e.g., if fetcher
== nil { return }) to return early when no fetcher is provided, keeping the rest
of the logic unchanged so you avoid a nil pointer dereference.


type Model struct {
ID string `json:"id"`
CanonicalSlug *string `json:"canonical_slug,omitempty"`
DeploymentName *string `json:"deployment_name,omitempty"`
Name *string `json:"name,omitempty"`
Created *int64 `json:"created,omitempty"`
ContextLength *int `json:"context_length,omitempty"`
Expand Down Expand Up @@ -82,6 +169,8 @@ type Pricing struct {
InternalReasoning *string `json:"internal_reasoning,omitempty"`
InputCacheRead *string `json:"input_cache_read,omitempty"`
InputCacheWrite *string `json:"input_cache_write,omitempty"`

*DataSheetPricingEntry
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Verify the intended API surface with dual pricing representations.

Embedding *DataSheetPricingEntry promotes all its fields to the Pricing JSON output, creating two sets of pricing information:

  • Existing string-based fields: Prompt, Completion, Request, etc.
  • New detailed float64-based fields: InputCostPerToken, OutputCostPerToken, etc.

This dual representation might be intentional for backward compatibility, but could confuse API consumers about which pricing data to use.

Run the following script to check if there's documentation or tests explaining the relationship between these two pricing representations:


🏁 Script executed:

#!/bin/bash
# Search for documentation or usage patterns explaining the dual pricing format

# Look for comments or docs explaining Pricing struct usage
rg -nP -C5 'type Pricing struct|Pricing.*\{' --type=go

# Look for tests that populate both old and new pricing fields
rg -nP -C3 'Pricing\s*\{' --type=go -g '*_test.go'

# Search for any migration notes or comments about pricing changes
rg -nP -C3 'pricing.*backward|backward.*pricing|pricing.*compat|legacy.*pricing' --type=go -i

Length of output: 40613


Now let me search the codebase directly to inspect the struct definition and any custom marshaling:

I need to inspect the actual codebase directly. Let me extract the specific struct definitions and check for custom marshaling:

#!/bin/bash
# Get the exact Pricing struct definition including all fields and tags
cat core/schemas/models.go | sed -n '163,180p'

# Check if Pricing or DataSheetPricingEntry have MarshalJSON methods
echo "=== Checking for MarshalJSON on Pricing ==="
rg -n "func.*Pricing.*MarshalJSON" --type=go

echo "=== Checking for MarshalJSON on DataSheetPricingEntry ==="
rg -n "func.*DataSheetPricingEntry.*MarshalJSON" --type=go

# Search for any JSON tag patterns on DataSheetPricingEntry
echo "=== DataSheetPricingEntry struct definition ==="
rg -n -A20 "type DataSheetPricingEntry struct" --type=go

#!/bin/bash
# Extract Pricing struct complete definition
echo "=== Pricing struct definition (lines 163-180) ==="
sed -n '163,180p' core/schemas/models.go

# Extract DataSheetPricingEntry struct complete definition
echo -e "\n=== DataSheetPricingEntry struct definition ==="
rg -A30 "type DataSheetPricingEntry struct" core/schemas/models.go

# Check for MarshalJSON methods
echo -e "\n=== Checking MarshalJSON implementations ==="
rg -n "func.*MarshalJSON" core/schemas/models.go

🏁 Script executed:

#!/bin/bash
# Extract Pricing struct complete definition around line 163
echo "=== Pricing struct definition ==="
sed -n '163,195p' core/schemas/models.go

# Extract DataSheetPricingEntry struct complete definition
echo -e "\n=== DataSheetPricingEntry struct definition ==="
sed -n '200,240p' core/schemas/models.go

# Check for MarshalJSON methods in the file
echo -e "\n=== MarshalJSON implementations ==="
grep -n "func.*MarshalJSON" core/schemas/models.go || echo "No MarshalJSON found"

Length of output: 3529


🏁 Script executed:

#!/bin/bash
# Search for tests that create Pricing structs and show JSON output expectations
echo "=== Tests with Pricing struct ==="
rg -n "Pricing\s*{" core/schemas/models_test.go 2>/dev/null || rg -n "Pricing" --type=go -g '*_test.go' -C2 | head -50

# Search for comments explaining the dual pricing design
echo -e "\n=== Comments near Pricing struct ==="
sed -n '155,165p' core/schemas/models.go

# Search for JSON marshaling tests
echo -e "\n=== JSON marshaling or response format documentation ==="
rg -n "json\." --type=go -C2 core/schemas/models.go | head -40

# Check if there are actual API response examples or integration tests
echo -e "\n=== Integration tests or response examples ==="
find . -name "*.go" -type f -exec grep -l "BifrostListModelsResponse.*json" {} \; 2>/dev/null | head -5

Length of output: 637


Add documentation explaining the dual pricing representation in the Pricing struct.

The embedded *DataSheetPricingEntry (line 173) promotes its fields to the Pricing JSON output due to Go's embedding behavior, creating dual pricing information: legacy string-based fields (Prompt, Completion, etc.) alongside detailed float64-based fields (InputCostPerToken, OutputCostPerToken, etc.). While this design is intentional for backward compatibility, it lacks documentation explaining the relationship between the two field sets, which could confuse API consumers about which pricing data to use.

Add inline comments to the Pricing struct explaining:

  • Why both field sets coexist
  • Which represents the current/preferred format
  • When/if legacy fields are deprecated
🤖 Prompt for AI Agents
In core/schemas/models.go around line 173, the Pricing struct embeds
*DataSheetPricingEntry which promotes legacy string-based fields into the JSON
output alongside the newer float64-based fields, causing dual pricing
representations; update the struct by adding concise inline comments that
explain: why both sets coexist (backwards compatibility), that the float64
fields (InputCostPerToken/OutputCostPerToken/etc.) are the current/preferred
format, which legacy string fields map to those values (e.g., Prompt/Completion
-> input/output) and when the legacy fields will be deprecated or removed
(include expected deprecation timeline or “TBD” if not decided), so API
consumers know which to prefer and the migration path.

}

type TopProvider struct {
Expand All @@ -107,6 +196,38 @@ type paginationCursor struct {
LastID string `json:"l,omitempty"`
}

// PricingEntry represents a single model's pricing information
type DataSheetPricingEntry struct {
// Basic pricing
InputCostPerToken float64 `json:"input_cost_per_token"`
OutputCostPerToken float64 `json:"output_cost_per_token"`
Provider string `json:"provider"`
Mode string `json:"mode"`

// Additional pricing for media
InputCostPerImage *float64 `json:"input_cost_per_image,omitempty"`
InputCostPerVideoPerSecond *float64 `json:"input_cost_per_video_per_second,omitempty"`
InputCostPerAudioPerSecond *float64 `json:"input_cost_per_audio_per_second,omitempty"`

// Character-based pricing
InputCostPerCharacter *float64 `json:"input_cost_per_character,omitempty"`
OutputCostPerCharacter *float64 `json:"output_cost_per_character,omitempty"`

// Pricing above 128k tokens
InputCostPerTokenAbove128kTokens *float64 `json:"input_cost_per_token_above_128k_tokens,omitempty"`
InputCostPerCharacterAbove128kTokens *float64 `json:"input_cost_per_character_above_128k_tokens,omitempty"`
InputCostPerImageAbove128kTokens *float64 `json:"input_cost_per_image_above_128k_tokens,omitempty"`
InputCostPerVideoPerSecondAbove128kTokens *float64 `json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"`
InputCostPerAudioPerSecondAbove128kTokens *float64 `json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"`
OutputCostPerTokenAbove128kTokens *float64 `json:"output_cost_per_token_above_128k_tokens,omitempty"`
OutputCostPerCharacterAbove128kTokens *float64 `json:"output_cost_per_character_above_128k_tokens,omitempty"`

// Cache and batch pricing
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost,omitempty"`
InputCostPerTokenBatches *float64 `json:"input_cost_per_token_batches,omitempty"`
OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"`
}

// encodePaginationCursor creates an opaque base64-encoded page token from cursor data.
// Returns empty string if offset is 0 or negative.
func encodePaginationCursor(offset int, lastID string) (string, error) {
Expand Down Expand Up @@ -172,65 +293,3 @@ func validatePaginationCursor(cursor paginationCursor, data []Model) bool {

return true
}

// ApplyPagination applies offset-based pagination to a BifrostListModelsResponse.
// Uses opaque tokens with LastID validation to ensure cursor integrity.
// Returns the paginated response with properly set NextPageToken.
func (response *BifrostListModelsResponse) ApplyPagination(pageSize int, pageToken string) *BifrostListModelsResponse {
if response == nil {
return nil
}

totalItems := len(response.Data)

if pageSize <= 0 {
return response
}

cursor := decodePaginationCursor(pageToken)
offset := cursor.Offset

// Validate cursor integrity if LastID is present
if cursor.LastID != "" && !validatePaginationCursor(cursor, response.Data) {
// Invalid cursor: reset to beginning
offset = 0
}

if offset >= totalItems {
// Return empty page, no next token
return &BifrostListModelsResponse{
Data: []Model{},
ExtraFields: response.ExtraFields,
NextPageToken: "",
}
}

endIndex := offset + pageSize
if endIndex > totalItems {
endIndex = totalItems
}

paginatedData := response.Data[offset:endIndex]

paginatedResponse := &BifrostListModelsResponse{
Data: paginatedData,
ExtraFields: response.ExtraFields,
}

if endIndex < totalItems {
// Get the last item ID for cursor validation
var lastID string
if len(paginatedData) > 0 {
lastID = paginatedData[len(paginatedData)-1].ID
}

nextToken, err := encodePaginationCursor(endIndex, lastID)
if err == nil {
paginatedResponse.NextPageToken = nextToken
}
} else {
paginatedResponse.NextPageToken = ""
}

return paginatedResponse
}
55 changes: 19 additions & 36 deletions framework/modelcatalog/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ const (

// Config is the model pricing configuration.
type Config struct {
PricingURL *string `json:"pricing_url,omitempty"`
PricingSyncInterval *time.Duration `json:"pricing_sync_interval,omitempty"`
PricingURL *string `json:"pricing_url,omitempty"`
PricingSyncInterval *time.Duration `json:"pricing_sync_interval,omitempty"`
PricingSyncCallback func(map[string]schemas.DataSheetPricingEntry) `json:"pricing_sync_callback,omitempty"`
}

type ModelCatalog struct {
Expand All @@ -49,41 +50,9 @@ type ModelCatalog struct {
wg sync.WaitGroup
syncCtx context.Context
syncCancel context.CancelFunc
}

// PricingData represents the structure of the pricing.json file
type PricingData map[string]PricingEntry

// PricingEntry represents a single model's pricing information
type PricingEntry struct {
// Basic pricing
InputCostPerToken float64 `json:"input_cost_per_token"`
OutputCostPerToken float64 `json:"output_cost_per_token"`
Provider string `json:"provider"`
Mode string `json:"mode"`

// Additional pricing for media
InputCostPerImage *float64 `json:"input_cost_per_image,omitempty"`
InputCostPerVideoPerSecond *float64 `json:"input_cost_per_video_per_second,omitempty"`
InputCostPerAudioPerSecond *float64 `json:"input_cost_per_audio_per_second,omitempty"`

// Character-based pricing
InputCostPerCharacter *float64 `json:"input_cost_per_character,omitempty"`
OutputCostPerCharacter *float64 `json:"output_cost_per_character,omitempty"`

// Pricing above 128k tokens
InputCostPerTokenAbove128kTokens *float64 `json:"input_cost_per_token_above_128k_tokens,omitempty"`
InputCostPerCharacterAbove128kTokens *float64 `json:"input_cost_per_character_above_128k_tokens,omitempty"`
InputCostPerImageAbove128kTokens *float64 `json:"input_cost_per_image_above_128k_tokens,omitempty"`
InputCostPerVideoPerSecondAbove128kTokens *float64 `json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"`
InputCostPerAudioPerSecondAbove128kTokens *float64 `json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"`
OutputCostPerTokenAbove128kTokens *float64 `json:"output_cost_per_token_above_128k_tokens,omitempty"`
OutputCostPerCharacterAbove128kTokens *float64 `json:"output_cost_per_character_above_128k_tokens,omitempty"`

// Cache and batch pricing
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost,omitempty"`
InputCostPerTokenBatches *float64 `json:"input_cost_per_token_batches,omitempty"`
OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"`
// Callback after pricing data is synced
pricingSyncCallback func(map[string]schemas.DataSheetPricingEntry)
}

// Init initializes the pricing manager
Expand All @@ -105,6 +74,7 @@ func Init(ctx context.Context, config *Config, configStore configstore.ConfigSto
pricingData: make(map[string]configstoreTables.TableModelPricing),
modelPool: make(map[schemas.ModelProvider][]string),
done: make(chan struct{}),
pricingSyncCallback: config.PricingSyncCallback,
}

logger.Info("initializing pricing manager...")
Expand Down Expand Up @@ -189,6 +159,19 @@ func (mc *ModelCatalog) getPricingSyncInterval() time.Duration {
return mc.pricingSyncInterval
}

// GetPricingData returns the pricing data
func (mc *ModelCatalog) GetPricingData() map[string]schemas.DataSheetPricingEntry {
mc.mu.RLock()
defer mc.mu.RUnlock()
// Make a copy of the pricing data
pricingData := make(map[string]schemas.DataSheetPricingEntry)
for key, pricing := range mc.pricingData {
model, _, _ := splitKey(key)
pricingData[model] = convertTableModelPricingToPricingData(pricing)
}
return pricingData
}

// GetModelsForProvider returns all available models for a given provider (thread-safe)
func (mc *ModelCatalog) GetModelsForProvider(provider schemas.ModelProvider) []string {
mc.mu.RLock()
Expand Down
Loading