@@ -46,6 +46,7 @@ type Bifrost struct {
4646 logger schemas.Logger // logger instance, default logger is used if not provided
4747 mcpManager * MCPManager // MCP integration manager (nil if MCP not configured)
4848 dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
49+ keySelector schemas.KeySelector // Custom key selector function
4950}
5051
5152// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation.
@@ -86,10 +87,15 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
8687 plugins : atomic.Pointer [[]schemas.Plugin ]{},
8788 requestQueues : sync.Map {},
8889 waitGroups : sync.Map {},
90+ keySelector : config .KeySelector ,
8991 }
9092 bifrost .plugins .Store (& config .Plugins )
9193 bifrost .dropExcessRequests .Store (config .DropExcessRequests )
9294
95+ if bifrost .keySelector == nil {
96+ bifrost .keySelector = WeightedRandomKeySelector
97+ }
98+
9399 // Initialize object pools
94100 bifrost .channelMessagePool = sync.Pool {
95101 New : func () interface {} {
@@ -626,12 +632,12 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi
626632 return bifrost .prepareProvider (providerKey , providerConfig )
627633 }
628634
629- oldQueue := oldQueueValue .(chan ChannelMessage )
635+ oldQueue := oldQueueValue .(chan * ChannelMessage )
630636
631637 bifrost .logger .Debug ("gracefully stopping existing workers for provider %s" , providerKey )
632638
633639 // Step 1: Create new queue with updated buffer size
634- newQueue := make (chan ChannelMessage , providerConfig .ConcurrencyAndBufferSize .BufferSize )
640+ newQueue := make (chan * ChannelMessage , providerConfig .ConcurrencyAndBufferSize .BufferSize )
635641
636642 // Step 2: Transfer any buffered requests from old queue to new queue
637643 // This prevents request loss during the transition
@@ -647,7 +653,7 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi
647653 // New queue is full, handle this request in a goroutine
648654 // This is unlikely with proper buffer sizing but provides safety
649655 transferWaitGroup .Add (1 )
650- go func (m ChannelMessage ) {
656+ go func (m * ChannelMessage ) {
651657 defer transferWaitGroup .Done ()
652658 select {
653659 case newQueue <- m :
@@ -1011,7 +1017,7 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
10111017 return fmt .Errorf ("failed to get config for provider: %v" , err )
10121018 }
10131019
1014- queue := make (chan ChannelMessage , providerConfig .ConcurrencyAndBufferSize .BufferSize ) // Buffered channel per provider
1020+ queue := make (chan * ChannelMessage , providerConfig .ConcurrencyAndBufferSize .BufferSize ) // Buffered channel per provider
10151021
10161022 bifrost .requestQueues .Store (providerKey , queue )
10171023
@@ -1038,13 +1044,13 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
10381044// If the queue doesn't exist, it creates one at runtime and initializes the provider,
10391045// given the provider config is provided in the account interface implementation.
10401046// This function uses read locks to prevent race conditions during provider updates.
1041- func (bifrost * Bifrost ) getProviderQueue (providerKey schemas.ModelProvider ) (chan ChannelMessage , error ) {
1047+ func (bifrost * Bifrost ) getProviderQueue (providerKey schemas.ModelProvider ) (chan * ChannelMessage , error ) {
10421048 // Use read lock to allow concurrent reads but prevent concurrent updates
10431049 providerMutex := bifrost .getProviderMutex (providerKey )
10441050 providerMutex .RLock ()
10451051
10461052 if queueValue , exists := bifrost .requestQueues .Load (providerKey ); exists {
1047- queue := queueValue .(chan ChannelMessage )
1053+ queue := queueValue .(chan * ChannelMessage )
10481054 providerMutex .RUnlock ()
10491055 return queue , nil
10501056 }
@@ -1057,7 +1063,7 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
10571063
10581064 // Double-check after acquiring write lock (another goroutine might have created it)
10591065 if queueValue , exists := bifrost .requestQueues .Load (providerKey ); exists {
1060- queue := queueValue .(chan ChannelMessage )
1066+ queue := queueValue .(chan * ChannelMessage )
10611067 return queue , nil
10621068 }
10631069
@@ -1073,7 +1079,7 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
10731079 }
10741080
10751081 queueValue , _ := bifrost .requestQueues .Load (providerKey )
1076- queue := queueValue .(chan ChannelMessage )
1082+ queue := queueValue .(chan * ChannelMessage )
10771083
10781084 return queue , nil
10791085}
@@ -1335,9 +1341,8 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13351341
13361342 msg := bifrost .getChannelMessage (* preReq )
13371343 msg .Context = ctx
1338- startTime := time .Now ()
13391344 select {
1340- case queue <- * msg :
1345+ case queue <- msg :
13411346 // Message was sent successfully
13421347 case <- ctx .Done ():
13431348 bifrost .releaseChannelMessage (msg )
@@ -1349,7 +1354,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13491354 return nil , newBifrostErrorFromMsg ("request dropped: queue is full" )
13501355 }
13511356 select {
1352- case queue <- * msg :
1357+ case queue <- msg :
13531358 // Message was sent successfully
13541359 case <- ctx .Done ():
13551360 bifrost .releaseChannelMessage (msg )
@@ -1362,11 +1367,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13621367 pluginCount := len (* bifrost .plugins .Load ())
13631368 select {
13641369 case result = <- msg .Response :
1365- latency := time .Since (startTime ).Milliseconds ()
1366- if result .ExtraFields .Latency == nil {
1367- result .ExtraFields .Latency = Ptr (float64 (latency ))
1368- }
1369- resp , bifrostErr := pipeline .RunPostHooks (& ctx , result , nil , pluginCount )
1370+ resp , bifrostErr := pipeline .RunPostHooks (& msg .Context , result , nil , pluginCount )
13701371 if bifrostErr != nil {
13711372 bifrost .releaseChannelMessage (msg )
13721373 return nil , bifrostErr
@@ -1375,7 +1376,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13751376 return resp , nil
13761377 case bifrostErrVal := <- msg .Err :
13771378 bifrostErrPtr := & bifrostErrVal
1378- resp , bifrostErrPtr = pipeline .RunPostHooks (& ctx , nil , bifrostErrPtr , pluginCount )
1379+ resp , bifrostErrPtr = pipeline .RunPostHooks (& msg . Context , nil , bifrostErrPtr , pluginCount )
13791380 bifrost .releaseChannelMessage (msg )
13801381 if bifrostErrPtr != nil {
13811382 return nil , bifrostErrPtr
@@ -1457,7 +1458,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
14571458 msg .Context = ctx
14581459
14591460 select {
1460- case queue <- * msg :
1461+ case queue <- msg :
14611462 // Message was sent successfully
14621463 case <- ctx .Done ():
14631464 bifrost .releaseChannelMessage (msg )
@@ -1469,7 +1470,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
14691470 return nil , newBifrostErrorFromMsg ("request dropped: queue is full" )
14701471 }
14711472 select {
1472- case queue <- * msg :
1473+ case queue <- msg :
14731474 // Message was sent successfully
14741475 case <- ctx .Done ():
14751476 bifrost .releaseChannelMessage (msg )
@@ -1500,7 +1501,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
15001501
15011502// requestWorker handles incoming requests from the queue for a specific provider.
15021503// It manages retries, error handling, and response processing.
1503- func (bifrost * Bifrost ) requestWorker (provider schemas.Provider , config * schemas.ProviderConfig , queue chan ChannelMessage ) {
1504+ func (bifrost * Bifrost ) requestWorker (provider schemas.Provider , config * schemas.ProviderConfig , queue chan * ChannelMessage ) {
15041505 defer func () {
15051506 if waitGroupValue , ok := bifrost .waitGroups .Load (provider .GetProviderKey ()); ok {
15061507 waitGroup := waitGroupValue .(* sync.WaitGroup )
@@ -1535,6 +1536,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
15351536 }
15361537 continue
15371538 }
1539+ req .Context = context .WithValue (req .Context , schemas .BifrostContextKeySelectedKey , key .ID )
15381540 }
15391541
15401542 // Track attempts
@@ -1570,12 +1572,12 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
15701572
15711573 // Attempt the request
15721574 if IsStreamRequestType (req .RequestType ) {
1573- stream , bifrostError = handleProviderStreamRequest (provider , & req , key , postHookRunner )
1575+ stream , bifrostError = handleProviderStreamRequest (provider , req , key , postHookRunner )
15741576 if bifrostError != nil && ! bifrostError .IsBifrostError {
15751577 break // Don't retry client errors
15761578 }
15771579 } else {
1578- result , bifrostError = handleProviderRequest (provider , & req , key )
1580+ result , bifrostError = handleProviderRequest (provider , req , key )
15791581 if bifrostError != nil {
15801582 break // Don't retry client errors
15811583 }
@@ -1924,9 +1926,19 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
19241926 return supportedKeys [0 ], nil
19251927 }
19261928
1929+ selectedKey , err := bifrost .keySelector (ctx , supportedKeys , providerKey , model )
1930+ if err != nil {
1931+ return schemas.Key {}, err
1932+ }
1933+
1934+ return selectedKey , nil
1935+
1936+ }
1937+
1938+ func WeightedRandomKeySelector (ctx * context.Context , keys []schemas.Key , providerKey schemas.ModelProvider , model string ) (schemas.Key , error ) {
19271939 // Use a weighted random selection based on key weights
19281940 totalWeight := 0
1929- for _ , key := range supportedKeys {
1941+ for _ , key := range keys {
19301942 totalWeight += int (key .Weight * 100 ) // Convert float to int for better performance
19311943 }
19321944
@@ -1936,15 +1948,15 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
19361948
19371949 // Select key based on weight
19381950 currentWeight := 0
1939- for _ , key := range supportedKeys {
1951+ for _ , key := range keys {
19401952 currentWeight += int (key .Weight * 100 )
19411953 if randomValue < currentWeight {
19421954 return key , nil
19431955 }
19441956 }
19451957
19461958 // Fallback to first key if something goes wrong
1947- return supportedKeys [0 ], nil
1959+ return keys [0 ], nil
19481960}
19491961
19501962// Shutdown gracefully stops all workers when triggered.
@@ -1954,7 +1966,7 @@ func (bifrost *Bifrost) Shutdown() {
19541966
19551967 // Close all provider queues to signal workers to stop
19561968 bifrost .requestQueues .Range (func (key , value interface {}) bool {
1957- close (value .(chan ChannelMessage ))
1969+ close (value .(chan * ChannelMessage ))
19581970 return true
19591971 })
19601972
0 commit comments