From 41ed9c46cef794bb54eb10b5750645cfef723cee Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Mon, 10 Nov 2025 20:27:55 +0000 Subject: [PATCH] feat: add support for ANTHROPIC_CUSTOM_HEADERS environment variable This change allows users to inject custom HTTP headers into all Anthropic API requests via the ANTHROPIC_CUSTOM_HEADERS environment variable. The format is 'Name: Value' with multiple headers separated by newlines. Changes: - Added CustomHeaders field to AnthropicConfig - Implemented parseCustomHeaders() to parse the environment variable - Modified newAnthropicClient() to inject headers using option.WithHeader() - Added comprehensive tests (15 test cases) for parsing and integration - Fixed test helpers for new config structure The implementation handles edge cases like URLs with colons, whitespace, empty lines, and malformed headers gracefully. All tests pass including race detection. --- anthropic_internal_test.go | 223 +++++++++++++++++++++++++++ bridge_integration_test.go | 14 +- config.go | 10 +- intercept_anthropic_messages_base.go | 5 + provider_anthropic.go | 43 ++++++ 5 files changed, 285 insertions(+), 10 deletions(-) diff --git a/anthropic_internal_test.go b/anthropic_internal_test.go index 6efc4f1..77c88a9 100644 --- a/anthropic_internal_test.go +++ b/anthropic_internal_test.go @@ -3,12 +3,235 @@ package aibridge import ( "context" "encoding/json" + "net/http" + "net/http/httptest" + "sync" "testing" + "cdr.dev/slog" "github.com/anthropics/anthropic-sdk-go" + "github.com/google/uuid" "github.com/stretchr/testify/require" ) +// noopRecorder is a no-op implementation of Recorder for testing +type noopRecorder struct{} + +func (n *noopRecorder) RecordInterception(ctx context.Context, req *InterceptionRecord) error { + return nil +} + +func (n *noopRecorder) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) error { + return nil +} + +func (n *noopRecorder) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) error { + return nil +} + +func (n *noopRecorder) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) error { + return nil +} + +func (n *noopRecorder) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) error { + return nil +} + +func TestCustomHeadersIntegration(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + customHeaders map[string]string + expectedInReq map[string]string + }{ + { + name: "single custom header", + customHeaders: map[string]string{ + "X-Custom-Header": "test-value", + }, + expectedInReq: map[string]string{ + "X-Custom-Header": "test-value", + }, + }, + { + name: "multiple custom headers", + customHeaders: map[string]string{ + "X-Custom-Header-1": "value1", + "X-Custom-Header-2": "value2", + }, + expectedInReq: map[string]string{ + "X-Custom-Header-1": "value1", + "X-Custom-Header-2": "value2", + }, + }, + { + name: "no custom headers", + customHeaders: nil, + expectedInReq: map[string]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Track which headers were received + receivedHeaders := make(map[string]string) + var headerMu sync.Mutex + + // Create a mock server that captures headers + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headerMu.Lock() + defer headerMu.Unlock() + + // Capture custom headers + for key := range tt.expectedInReq { + if val := r.Header.Get(key); val != "" { + receivedHeaders[key] = val + } + } + + // Return a minimal valid response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + // Return a simple JSON response that matches Anthropic's Message structure + _, _ = w.Write([]byte(`{ + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "test"}], + "model": "claude-3-5-sonnet-20241022", + "usage": {"input_tokens": 10, "output_tokens": 20} + }`)) + })) + defer srv.Close() + + // Create config with custom headers + cfg := AnthropicConfig{ + ProviderConfig: ProviderConfig{ + BaseURL: srv.URL, + Key: "test-key", + }, + CustomHeaders: tt.customHeaders, + } + + // Create a simple message request + req := &MessageNewParamsWrapper{ + MessageNewParams: anthropic.MessageNewParams{ + Model: "claude-3-5-sonnet-20241022", + MaxTokens: 100, + Messages: []anthropic.MessageParam{ + { + Role: anthropic.MessageParamRoleUser, + Content: []anthropic.ContentBlockParamUnion{ + anthropic.NewTextBlock("test"), + }, + }, + }, + }, + } + + interception := NewAnthropicMessagesBlockingInterception(uuid.New(), req, cfg, nil) + + // Make request + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/anthropic/v1/messages", nil) + + logger := slog.Make() + // Use a no-op recorder to avoid nil pointer issues + recorder := &noopRecorder{} + interception.Setup(logger, recorder, nil) + + err := interception.ProcessRequest(w, r) + require.NoError(t, err) + + // Verify headers were sent + headerMu.Lock() + defer headerMu.Unlock() + require.Equal(t, tt.expectedInReq, receivedHeaders) + }) + } +} + +func TestParseCustomHeaders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected map[string]string + }{ + { + name: "empty string", + input: "", + expected: nil, + }, + { + name: "single header", + input: "X-Custom-Header: value123", + expected: map[string]string{"X-Custom-Header": "value123"}, + }, + { + name: "multiple headers", + input: "X-Custom-Header: value1\nX-Another-Header: value2", + expected: map[string]string{"X-Custom-Header": "value1", "X-Another-Header": "value2"}, + }, + { + name: "header with URL value", + input: "X-Callback-URL: https://example.com:8080/path", + expected: map[string]string{"X-Callback-URL": "https://example.com:8080/path"}, + }, + { + name: "header with leading/trailing spaces", + input: " X-Custom-Header : value with spaces ", + expected: map[string]string{"X-Custom-Header": "value with spaces"}, + }, + { + name: "empty lines ignored", + input: "X-Header-1: value1\n\n\nX-Header-2: value2\n", + expected: map[string]string{"X-Header-1": "value1", "X-Header-2": "value2"}, + }, + { + name: "malformed header without colon", + input: "InvalidHeader\nX-Valid-Header: value", + expected: map[string]string{"X-Valid-Header": "value"}, + }, + { + name: "header with empty key ignored", + input: ": value\nX-Valid-Header: value", + expected: map[string]string{"X-Valid-Header": "value"}, + }, + { + name: "header with empty value allowed", + input: "X-Empty-Header:", + expected: map[string]string{"X-Empty-Header": ""}, + }, + { + name: "multiple colons in value", + input: "X-JSON: {\"key\":\"value\"}", + expected: map[string]string{"X-JSON": "{\"key\":\"value\"}"}, + }, + { + name: "all malformed headers returns nil", + input: "NoColonHere\nAnotherInvalid", + expected: nil, + }, + { + name: "mixed valid and invalid headers", + input: "X-Valid: value1\nInvalidLine\nX-Another: value2", + expected: map[string]string{"X-Valid": "value1", "X-Another": "value2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseCustomHeaders(tt.input) + require.Equal(t, tt.expected, result) + }) + } +} + func TestConvertStringContentToArray(t *testing.T) { t.Parallel() diff --git a/bridge_integration_test.go b/bridge_integration_test.go index c606340..d8b2e0a 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -499,7 +499,7 @@ func TestSimple(t *testing.T) { fixture: oaiSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, mcp.NewServerProxyManager(nil)) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil)) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -650,7 +650,7 @@ func TestFallthrough(t *testing.T) { fixture: oaiFallthrough, configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - provider := aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey))) + provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey)) bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil)) require.NoError(t, err) return provider, bridge @@ -838,7 +838,7 @@ func TestOpenAIInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr) } // Build the requirements & make the assertions which are common to all providers. @@ -1047,7 +1047,7 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr) }, responseHandlerFn: func(streaming bool, resp *http.Response) { if streaming { @@ -1359,7 +1359,9 @@ func openaiCfg(url, key string) aibridge.OpenAIConfig { func anthropicCfg(url, key string) aibridge.AnthropicConfig { return aibridge.AnthropicConfig{ - BaseURL: url, - Key: key, + ProviderConfig: aibridge.ProviderConfig{ + BaseURL: url, + Key: key, + }, } } diff --git a/config.go b/config.go index 8dc6f1d..59597bc 100644 --- a/config.go +++ b/config.go @@ -4,10 +4,12 @@ type ProviderConfig struct { BaseURL, Key string } -type ( - OpenAIConfig ProviderConfig - AnthropicConfig ProviderConfig -) +type OpenAIConfig ProviderConfig + +type AnthropicConfig struct { + ProviderConfig + CustomHeaders map[string]string +} type AWSBedrockConfig struct { Region string diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index 8459618..7fd1789 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -99,6 +99,11 @@ func (i *AnthropicMessagesInterceptionBase) newAnthropicClient(ctx context.Conte opts = append(opts, option.WithAPIKey(i.cfg.Key)) opts = append(opts, option.WithBaseURL(i.cfg.BaseURL)) + // Inject custom headers if configured + for key, value := range i.cfg.CustomHeaders { + opts = append(opts, option.WithHeader(key, value)) + } + if i.bedrockCfg != nil { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() diff --git a/provider_anthropic.go b/provider_anthropic.go index 7e9c99f..c55701c 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "os" + "strings" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/shared" @@ -28,6 +29,45 @@ const ( routeMessages = "/anthropic/v1/messages" // https://docs.anthropic.com/en/api/messages ) +// parseCustomHeaders parses the ANTHROPIC_CUSTOM_HEADERS environment variable. +// The format is "Name: Value" with one header per line. +// Multiple lines can be separated by newlines (\n). +func parseCustomHeaders(envValue string) map[string]string { + if envValue == "" { + return nil + } + + headers := make(map[string]string) + lines := strings.Split(envValue, "\n") + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + // Split on the first colon to handle values that contain colons (e.g., URLs) + parts := strings.SplitN(line, ":", 2) + if len(parts) != 2 { + // Skip malformed headers + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + if key != "" { + headers[key] = value + } + } + + if len(headers) == 0 { + return nil + } + + return headers +} + func NewAnthropicProvider(cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicProvider { if cfg.BaseURL == "" { cfg.BaseURL = "https://api.anthropic.com/" @@ -35,6 +75,9 @@ func NewAnthropicProvider(cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig) *An if cfg.Key == "" { cfg.Key = os.Getenv("ANTHROPIC_API_KEY") } + if cfg.CustomHeaders == nil { + cfg.CustomHeaders = parseCustomHeaders(os.Getenv("ANTHROPIC_CUSTOM_HEADERS")) + } return &AnthropicProvider{ cfg: cfg,