Skip to content

Commit d80798a

Browse files
feat: core extended
1 parent c90ec02 commit d80798a

File tree

15 files changed

+167
-127
lines changed

15 files changed

+167
-127
lines changed

core/bifrost.go

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ type Bifrost struct {
4646
logger schemas.Logger // logger instance, default logger is used if not provided
4747
mcpManager *MCPManager // MCP integration manager (nil if MCP not configured)
4848
dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
49+
keySelector schemas.KeySelector // Custom key selector function
4950
}
5051

5152
// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation.
@@ -86,10 +87,15 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
8687
plugins: atomic.Pointer[[]schemas.Plugin]{},
8788
requestQueues: sync.Map{},
8889
waitGroups: sync.Map{},
90+
keySelector: config.KeySelector,
8991
}
9092
bifrost.plugins.Store(&config.Plugins)
9193
bifrost.dropExcessRequests.Store(config.DropExcessRequests)
9294

95+
if bifrost.keySelector == nil {
96+
bifrost.keySelector = WeightedRandomKeySelector
97+
}
98+
9399
// Initialize object pools
94100
bifrost.channelMessagePool = sync.Pool{
95101
New: func() interface{} {
@@ -626,12 +632,12 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi
626632
return bifrost.prepareProvider(providerKey, providerConfig)
627633
}
628634

629-
oldQueue := oldQueueValue.(chan ChannelMessage)
635+
oldQueue := oldQueueValue.(chan *ChannelMessage)
630636

631637
bifrost.logger.Debug("gracefully stopping existing workers for provider %s", providerKey)
632638

633639
// Step 1: Create new queue with updated buffer size
634-
newQueue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize)
640+
newQueue := make(chan *ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize)
635641

636642
// Step 2: Transfer any buffered requests from old queue to new queue
637643
// This prevents request loss during the transition
@@ -647,7 +653,7 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi
647653
// New queue is full, handle this request in a goroutine
648654
// This is unlikely with proper buffer sizing but provides safety
649655
transferWaitGroup.Add(1)
650-
go func(m ChannelMessage) {
656+
go func(m *ChannelMessage) {
651657
defer transferWaitGroup.Done()
652658
select {
653659
case newQueue <- m:
@@ -1011,7 +1017,7 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
10111017
return fmt.Errorf("failed to get config for provider: %v", err)
10121018
}
10131019

1014-
queue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) // Buffered channel per provider
1020+
queue := make(chan *ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) // Buffered channel per provider
10151021

10161022
bifrost.requestQueues.Store(providerKey, queue)
10171023

@@ -1038,13 +1044,13 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
10381044
// If the queue doesn't exist, it creates one at runtime and initializes the provider,
10391045
// given the provider config is provided in the account interface implementation.
10401046
// This function uses read locks to prevent race conditions during provider updates.
1041-
func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (chan ChannelMessage, error) {
1047+
func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (chan *ChannelMessage, error) {
10421048
// Use read lock to allow concurrent reads but prevent concurrent updates
10431049
providerMutex := bifrost.getProviderMutex(providerKey)
10441050
providerMutex.RLock()
10451051

10461052
if queueValue, exists := bifrost.requestQueues.Load(providerKey); exists {
1047-
queue := queueValue.(chan ChannelMessage)
1053+
queue := queueValue.(chan *ChannelMessage)
10481054
providerMutex.RUnlock()
10491055
return queue, nil
10501056
}
@@ -1057,7 +1063,7 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
10571063

10581064
// Double-check after acquiring write lock (another goroutine might have created it)
10591065
if queueValue, exists := bifrost.requestQueues.Load(providerKey); exists {
1060-
queue := queueValue.(chan ChannelMessage)
1066+
queue := queueValue.(chan *ChannelMessage)
10611067
return queue, nil
10621068
}
10631069

@@ -1073,7 +1079,7 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
10731079
}
10741080

10751081
queueValue, _ := bifrost.requestQueues.Load(providerKey)
1076-
queue := queueValue.(chan ChannelMessage)
1082+
queue := queueValue.(chan *ChannelMessage)
10771083

10781084
return queue, nil
10791085
}
@@ -1335,9 +1341,8 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13351341

13361342
msg := bifrost.getChannelMessage(*preReq)
13371343
msg.Context = ctx
1338-
startTime := time.Now()
13391344
select {
1340-
case queue <- *msg:
1345+
case queue <- msg:
13411346
// Message was sent successfully
13421347
case <-ctx.Done():
13431348
bifrost.releaseChannelMessage(msg)
@@ -1349,7 +1354,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13491354
return nil, newBifrostErrorFromMsg("request dropped: queue is full")
13501355
}
13511356
select {
1352-
case queue <- *msg:
1357+
case queue <- msg:
13531358
// Message was sent successfully
13541359
case <-ctx.Done():
13551360
bifrost.releaseChannelMessage(msg)
@@ -1362,11 +1367,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13621367
pluginCount := len(*bifrost.plugins.Load())
13631368
select {
13641369
case result = <-msg.Response:
1365-
latency := time.Since(startTime).Milliseconds()
1366-
if result.ExtraFields.Latency == nil {
1367-
result.ExtraFields.Latency = Ptr(float64(latency))
1368-
}
1369-
resp, bifrostErr := pipeline.RunPostHooks(&ctx, result, nil, pluginCount)
1370+
resp, bifrostErr := pipeline.RunPostHooks(&msg.Context, result, nil, pluginCount)
13701371
if bifrostErr != nil {
13711372
bifrost.releaseChannelMessage(msg)
13721373
return nil, bifrostErr
@@ -1375,7 +1376,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
13751376
return resp, nil
13761377
case bifrostErrVal := <-msg.Err:
13771378
bifrostErrPtr := &bifrostErrVal
1378-
resp, bifrostErrPtr = pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, pluginCount)
1379+
resp, bifrostErrPtr = pipeline.RunPostHooks(&msg.Context, nil, bifrostErrPtr, pluginCount)
13791380
bifrost.releaseChannelMessage(msg)
13801381
if bifrostErrPtr != nil {
13811382
return nil, bifrostErrPtr
@@ -1457,7 +1458,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
14571458
msg.Context = ctx
14581459

14591460
select {
1460-
case queue <- *msg:
1461+
case queue <- msg:
14611462
// Message was sent successfully
14621463
case <-ctx.Done():
14631464
bifrost.releaseChannelMessage(msg)
@@ -1469,7 +1470,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
14691470
return nil, newBifrostErrorFromMsg("request dropped: queue is full")
14701471
}
14711472
select {
1472-
case queue <- *msg:
1473+
case queue <- msg:
14731474
// Message was sent successfully
14741475
case <-ctx.Done():
14751476
bifrost.releaseChannelMessage(msg)
@@ -1500,7 +1501,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex
15001501

15011502
// requestWorker handles incoming requests from the queue for a specific provider.
15021503
// It manages retries, error handling, and response processing.
1503-
func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas.ProviderConfig, queue chan ChannelMessage) {
1504+
func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas.ProviderConfig, queue chan *ChannelMessage) {
15041505
defer func() {
15051506
if waitGroupValue, ok := bifrost.waitGroups.Load(provider.GetProviderKey()); ok {
15061507
waitGroup := waitGroupValue.(*sync.WaitGroup)
@@ -1535,6 +1536,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
15351536
}
15361537
continue
15371538
}
1539+
req.Context = context.WithValue(req.Context, schemas.BifrostContextKeySelectedKey, key.ID)
15381540
}
15391541

15401542
// Track attempts
@@ -1570,12 +1572,12 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
15701572

15711573
// Attempt the request
15721574
if IsStreamRequestType(req.RequestType) {
1573-
stream, bifrostError = handleProviderStreamRequest(provider, &req, key, postHookRunner)
1575+
stream, bifrostError = handleProviderStreamRequest(provider, req, key, postHookRunner)
15741576
if bifrostError != nil && !bifrostError.IsBifrostError {
15751577
break // Don't retry client errors
15761578
}
15771579
} else {
1578-
result, bifrostError = handleProviderRequest(provider, &req, key)
1580+
result, bifrostError = handleProviderRequest(provider, req, key)
15791581
if bifrostError != nil {
15801582
break // Don't retry client errors
15811583
}
@@ -1924,9 +1926,19 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
19241926
return supportedKeys[0], nil
19251927
}
19261928

1929+
selectedKey, err := bifrost.keySelector(ctx, supportedKeys, providerKey, model)
1930+
if err != nil {
1931+
return schemas.Key{}, err
1932+
}
1933+
1934+
return selectedKey, nil
1935+
1936+
}
1937+
1938+
func WeightedRandomKeySelector(ctx *context.Context, keys []schemas.Key, providerKey schemas.ModelProvider, model string) (schemas.Key, error) {
19271939
// Use a weighted random selection based on key weights
19281940
totalWeight := 0
1929-
for _, key := range supportedKeys {
1941+
for _, key := range keys {
19301942
totalWeight += int(key.Weight * 100) // Convert float to int for better performance
19311943
}
19321944

@@ -1936,15 +1948,15 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
19361948

19371949
// Select key based on weight
19381950
currentWeight := 0
1939-
for _, key := range supportedKeys {
1951+
for _, key := range keys {
19401952
currentWeight += int(key.Weight * 100)
19411953
if randomValue < currentWeight {
19421954
return key, nil
19431955
}
19441956
}
19451957

19461958
// Fallback to first key if something goes wrong
1947-
return supportedKeys[0], nil
1959+
return keys[0], nil
19481960
}
19491961

19501962
// Shutdown gracefully stops all workers when triggered.
@@ -1954,7 +1966,7 @@ func (bifrost *Bifrost) Shutdown() {
19541966

19551967
// Close all provider queues to signal workers to stop
19561968
bifrost.requestQueues.Range(func(key, value interface{}) bool {
1957-
close(value.(chan ChannelMessage))
1969+
close(value.(chan *ChannelMessage))
19581970
return true
19591971
})
19601972

core/providers/anthropic.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,11 @@ func (provider *AnthropicProvider) GetProviderKey() schemas.ModelProvider {
124124
// completeRequest sends a request to Anthropic's API and handles the response.
125125
// It constructs the API URL, sets up authentication, and processes the response.
126126
// Returns the response body or an error if the request fails.
127-
func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestBody interface{}, url string, key string) ([]byte, *schemas.BifrostError) {
127+
func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestBody interface{}, url string, key string) ([]byte, time.Duration, *schemas.BifrostError) {
128128
// Marshal the request body
129129
jsonData, err := sonic.Marshal(requestBody)
130130
if err != nil {
131-
return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, provider.GetProviderKey())
131+
return nil, 0, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, provider.GetProviderKey())
132132
}
133133

134134
// Create the request with the JSON body
@@ -149,9 +149,9 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB
149149
req.SetBody(jsonData)
150150

151151
// Send the request
152-
bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp)
152+
latency, bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp)
153153
if bifrostErr != nil {
154-
return nil, bifrostErr
154+
return nil, latency, bifrostErr
155155
}
156156

157157
// Handle error response
@@ -164,13 +164,13 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB
164164
bifrostErr.Error.Type = &errorResp.Error.Type
165165
bifrostErr.Error.Message = errorResp.Error.Message
166166

167-
return nil, bifrostErr
167+
return nil, latency, bifrostErr
168168
}
169169

170170
// Read the response body
171171
body := resp.Body()
172172

173-
return body, nil
173+
return body, latency, nil
174174
}
175175

176176
// TextCompletion performs a text completion request to Anthropic's API.
@@ -188,7 +188,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, key schem
188188
}
189189

190190
// Use struct directly for JSON marshaling
191-
responseBody, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v1/complete", key.Value)
191+
responseBody, latency, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v1/complete", key.Value)
192192
if err != nil {
193193
return nil, err
194194
}
@@ -208,6 +208,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, key schem
208208
bifrostResponse.ExtraFields.Provider = provider.GetProviderKey()
209209
bifrostResponse.ExtraFields.ModelRequested = request.Model
210210
bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest
211+
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
211212

212213
// Set raw response if enabled
213214
if provider.sendBackRawResponse {
@@ -239,7 +240,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, key schem
239240
}
240241

241242
// Use struct directly for JSON marshaling
242-
responseBody, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v1/messages", key.Value)
243+
responseBody, latency, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v1/messages", key.Value)
243244
if err != nil {
244245
return nil, err
245246
}
@@ -260,6 +261,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, key schem
260261
bifrostResponse.ExtraFields.Provider = provider.GetProviderKey()
261262
bifrostResponse.ExtraFields.ModelRequested = request.Model
262263
bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest
264+
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
263265

264266
// Set raw response if enabled
265267
if provider.sendBackRawResponse {
@@ -284,7 +286,7 @@ func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Ke
284286
}
285287

286288
// Use struct directly for JSON marshaling
287-
responseBody, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v1/messages", key.Value)
289+
responseBody, latency, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v1/messages", key.Value)
288290
if err != nil {
289291
return nil, err
290292
}
@@ -305,6 +307,7 @@ func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Ke
305307
bifrostResponse.ExtraFields.Provider = provider.GetProviderKey()
306308
bifrostResponse.ExtraFields.ModelRequested = request.Model
307309
bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest
310+
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
308311

309312
// Set raw response if enabled
310313
if provider.sendBackRawResponse {

0 commit comments

Comments
 (0)