Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions anthropic_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,235 @@ package aibridge
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"

"cdr.dev/slog"
"github.com/anthropics/anthropic-sdk-go"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)

// noopRecorder is a no-op implementation of Recorder for testing
type noopRecorder struct{}

func (n *noopRecorder) RecordInterception(ctx context.Context, req *InterceptionRecord) error {
return nil
}

func (n *noopRecorder) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) error {
return nil
}

func (n *noopRecorder) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) error {
return nil
}

func (n *noopRecorder) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) error {
return nil
}

func (n *noopRecorder) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) error {
return nil
}

func TestCustomHeadersIntegration(t *testing.T) {
t.Parallel()

tests := []struct {
name string
customHeaders map[string]string
expectedInReq map[string]string
}{
{
name: "single custom header",
customHeaders: map[string]string{
"X-Custom-Header": "test-value",
},
expectedInReq: map[string]string{
"X-Custom-Header": "test-value",
},
},
{
name: "multiple custom headers",
customHeaders: map[string]string{
"X-Custom-Header-1": "value1",
"X-Custom-Header-2": "value2",
},
expectedInReq: map[string]string{
"X-Custom-Header-1": "value1",
"X-Custom-Header-2": "value2",
},
},
{
name: "no custom headers",
customHeaders: nil,
expectedInReq: map[string]string{},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

// Track which headers were received
receivedHeaders := make(map[string]string)
var headerMu sync.Mutex

// Create a mock server that captures headers
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headerMu.Lock()
defer headerMu.Unlock()

// Capture custom headers
for key := range tt.expectedInReq {
if val := r.Header.Get(key); val != "" {
receivedHeaders[key] = val
}
}

// Return a minimal valid response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
// Return a simple JSON response that matches Anthropic's Message structure
_, _ = w.Write([]byte(`{
"id": "msg_123",
"type": "message",
"role": "assistant",
"content": [{"type": "text", "text": "test"}],
"model": "claude-3-5-sonnet-20241022",
"usage": {"input_tokens": 10, "output_tokens": 20}
}`))
}))
defer srv.Close()

// Create config with custom headers
cfg := AnthropicConfig{
ProviderConfig: ProviderConfig{
BaseURL: srv.URL,
Key: "test-key",
},
CustomHeaders: tt.customHeaders,
}

// Create a simple message request
req := &MessageNewParamsWrapper{
MessageNewParams: anthropic.MessageNewParams{
Model: "claude-3-5-sonnet-20241022",
MaxTokens: 100,
Messages: []anthropic.MessageParam{
{
Role: anthropic.MessageParamRoleUser,
Content: []anthropic.ContentBlockParamUnion{
anthropic.NewTextBlock("test"),
},
},
},
},
}

interception := NewAnthropicMessagesBlockingInterception(uuid.New(), req, cfg, nil)

// Make request
w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/anthropic/v1/messages", nil)

logger := slog.Make()
// Use a no-op recorder to avoid nil pointer issues
recorder := &noopRecorder{}
interception.Setup(logger, recorder, nil)

err := interception.ProcessRequest(w, r)
require.NoError(t, err)

// Verify headers were sent
headerMu.Lock()
defer headerMu.Unlock()
require.Equal(t, tt.expectedInReq, receivedHeaders)
})
}
}

func TestParseCustomHeaders(t *testing.T) {
t.Parallel()

tests := []struct {
name string
input string
expected map[string]string
}{
{
name: "empty string",
input: "",
expected: nil,
},
{
name: "single header",
input: "X-Custom-Header: value123",
expected: map[string]string{"X-Custom-Header": "value123"},
},
{
name: "multiple headers",
input: "X-Custom-Header: value1\nX-Another-Header: value2",
expected: map[string]string{"X-Custom-Header": "value1", "X-Another-Header": "value2"},
},
{
name: "header with URL value",
input: "X-Callback-URL: https://example.com:8080/path",
expected: map[string]string{"X-Callback-URL": "https://example.com:8080/path"},
},
{
name: "header with leading/trailing spaces",
input: " X-Custom-Header : value with spaces ",
expected: map[string]string{"X-Custom-Header": "value with spaces"},
},
{
name: "empty lines ignored",
input: "X-Header-1: value1\n\n\nX-Header-2: value2\n",
expected: map[string]string{"X-Header-1": "value1", "X-Header-2": "value2"},
},
{
name: "malformed header without colon",
input: "InvalidHeader\nX-Valid-Header: value",
expected: map[string]string{"X-Valid-Header": "value"},
},
{
name: "header with empty key ignored",
input: ": value\nX-Valid-Header: value",
expected: map[string]string{"X-Valid-Header": "value"},
},
{
name: "header with empty value allowed",
input: "X-Empty-Header:",
expected: map[string]string{"X-Empty-Header": ""},
},
{
name: "multiple colons in value",
input: "X-JSON: {\"key\":\"value\"}",
expected: map[string]string{"X-JSON": "{\"key\":\"value\"}"},
},
{
name: "all malformed headers returns nil",
input: "NoColonHere\nAnotherInvalid",
expected: nil,
},
{
name: "mixed valid and invalid headers",
input: "X-Valid: value1\nInvalidLine\nX-Another: value2",
expected: map[string]string{"X-Valid": "value1", "X-Another": "value2"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseCustomHeaders(tt.input)
require.Equal(t, tt.expected, result)
})
}
}

func TestConvertStringContentToArray(t *testing.T) {
t.Parallel()

Expand Down
14 changes: 8 additions & 6 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ func TestSimple(t *testing.T) {
fixture: oaiSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, mcp.NewServerProxyManager(nil))
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil))
},
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
if streaming {
Expand Down Expand Up @@ -650,7 +650,7 @@ func TestFallthrough(t *testing.T) {
fixture: oaiFallthrough,
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))
provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil))
require.NoError(t, err)
return provider, bridge
Expand Down Expand Up @@ -838,7 +838,7 @@ func TestOpenAIInjectedTools(t *testing.T) {

configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr)
}

// Build the requirements & make the assertions which are common to all providers.
Expand Down Expand Up @@ -1047,7 +1047,7 @@ func TestErrorHandling(t *testing.T) {
createRequestFunc: createOpenAIChatCompletionsReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr)
},
responseHandlerFn: func(streaming bool, resp *http.Response) {
if streaming {
Expand Down Expand Up @@ -1359,7 +1359,9 @@ func openaiCfg(url, key string) aibridge.OpenAIConfig {

func anthropicCfg(url, key string) aibridge.AnthropicConfig {
return aibridge.AnthropicConfig{
BaseURL: url,
Key: key,
ProviderConfig: aibridge.ProviderConfig{
BaseURL: url,
Key: key,
},
}
}
10 changes: 6 additions & 4 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ type ProviderConfig struct {
BaseURL, Key string
}

type (
OpenAIConfig ProviderConfig
AnthropicConfig ProviderConfig
)
type OpenAIConfig ProviderConfig

type AnthropicConfig struct {
ProviderConfig
CustomHeaders map[string]string
}

type AWSBedrockConfig struct {
Region string
Expand Down
5 changes: 5 additions & 0 deletions intercept_anthropic_messages_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ func (i *AnthropicMessagesInterceptionBase) newAnthropicClient(ctx context.Conte
opts = append(opts, option.WithAPIKey(i.cfg.Key))
opts = append(opts, option.WithBaseURL(i.cfg.BaseURL))

// Inject custom headers if configured
for key, value := range i.cfg.CustomHeaders {
opts = append(opts, option.WithHeader(key, value))
}

if i.bedrockCfg != nil {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
Expand Down
43 changes: 43 additions & 0 deletions provider_anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/http"
"os"
"strings"

"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/shared"
Expand All @@ -28,13 +29,55 @@ const (
routeMessages = "/anthropic/v1/messages" // https://docs.anthropic.com/en/api/messages
)

// parseCustomHeaders parses the ANTHROPIC_CUSTOM_HEADERS environment variable.
// The format is "Name: Value" with one header per line.
// Multiple lines can be separated by newlines (\n).
func parseCustomHeaders(envValue string) map[string]string {
if envValue == "" {
return nil
}

headers := make(map[string]string)
lines := strings.Split(envValue, "\n")

for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}

// Split on the first colon to handle values that contain colons (e.g., URLs)
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
// Skip malformed headers
continue
}

key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])

if key != "" {
headers[key] = value
}
}

if len(headers) == 0 {
return nil
}

return headers
}

func NewAnthropicProvider(cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicProvider {
if cfg.BaseURL == "" {
cfg.BaseURL = "https://api.anthropic.com/"
}
if cfg.Key == "" {
cfg.Key = os.Getenv("ANTHROPIC_API_KEY")
}
if cfg.CustomHeaders == nil {
cfg.CustomHeaders = parseCustomHeaders(os.Getenv("ANTHROPIC_CUSTOM_HEADERS"))
}

return &AnthropicProvider{
cfg: cfg,
Expand Down