From 7e6b09b3342da2916b8bae9f6f3a7a30c4f6d667 Mon Sep 17 00:00:00 2001 From: Pratham-Mishra04 Date: Tue, 4 Nov 2025 14:24:22 +0530 Subject: [PATCH] feat: model pricing added to list models endpoint --- core/bifrost.go | 39 +++++ core/schemas/models.go | 183 +++++++++++++++-------- framework/modelcatalog/main.go | 55 +++---- framework/modelcatalog/sync.go | 10 +- framework/modelcatalog/utils.go | 33 +++- transports/bifrost-http/lib/config.go | 10 ++ transports/bifrost-http/server/server.go | 2 + 7 files changed, 231 insertions(+), 101 deletions(-) diff --git a/core/bifrost.go b/core/bifrost.go index 37958ac70..42d9ffc05 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -54,6 +54,7 @@ type Bifrost struct { responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init pluginPipelinePool sync.Pool // Pool for PluginPipeline objects bifrostRequestPool sync.Pool // Pool for BifrostRequest objects + pricingData sync.Map // pricing data for each model logger schemas.Logger // logger instance, default logger is used if not provided mcpManager *MCPManager // MCP integration manager (nil if MCP not configured) dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. @@ -98,6 +99,7 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { plugins: atomic.Pointer[[]schemas.Plugin]{}, requestQueues: sync.Map{}, waitGroups: sync.Map{}, + pricingData: sync.Map{}, keySelector: config.KeySelector, logger: config.Logger, } @@ -287,6 +289,8 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr } return nil, bifrostErr } + // Add pricing data to the response + response.AddPricing(bifrost.GetPricingDataForModel) return response, nil } @@ -369,6 +373,9 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr break } + // Add pricing data to the response + response.AddPricing(bifrost.GetPricingDataForModel) + providerModels = append(providerModels, response.Data...) // Check if there are more pages @@ -822,6 +829,38 @@ func (bifrost *Bifrost) ReloadPlugin(plugin schemas.Plugin) error { } } +// SetPricingData sets pricing data for all the models. +// This is used to set pricing data for all the models at once. +// +// Parameters: +// - pricingData: A map of model names to pricing data +func (bifrost *Bifrost) SetPricingData(pricingData map[string]schemas.DataSheetPricingEntry) { + for model, pricing := range pricingData { + bifrost.pricingData.Store(pricing.Provider+"/"+model, pricing) + } +} + +// GetPricingDataForModel returns pricing data for a model. +// This is used to get pricing data for a model. +// +// Parameters: +// - model: The model to get pricing data for +// - provider: The provider to get pricing data for +// +// Returns: +// - pricing: The pricing data for the model, nil if not found +func (bifrost *Bifrost) GetPricingDataForModel(model string, provider schemas.ModelProvider) *schemas.DataSheetPricingEntry { + pricing, ok := bifrost.pricingData.Load(string(provider) + "/" + model) + if !ok { + return nil + } + if pricing, ok := pricing.(schemas.DataSheetPricingEntry); ok { + return &pricing + } + return nil +} + +// GetConfiguredProviders returns a configured providers list. func (bifrost *Bifrost) GetConfiguredProviders() ([]schemas.ModelProvider, error) { providers := bifrost.providers.Load() if providers == nil { diff --git a/core/schemas/models.go b/core/schemas/models.go index 5b50077aa..fcc88cafb 100644 --- a/core/schemas/models.go +++ b/core/schemas/models.go @@ -3,6 +3,7 @@ package schemas import ( "encoding/base64" "fmt" + "strings" "github.com/bytedance/sonic" ) @@ -44,9 +45,95 @@ type BifrostListModelsResponse struct { HasMore *bool `json:"-"` } +// ApplyPagination applies offset-based pagination to a BifrostListModelsResponse. +// Uses opaque tokens with LastID validation to ensure cursor integrity. +// Returns the paginated response with properly set NextPageToken. +func (response *BifrostListModelsResponse) ApplyPagination(pageSize int, pageToken string) *BifrostListModelsResponse { + if response == nil { + return nil + } + + totalItems := len(response.Data) + + if pageSize <= 0 { + return response + } + + cursor := decodePaginationCursor(pageToken) + offset := cursor.Offset + + // Validate cursor integrity if LastID is present + if cursor.LastID != "" && !validatePaginationCursor(cursor, response.Data) { + // Invalid cursor: reset to beginning + offset = 0 + } + + if offset >= totalItems { + // Return empty page, no next token + return &BifrostListModelsResponse{ + Data: []Model{}, + ExtraFields: response.ExtraFields, + NextPageToken: "", + } + } + + endIndex := offset + pageSize + if endIndex > totalItems { + endIndex = totalItems + } + + paginatedData := response.Data[offset:endIndex] + + paginatedResponse := &BifrostListModelsResponse{ + Data: paginatedData, + ExtraFields: response.ExtraFields, + } + + if endIndex < totalItems { + // Get the last item ID for cursor validation + var lastID string + if len(paginatedData) > 0 { + lastID = paginatedData[len(paginatedData)-1].ID + } + + nextToken, err := encodePaginationCursor(endIndex, lastID) + if err == nil { + paginatedResponse.NextPageToken = nextToken + } + } else { + paginatedResponse.NextPageToken = "" + } + + return paginatedResponse +} + +type PricingFetcher func(model string, provider ModelProvider) *DataSheetPricingEntry + +// AddPricing adds pricing data to the response. +// This is used to add pricing data to the response. +// +// Parameters: +// - fetcher: The pricing fetcher function +// +// Returns: +// - response: The response with pricing data +func (response *BifrostListModelsResponse) AddPricing(fetcher PricingFetcher) { + for i, modelData := range response.Data { + model := strings.TrimPrefix(modelData.ID, string(response.ExtraFields.Provider)+"/") + pricing := fetcher(model, response.ExtraFields.Provider) + if pricing != nil { + if response.Data[i].Pricing == nil { + response.Data[i].Pricing = &Pricing{} + } + response.Data[i].Pricing.DataSheetPricingEntry = pricing + } + } +} + type Model struct { ID string `json:"id"` CanonicalSlug *string `json:"canonical_slug,omitempty"` + DeploymentName *string `json:"deployment_name,omitempty"` Name *string `json:"name,omitempty"` Created *int64 `json:"created,omitempty"` ContextLength *int `json:"context_length,omitempty"` @@ -82,6 +169,8 @@ type Pricing struct { InternalReasoning *string `json:"internal_reasoning,omitempty"` InputCacheRead *string `json:"input_cache_read,omitempty"` InputCacheWrite *string `json:"input_cache_write,omitempty"` + + *DataSheetPricingEntry } type TopProvider struct { @@ -107,6 +196,38 @@ type paginationCursor struct { LastID string `json:"l,omitempty"` } +// PricingEntry represents a single model's pricing information +type DataSheetPricingEntry struct { + // Basic pricing + InputCostPerToken float64 `json:"input_cost_per_token"` + OutputCostPerToken float64 `json:"output_cost_per_token"` + Provider string `json:"provider"` + Mode string `json:"mode"` + + // Additional pricing for media + InputCostPerImage *float64 `json:"input_cost_per_image,omitempty"` + InputCostPerVideoPerSecond *float64 `json:"input_cost_per_video_per_second,omitempty"` + InputCostPerAudioPerSecond *float64 `json:"input_cost_per_audio_per_second,omitempty"` + + // Character-based pricing + InputCostPerCharacter *float64 `json:"input_cost_per_character,omitempty"` + OutputCostPerCharacter *float64 `json:"output_cost_per_character,omitempty"` + + // Pricing above 128k tokens + InputCostPerTokenAbove128kTokens *float64 `json:"input_cost_per_token_above_128k_tokens,omitempty"` + InputCostPerCharacterAbove128kTokens *float64 `json:"input_cost_per_character_above_128k_tokens,omitempty"` + InputCostPerImageAbove128kTokens *float64 `json:"input_cost_per_image_above_128k_tokens,omitempty"` + InputCostPerVideoPerSecondAbove128kTokens *float64 `json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"` + InputCostPerAudioPerSecondAbove128kTokens *float64 `json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"` + OutputCostPerTokenAbove128kTokens *float64 `json:"output_cost_per_token_above_128k_tokens,omitempty"` + OutputCostPerCharacterAbove128kTokens *float64 `json:"output_cost_per_character_above_128k_tokens,omitempty"` + + // Cache and batch pricing + CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost,omitempty"` + InputCostPerTokenBatches *float64 `json:"input_cost_per_token_batches,omitempty"` + OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"` +} + // encodePaginationCursor creates an opaque base64-encoded page token from cursor data. // Returns empty string if offset is 0 or negative. func encodePaginationCursor(offset int, lastID string) (string, error) { @@ -172,65 +293,3 @@ func validatePaginationCursor(cursor paginationCursor, data []Model) bool { return true } - -// ApplyPagination applies offset-based pagination to a BifrostListModelsResponse. -// Uses opaque tokens with LastID validation to ensure cursor integrity. -// Returns the paginated response with properly set NextPageToken. -func (response *BifrostListModelsResponse) ApplyPagination(pageSize int, pageToken string) *BifrostListModelsResponse { - if response == nil { - return nil - } - - totalItems := len(response.Data) - - if pageSize <= 0 { - return response - } - - cursor := decodePaginationCursor(pageToken) - offset := cursor.Offset - - // Validate cursor integrity if LastID is present - if cursor.LastID != "" && !validatePaginationCursor(cursor, response.Data) { - // Invalid cursor: reset to beginning - offset = 0 - } - - if offset >= totalItems { - // Return empty page, no next token - return &BifrostListModelsResponse{ - Data: []Model{}, - ExtraFields: response.ExtraFields, - NextPageToken: "", - } - } - - endIndex := offset + pageSize - if endIndex > totalItems { - endIndex = totalItems - } - - paginatedData := response.Data[offset:endIndex] - - paginatedResponse := &BifrostListModelsResponse{ - Data: paginatedData, - ExtraFields: response.ExtraFields, - } - - if endIndex < totalItems { - // Get the last item ID for cursor validation - var lastID string - if len(paginatedData) > 0 { - lastID = paginatedData[len(paginatedData)-1].ID - } - - nextToken, err := encodePaginationCursor(endIndex, lastID) - if err == nil { - paginatedResponse.NextPageToken = nextToken - } - } else { - paginatedResponse.NextPageToken = "" - } - - return paginatedResponse -} diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index b1065c664..074201189 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -24,8 +24,9 @@ const ( // Config is the model pricing configuration. type Config struct { - PricingURL *string `json:"pricing_url,omitempty"` - PricingSyncInterval *time.Duration `json:"pricing_sync_interval,omitempty"` + PricingURL *string `json:"pricing_url,omitempty"` + PricingSyncInterval *time.Duration `json:"pricing_sync_interval,omitempty"` + PricingSyncCallback func(map[string]schemas.DataSheetPricingEntry) `json:"pricing_sync_callback,omitempty"` } type ModelCatalog struct { @@ -49,41 +50,9 @@ type ModelCatalog struct { wg sync.WaitGroup syncCtx context.Context syncCancel context.CancelFunc -} -// PricingData represents the structure of the pricing.json file -type PricingData map[string]PricingEntry - -// PricingEntry represents a single model's pricing information -type PricingEntry struct { - // Basic pricing - InputCostPerToken float64 `json:"input_cost_per_token"` - OutputCostPerToken float64 `json:"output_cost_per_token"` - Provider string `json:"provider"` - Mode string `json:"mode"` - - // Additional pricing for media - InputCostPerImage *float64 `json:"input_cost_per_image,omitempty"` - InputCostPerVideoPerSecond *float64 `json:"input_cost_per_video_per_second,omitempty"` - InputCostPerAudioPerSecond *float64 `json:"input_cost_per_audio_per_second,omitempty"` - - // Character-based pricing - InputCostPerCharacter *float64 `json:"input_cost_per_character,omitempty"` - OutputCostPerCharacter *float64 `json:"output_cost_per_character,omitempty"` - - // Pricing above 128k tokens - InputCostPerTokenAbove128kTokens *float64 `json:"input_cost_per_token_above_128k_tokens,omitempty"` - InputCostPerCharacterAbove128kTokens *float64 `json:"input_cost_per_character_above_128k_tokens,omitempty"` - InputCostPerImageAbove128kTokens *float64 `json:"input_cost_per_image_above_128k_tokens,omitempty"` - InputCostPerVideoPerSecondAbove128kTokens *float64 `json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"` - InputCostPerAudioPerSecondAbove128kTokens *float64 `json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"` - OutputCostPerTokenAbove128kTokens *float64 `json:"output_cost_per_token_above_128k_tokens,omitempty"` - OutputCostPerCharacterAbove128kTokens *float64 `json:"output_cost_per_character_above_128k_tokens,omitempty"` - - // Cache and batch pricing - CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost,omitempty"` - InputCostPerTokenBatches *float64 `json:"input_cost_per_token_batches,omitempty"` - OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"` + // Callback after pricing data is synced + pricingSyncCallback func(map[string]schemas.DataSheetPricingEntry) } // Init initializes the pricing manager @@ -105,6 +74,7 @@ func Init(ctx context.Context, config *Config, configStore configstore.ConfigSto pricingData: make(map[string]configstoreTables.TableModelPricing), modelPool: make(map[schemas.ModelProvider][]string), done: make(chan struct{}), + pricingSyncCallback: config.PricingSyncCallback, } logger.Info("initializing pricing manager...") @@ -189,6 +159,19 @@ func (mc *ModelCatalog) getPricingSyncInterval() time.Duration { return mc.pricingSyncInterval } +// GetPricingData returns the pricing data +func (mc *ModelCatalog) GetPricingData() map[string]schemas.DataSheetPricingEntry { + mc.mu.RLock() + defer mc.mu.RUnlock() + // Make a copy of the pricing data + pricingData := make(map[string]schemas.DataSheetPricingEntry) + for key, pricing := range mc.pricingData { + model, _, _ := splitKey(key) + pricingData[model] = convertTableModelPricingToPricingData(pricing) + } + return pricingData +} + // GetModelsForProvider returns all available models for a given provider (thread-safe) func (mc *ModelCatalog) GetModelsForProvider(provider schemas.ModelProvider) []string { mc.mu.RLock() diff --git a/framework/modelcatalog/sync.go b/framework/modelcatalog/sync.go index 6a909a164..e4aeac109 100644 --- a/framework/modelcatalog/sync.go +++ b/framework/modelcatalog/sync.go @@ -8,6 +8,7 @@ import ( "net/http" "time" + "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "gorm.io/gorm" ) @@ -127,12 +128,17 @@ func (mc *ModelCatalog) syncPricing(ctx context.Context) error { return fmt.Errorf("failed to reload pricing cache: %w", err) } + if mc.pricingSyncCallback != nil { + mc.pricingSyncCallback(pricingData) + mc.logger.Debug("pricing sync callback executed") + } + mc.logger.Info("successfully synced %d pricing records", len(pricingData)) return nil } // loadPricingFromURL loads pricing data from the remote URL -func (mc *ModelCatalog) loadPricingFromURL(ctx context.Context) (PricingData, error) { +func (mc *ModelCatalog) loadPricingFromURL(ctx context.Context) (map[string]schemas.DataSheetPricingEntry, error) { // Create HTTP client with timeout client := &http.Client{ Timeout: 30 * time.Second, @@ -160,7 +166,7 @@ func (mc *ModelCatalog) loadPricingFromURL(ctx context.Context) (PricingData, er } // Unmarshal JSON data - var pricingData PricingData + var pricingData map[string]schemas.DataSheetPricingEntry if err := json.Unmarshal(data, &pricingData); err != nil { return nil, fmt.Errorf("failed to unmarshal pricing data: %w", err) } diff --git a/framework/modelcatalog/utils.go b/framework/modelcatalog/utils.go index 2ff8e3a42..f702662b2 100644 --- a/framework/modelcatalog/utils.go +++ b/framework/modelcatalog/utils.go @@ -10,6 +10,12 @@ import ( // makeKey creates a unique key for a model, provider, and mode for pricingData map func makeKey(model, provider, mode string) string { return model + "|" + provider + "|" + mode } +// splitKey splits a key into model, provider, and mode +func splitKey(key string) (string, string, string) { + parts := strings.Split(key, "|") + return parts[0], parts[1], parts[2] +} + // isBatchRequest checks if the request is for batch processing func isBatchRequest(req *schemas.BifrostRequest) bool { // Check for batch endpoints or batch-specific headers @@ -73,7 +79,7 @@ func normalizeRequestType(reqType schemas.RequestType) string { } // convertPricingDataToTableModelPricing converts the pricing data to a TableModelPricing struct -func convertPricingDataToTableModelPricing(modelKey string, entry PricingEntry) configstoreTables.TableModelPricing { +func convertPricingDataToTableModelPricing(modelKey string, entry schemas.DataSheetPricingEntry) configstoreTables.TableModelPricing { provider := normalizeProvider(entry.Provider) // Handle provider/model format - extract just the model name @@ -119,6 +125,31 @@ func convertPricingDataToTableModelPricing(modelKey string, entry PricingEntry) return pricing } +// convertTableModelPricingToPricingData converts the TableModelPricing struct to a DataSheetPricingEntry struct +func convertTableModelPricingToPricingData(pricing configstoreTables.TableModelPricing) schemas.DataSheetPricingEntry { + return schemas.DataSheetPricingEntry{ + Provider: pricing.Provider, + Mode: pricing.Mode, + InputCostPerToken: pricing.InputCostPerToken, + OutputCostPerToken: pricing.OutputCostPerToken, + InputCostPerImage: pricing.InputCostPerImage, + InputCostPerVideoPerSecond: pricing.InputCostPerVideoPerSecond, + InputCostPerAudioPerSecond: pricing.InputCostPerAudioPerSecond, + InputCostPerCharacter: pricing.InputCostPerCharacter, + OutputCostPerCharacter: pricing.OutputCostPerCharacter, + InputCostPerTokenAbove128kTokens: pricing.InputCostPerTokenAbove128kTokens, + InputCostPerCharacterAbove128kTokens: pricing.InputCostPerCharacterAbove128kTokens, + InputCostPerImageAbove128kTokens: pricing.InputCostPerImageAbove128kTokens, + InputCostPerVideoPerSecondAbove128kTokens: pricing.InputCostPerVideoPerSecondAbove128kTokens, + InputCostPerAudioPerSecondAbove128kTokens: pricing.InputCostPerAudioPerSecondAbove128kTokens, + OutputCostPerTokenAbove128kTokens: pricing.OutputCostPerTokenAbove128kTokens, + OutputCostPerCharacterAbove128kTokens: pricing.OutputCostPerCharacterAbove128kTokens, + CacheReadInputTokenCost: pricing.CacheReadInputTokenCost, + InputCostPerTokenBatches: pricing.InputCostPerTokenBatches, + OutputCostPerTokenBatches: pricing.OutputCostPerTokenBatches, + } +} + // getSafeFloat64 returns the value of a float64 pointer or fallback if nil func getSafeFloat64(ptr *float64, fallback float64) float64 { if ptr != nil { diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 1aaf27d6f..acd5e8878 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -463,6 +463,11 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { if err != nil { return nil, fmt.Errorf("failed to update framework config: %w", err) } + pricingConfig.PricingSyncCallback = func(pricingData map[string]schemas.DataSheetPricingEntry) { + if config.client != nil { + config.client.SetPricingData(pricingData) + } + } config.FrameworkConfig = &framework.FrameworkConfig{ Pricing: pricingConfig, } @@ -912,6 +917,11 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { syncDuration := time.Duration(*configData.FrameworkConfig.Pricing.PricingSyncInterval) * time.Second pricingConfig.PricingSyncInterval = &syncDuration } + pricingConfig.PricingSyncCallback = func(pricingData map[string]schemas.DataSheetPricingEntry) { + if config.client != nil { + config.client.SetPricingData(pricingData) + } + } // Updating framework config config.FrameworkConfig = &framework.FrameworkConfig{ Pricing: pricingConfig, diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 41d3df4e8..5e433e1f9 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -736,6 +736,8 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { } else { s.Config.PricingManager.AddModelDataToPool(modelData) } + // Add pricing data to the client + s.Client.SetPricingData(s.Config.PricingManager.GetPricingData()) logger.Info("models added to catalog") s.Config.SetBifrostClient(s.Client) // Initializing middlewares