Skip to content

Commit 7e6b09b

Browse files
feat: model pricing added to list models endpoint
1 parent d334a40 commit 7e6b09b

File tree

7 files changed

+231
-101
lines changed

7 files changed

+231
-101
lines changed

core/bifrost.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ type Bifrost struct {
5454
responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init
5555
pluginPipelinePool sync.Pool // Pool for PluginPipeline objects
5656
bifrostRequestPool sync.Pool // Pool for BifrostRequest objects
57+
pricingData sync.Map // pricing data for each model
5758
logger schemas.Logger // logger instance, default logger is used if not provided
5859
mcpManager *MCPManager // MCP integration manager (nil if MCP not configured)
5960
dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
@@ -98,6 +99,7 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
9899
plugins: atomic.Pointer[[]schemas.Plugin]{},
99100
requestQueues: sync.Map{},
100101
waitGroups: sync.Map{},
102+
pricingData: sync.Map{},
101103
keySelector: config.KeySelector,
102104
logger: config.Logger,
103105
}
@@ -287,6 +289,8 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr
287289
}
288290
return nil, bifrostErr
289291
}
292+
// Add pricing data to the response
293+
response.AddPricing(bifrost.GetPricingDataForModel)
290294
return response, nil
291295
}
292296

@@ -369,6 +373,9 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr
369373
break
370374
}
371375

376+
// Add pricing data to the response
377+
response.AddPricing(bifrost.GetPricingDataForModel)
378+
372379
providerModels = append(providerModels, response.Data...)
373380

374381
// Check if there are more pages
@@ -822,6 +829,38 @@ func (bifrost *Bifrost) ReloadPlugin(plugin schemas.Plugin) error {
822829
}
823830
}
824831

832+
// SetPricingData sets pricing data for all the models.
833+
// This is used to set pricing data for all the models at once.
834+
//
835+
// Parameters:
836+
// - pricingData: A map of model names to pricing data
837+
func (bifrost *Bifrost) SetPricingData(pricingData map[string]schemas.DataSheetPricingEntry) {
838+
for model, pricing := range pricingData {
839+
bifrost.pricingData.Store(pricing.Provider+"/"+model, pricing)
840+
}
841+
}
842+
843+
// GetPricingDataForModel returns pricing data for a model.
844+
// This is used to get pricing data for a model.
845+
//
846+
// Parameters:
847+
// - model: The model to get pricing data for
848+
// - provider: The provider to get pricing data for
849+
//
850+
// Returns:
851+
// - pricing: The pricing data for the model, nil if not found
852+
func (bifrost *Bifrost) GetPricingDataForModel(model string, provider schemas.ModelProvider) *schemas.DataSheetPricingEntry {
853+
pricing, ok := bifrost.pricingData.Load(string(provider) + "/" + model)
854+
if !ok {
855+
return nil
856+
}
857+
if pricing, ok := pricing.(schemas.DataSheetPricingEntry); ok {
858+
return &pricing
859+
}
860+
return nil
861+
}
862+
863+
// GetConfiguredProviders returns a configured providers list.
825864
func (bifrost *Bifrost) GetConfiguredProviders() ([]schemas.ModelProvider, error) {
826865
providers := bifrost.providers.Load()
827866
if providers == nil {

core/schemas/models.go

Lines changed: 121 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package schemas
33
import (
44
"encoding/base64"
55
"fmt"
6+
"strings"
67

78
"github.com/bytedance/sonic"
89
)
@@ -44,9 +45,95 @@ type BifrostListModelsResponse struct {
4445
HasMore *bool `json:"-"`
4546
}
4647

48+
// ApplyPagination applies offset-based pagination to a BifrostListModelsResponse.
49+
// Uses opaque tokens with LastID validation to ensure cursor integrity.
50+
// Returns the paginated response with properly set NextPageToken.
51+
func (response *BifrostListModelsResponse) ApplyPagination(pageSize int, pageToken string) *BifrostListModelsResponse {
52+
if response == nil {
53+
return nil
54+
}
55+
56+
totalItems := len(response.Data)
57+
58+
if pageSize <= 0 {
59+
return response
60+
}
61+
62+
cursor := decodePaginationCursor(pageToken)
63+
offset := cursor.Offset
64+
65+
// Validate cursor integrity if LastID is present
66+
if cursor.LastID != "" && !validatePaginationCursor(cursor, response.Data) {
67+
// Invalid cursor: reset to beginning
68+
offset = 0
69+
}
70+
71+
if offset >= totalItems {
72+
// Return empty page, no next token
73+
return &BifrostListModelsResponse{
74+
Data: []Model{},
75+
ExtraFields: response.ExtraFields,
76+
NextPageToken: "",
77+
}
78+
}
79+
80+
endIndex := offset + pageSize
81+
if endIndex > totalItems {
82+
endIndex = totalItems
83+
}
84+
85+
paginatedData := response.Data[offset:endIndex]
86+
87+
paginatedResponse := &BifrostListModelsResponse{
88+
Data: paginatedData,
89+
ExtraFields: response.ExtraFields,
90+
}
91+
92+
if endIndex < totalItems {
93+
// Get the last item ID for cursor validation
94+
var lastID string
95+
if len(paginatedData) > 0 {
96+
lastID = paginatedData[len(paginatedData)-1].ID
97+
}
98+
99+
nextToken, err := encodePaginationCursor(endIndex, lastID)
100+
if err == nil {
101+
paginatedResponse.NextPageToken = nextToken
102+
}
103+
} else {
104+
paginatedResponse.NextPageToken = ""
105+
}
106+
107+
return paginatedResponse
108+
}
109+
110+
type PricingFetcher func(model string, provider ModelProvider) *DataSheetPricingEntry
111+
112+
// AddPricing adds pricing data to the response.
113+
// This is used to add pricing data to the response.
114+
//
115+
// Parameters:
116+
// - fetcher: The pricing fetcher function
117+
//
118+
// Returns:
119+
// - response: The response with pricing data
120+
func (response *BifrostListModelsResponse) AddPricing(fetcher PricingFetcher) {
121+
for i, modelData := range response.Data {
122+
model := strings.TrimPrefix(modelData.ID, string(response.ExtraFields.Provider)+"/")
123+
pricing := fetcher(model, response.ExtraFields.Provider)
124+
if pricing != nil {
125+
if response.Data[i].Pricing == nil {
126+
response.Data[i].Pricing = &Pricing{}
127+
}
128+
response.Data[i].Pricing.DataSheetPricingEntry = pricing
129+
}
130+
}
131+
}
132+
47133
type Model struct {
48134
ID string `json:"id"`
49135
CanonicalSlug *string `json:"canonical_slug,omitempty"`
136+
DeploymentName *string `json:"deployment_name,omitempty"`
50137
Name *string `json:"name,omitempty"`
51138
Created *int64 `json:"created,omitempty"`
52139
ContextLength *int `json:"context_length,omitempty"`
@@ -82,6 +169,8 @@ type Pricing struct {
82169
InternalReasoning *string `json:"internal_reasoning,omitempty"`
83170
InputCacheRead *string `json:"input_cache_read,omitempty"`
84171
InputCacheWrite *string `json:"input_cache_write,omitempty"`
172+
173+
*DataSheetPricingEntry
85174
}
86175

87176
type TopProvider struct {
@@ -107,6 +196,38 @@ type paginationCursor struct {
107196
LastID string `json:"l,omitempty"`
108197
}
109198

199+
// PricingEntry represents a single model's pricing information
200+
type DataSheetPricingEntry struct {
201+
// Basic pricing
202+
InputCostPerToken float64 `json:"input_cost_per_token"`
203+
OutputCostPerToken float64 `json:"output_cost_per_token"`
204+
Provider string `json:"provider"`
205+
Mode string `json:"mode"`
206+
207+
// Additional pricing for media
208+
InputCostPerImage *float64 `json:"input_cost_per_image,omitempty"`
209+
InputCostPerVideoPerSecond *float64 `json:"input_cost_per_video_per_second,omitempty"`
210+
InputCostPerAudioPerSecond *float64 `json:"input_cost_per_audio_per_second,omitempty"`
211+
212+
// Character-based pricing
213+
InputCostPerCharacter *float64 `json:"input_cost_per_character,omitempty"`
214+
OutputCostPerCharacter *float64 `json:"output_cost_per_character,omitempty"`
215+
216+
// Pricing above 128k tokens
217+
InputCostPerTokenAbove128kTokens *float64 `json:"input_cost_per_token_above_128k_tokens,omitempty"`
218+
InputCostPerCharacterAbove128kTokens *float64 `json:"input_cost_per_character_above_128k_tokens,omitempty"`
219+
InputCostPerImageAbove128kTokens *float64 `json:"input_cost_per_image_above_128k_tokens,omitempty"`
220+
InputCostPerVideoPerSecondAbove128kTokens *float64 `json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"`
221+
InputCostPerAudioPerSecondAbove128kTokens *float64 `json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"`
222+
OutputCostPerTokenAbove128kTokens *float64 `json:"output_cost_per_token_above_128k_tokens,omitempty"`
223+
OutputCostPerCharacterAbove128kTokens *float64 `json:"output_cost_per_character_above_128k_tokens,omitempty"`
224+
225+
// Cache and batch pricing
226+
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost,omitempty"`
227+
InputCostPerTokenBatches *float64 `json:"input_cost_per_token_batches,omitempty"`
228+
OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"`
229+
}
230+
110231
// encodePaginationCursor creates an opaque base64-encoded page token from cursor data.
111232
// Returns empty string if offset is 0 or negative.
112233
func encodePaginationCursor(offset int, lastID string) (string, error) {
@@ -172,65 +293,3 @@ func validatePaginationCursor(cursor paginationCursor, data []Model) bool {
172293

173294
return true
174295
}
175-
176-
// ApplyPagination applies offset-based pagination to a BifrostListModelsResponse.
177-
// Uses opaque tokens with LastID validation to ensure cursor integrity.
178-
// Returns the paginated response with properly set NextPageToken.
179-
func (response *BifrostListModelsResponse) ApplyPagination(pageSize int, pageToken string) *BifrostListModelsResponse {
180-
if response == nil {
181-
return nil
182-
}
183-
184-
totalItems := len(response.Data)
185-
186-
if pageSize <= 0 {
187-
return response
188-
}
189-
190-
cursor := decodePaginationCursor(pageToken)
191-
offset := cursor.Offset
192-
193-
// Validate cursor integrity if LastID is present
194-
if cursor.LastID != "" && !validatePaginationCursor(cursor, response.Data) {
195-
// Invalid cursor: reset to beginning
196-
offset = 0
197-
}
198-
199-
if offset >= totalItems {
200-
// Return empty page, no next token
201-
return &BifrostListModelsResponse{
202-
Data: []Model{},
203-
ExtraFields: response.ExtraFields,
204-
NextPageToken: "",
205-
}
206-
}
207-
208-
endIndex := offset + pageSize
209-
if endIndex > totalItems {
210-
endIndex = totalItems
211-
}
212-
213-
paginatedData := response.Data[offset:endIndex]
214-
215-
paginatedResponse := &BifrostListModelsResponse{
216-
Data: paginatedData,
217-
ExtraFields: response.ExtraFields,
218-
}
219-
220-
if endIndex < totalItems {
221-
// Get the last item ID for cursor validation
222-
var lastID string
223-
if len(paginatedData) > 0 {
224-
lastID = paginatedData[len(paginatedData)-1].ID
225-
}
226-
227-
nextToken, err := encodePaginationCursor(endIndex, lastID)
228-
if err == nil {
229-
paginatedResponse.NextPageToken = nextToken
230-
}
231-
} else {
232-
paginatedResponse.NextPageToken = ""
233-
}
234-
235-
return paginatedResponse
236-
}

framework/modelcatalog/main.go

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ const (
2424

2525
// Config is the model pricing configuration.
2626
type Config struct {
27-
PricingURL *string `json:"pricing_url,omitempty"`
28-
PricingSyncInterval *time.Duration `json:"pricing_sync_interval,omitempty"`
27+
PricingURL *string `json:"pricing_url,omitempty"`
28+
PricingSyncInterval *time.Duration `json:"pricing_sync_interval,omitempty"`
29+
PricingSyncCallback func(map[string]schemas.DataSheetPricingEntry) `json:"pricing_sync_callback,omitempty"`
2930
}
3031

3132
type ModelCatalog struct {
@@ -49,41 +50,9 @@ type ModelCatalog struct {
4950
wg sync.WaitGroup
5051
syncCtx context.Context
5152
syncCancel context.CancelFunc
52-
}
5353

54-
// PricingData represents the structure of the pricing.json file
55-
type PricingData map[string]PricingEntry
56-
57-
// PricingEntry represents a single model's pricing information
58-
type PricingEntry struct {
59-
// Basic pricing
60-
InputCostPerToken float64 `json:"input_cost_per_token"`
61-
OutputCostPerToken float64 `json:"output_cost_per_token"`
62-
Provider string `json:"provider"`
63-
Mode string `json:"mode"`
64-
65-
// Additional pricing for media
66-
InputCostPerImage *float64 `json:"input_cost_per_image,omitempty"`
67-
InputCostPerVideoPerSecond *float64 `json:"input_cost_per_video_per_second,omitempty"`
68-
InputCostPerAudioPerSecond *float64 `json:"input_cost_per_audio_per_second,omitempty"`
69-
70-
// Character-based pricing
71-
InputCostPerCharacter *float64 `json:"input_cost_per_character,omitempty"`
72-
OutputCostPerCharacter *float64 `json:"output_cost_per_character,omitempty"`
73-
74-
// Pricing above 128k tokens
75-
InputCostPerTokenAbove128kTokens *float64 `json:"input_cost_per_token_above_128k_tokens,omitempty"`
76-
InputCostPerCharacterAbove128kTokens *float64 `json:"input_cost_per_character_above_128k_tokens,omitempty"`
77-
InputCostPerImageAbove128kTokens *float64 `json:"input_cost_per_image_above_128k_tokens,omitempty"`
78-
InputCostPerVideoPerSecondAbove128kTokens *float64 `json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"`
79-
InputCostPerAudioPerSecondAbove128kTokens *float64 `json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"`
80-
OutputCostPerTokenAbove128kTokens *float64 `json:"output_cost_per_token_above_128k_tokens,omitempty"`
81-
OutputCostPerCharacterAbove128kTokens *float64 `json:"output_cost_per_character_above_128k_tokens,omitempty"`
82-
83-
// Cache and batch pricing
84-
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost,omitempty"`
85-
InputCostPerTokenBatches *float64 `json:"input_cost_per_token_batches,omitempty"`
86-
OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"`
54+
// Callback after pricing data is synced
55+
pricingSyncCallback func(map[string]schemas.DataSheetPricingEntry)
8756
}
8857

8958
// Init initializes the pricing manager
@@ -105,6 +74,7 @@ func Init(ctx context.Context, config *Config, configStore configstore.ConfigSto
10574
pricingData: make(map[string]configstoreTables.TableModelPricing),
10675
modelPool: make(map[schemas.ModelProvider][]string),
10776
done: make(chan struct{}),
77+
pricingSyncCallback: config.PricingSyncCallback,
10878
}
10979

11080
logger.Info("initializing pricing manager...")
@@ -189,6 +159,19 @@ func (mc *ModelCatalog) getPricingSyncInterval() time.Duration {
189159
return mc.pricingSyncInterval
190160
}
191161

162+
// GetPricingData returns the pricing data
163+
func (mc *ModelCatalog) GetPricingData() map[string]schemas.DataSheetPricingEntry {
164+
mc.mu.RLock()
165+
defer mc.mu.RUnlock()
166+
// Make a copy of the pricing data
167+
pricingData := make(map[string]schemas.DataSheetPricingEntry)
168+
for key, pricing := range mc.pricingData {
169+
model, _, _ := splitKey(key)
170+
pricingData[model] = convertTableModelPricingToPricingData(pricing)
171+
}
172+
return pricingData
173+
}
174+
192175
// GetModelsForProvider returns all available models for a given provider (thread-safe)
193176
func (mc *ModelCatalog) GetModelsForProvider(provider schemas.ModelProvider) []string {
194177
mc.mu.RLock()

0 commit comments

Comments
 (0)