diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index 96007a152..dc209c81f 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -12,7 +12,7 @@ import ( "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp/aggregator" - vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" + "github.com/stacklok/toolhive/pkg/vmcp/auth/factory" vmcpclient "github.com/stacklok/toolhive/pkg/vmcp/client" "github.com/stacklok/toolhive/pkg/vmcp/config" vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router" @@ -213,8 +213,15 @@ func runServe(cmd *cobra.Command, _ []string) error { return fmt.Errorf("failed to create groups manager: %w", err) } + // Create outgoing authentication registry from configuration + logger.Info("Initializing outgoing authentication") + outgoingRegistry, err := factory.NewOutgoingAuthRegistry(ctx, cfg.OutgoingAuth) + if err != nil { + return fmt.Errorf("failed to create outgoing authentication registry: %w", err) + } + // Create backend discoverer - discoverer := aggregator.NewCLIBackendDiscoverer(workloadsManager, groupsManager) + discoverer := aggregator.NewCLIBackendDiscoverer(workloadsManager, groupsManager, cfg.OutgoingAuth) // Discover backends from the configured group logger.Infof("Discovering backends in group: %s", cfg.GroupRef) @@ -230,7 +237,10 @@ func runServe(cmd *cobra.Command, _ []string) error { logger.Infof("Discovered %d backends", len(backends)) // Create backend client - backendClient := vmcpclient.NewHTTPBackendClient() + backendClient, err := vmcpclient.NewHTTPBackendClient(outgoingRegistry) + if err != nil { + return fmt.Errorf("failed to create backend client: %w", err) + } // Create conflict resolver based on configuration // Use the factory method that handles all strategies @@ -264,7 +274,7 @@ func runServe(cmd *cobra.Command, _ []string) error { // Setup authentication middleware logger.Infof("Setting up incoming authentication (type: %s)", cfg.IncomingAuth.Type) - authMiddleware, authInfoHandler, err := vmcpauth.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth) + authMiddleware, authInfoHandler, err := factory.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth) if err != nil { return fmt.Errorf("failed to create authentication middleware: %w", err) } diff --git a/pkg/vmcp/aggregator/cli_discoverer.go b/pkg/vmcp/aggregator/cli_discoverer.go index b96350b53..c1dec6b41 100644 --- a/pkg/vmcp/aggregator/cli_discoverer.go +++ b/pkg/vmcp/aggregator/cli_discoverer.go @@ -8,6 +8,7 @@ import ( "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" "github.com/stacklok/toolhive/pkg/workloads" ) @@ -16,14 +17,23 @@ import ( type cliBackendDiscoverer struct { workloadsManager workloads.Manager groupsManager groups.Manager + authConfig *config.OutgoingAuthConfig } // NewCLIBackendDiscoverer creates a new CLI-based backend discoverer. // It discovers workloads from Docker/Podman containers managed by ToolHive. -func NewCLIBackendDiscoverer(workloadsManager workloads.Manager, groupsManager groups.Manager) BackendDiscoverer { +// +// The authConfig parameter configures authentication for discovered backends. +// If nil, backends will have no authentication configured. +func NewCLIBackendDiscoverer( + workloadsManager workloads.Manager, + groupsManager groups.Manager, + authConfig *config.OutgoingAuthConfig, +) BackendDiscoverer { return &cliBackendDiscoverer{ workloadsManager: workloadsManager, groupsManager: groupsManager, + authConfig: authConfig, } } @@ -92,6 +102,16 @@ func (d *cliBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([ Metadata: make(map[string]string), } + // Apply authentication configuration if provided + if d.authConfig != nil { + authStrategy, authMetadata := d.resolveAuthConfig(name) + backend.AuthStrategy = authStrategy + backend.AuthMetadata = authMetadata + if authStrategy != "" { + logger.Debugf("Backend %s configured with auth strategy: %s", name, authStrategy) + } + } + // Copy user labels to metadata first for k, v := range workload.Labels { backend.Metadata[k] = v @@ -116,6 +136,29 @@ func (d *cliBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([ return backends, nil } +// resolveAuthConfig determines the authentication strategy and metadata for a backend. +// It checks for backend-specific configuration first, then falls back to default. +func (d *cliBackendDiscoverer) resolveAuthConfig(backendID string) (string, map[string]any) { + if d.authConfig == nil { + return "", nil + } + + // Check for backend-specific configuration + if strategy, exists := d.authConfig.Backends[backendID]; exists && strategy != nil { + logger.Debugf("Using backend-specific auth strategy for %s: %s", backendID, strategy.Type) + return strategy.Type, strategy.Metadata + } + + // Fall back to default configuration + if d.authConfig.Default != nil { + logger.Debugf("Using default auth strategy for %s: %s", backendID, d.authConfig.Default.Type) + return d.authConfig.Default.Type, d.authConfig.Default.Metadata + } + + // No authentication configured + return "", nil +} + // mapWorkloadStatusToHealth converts a workload status to a backend health status. func mapWorkloadStatusToHealth(status rt.WorkloadStatus) vmcp.BackendHealthStatus { switch status { diff --git a/pkg/vmcp/aggregator/cli_discoverer_test.go b/pkg/vmcp/aggregator/cli_discoverer_test.go index 19e1de944..9c3402fad 100644 --- a/pkg/vmcp/aggregator/cli_discoverer_test.go +++ b/pkg/vmcp/aggregator/cli_discoverer_test.go @@ -45,7 +45,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -79,7 +79,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "running-workload").Return(runningWorkload, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped-workload").Return(stoppedWorkload, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -108,7 +108,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workloadWithURL, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workloadWithoutURL, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -133,7 +133,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -150,7 +150,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), "nonexistent-group").Return(false, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), "nonexistent-group") require.Error(t, err) @@ -168,7 +168,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(false, errors.New("database error")) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.Error(t, err) @@ -187,7 +187,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), "empty-group").Return(true, nil) mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), "empty-group").Return([]string{}, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), "empty-group") require.NoError(t, err) @@ -214,7 +214,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped1").Return(stoppedWorkload, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "error1").Return(errorWorkload, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -240,7 +240,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failing-workload"). Return(core.Workload{}, errors.New("workload query failed")) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) diff --git a/pkg/vmcp/auth/auth.go b/pkg/vmcp/auth/auth.go index 76f9626eb..455b6e71c 100644 --- a/pkg/vmcp/auth/auth.go +++ b/pkg/vmcp/auth/auth.go @@ -1,7 +1,7 @@ // Package auth provides authentication for Virtual MCP Server. // // This package defines: -// - OutgoingAuthenticator: Authenticates vMCP to backend servers +// - OutgoingAuthRegistry: Registry for managing backend authentication strategies // - Strategy: Pluggable authentication strategies for backends // // Incoming authentication uses pkg/auth middleware (OIDC, local, anonymous) @@ -17,24 +17,39 @@ import ( "github.com/stacklok/toolhive/pkg/auth" ) -// OutgoingAuthenticator handles authentication to backend MCP servers. -// This is responsible for obtaining and injecting appropriate credentials -// for each backend based on its authentication strategy. +// OutgoingAuthRegistry manages authentication strategies for outgoing requests to backend MCP servers. +// This is a registry that stores and retrieves Strategy implementations. // -// The specific authentication strategies and their behavior will be defined -// during implementation based on the design decisions documented in the -// Virtual MCP Server proposal. -type OutgoingAuthenticator interface { - // AuthenticateRequest adds authentication to an outgoing backend request. - // The strategy and metadata are provided in the BackendTarget. - AuthenticateRequest(ctx context.Context, req *http.Request, strategy string, metadata map[string]any) error - - // GetStrategy returns the authentication strategy handler for a given strategy name. - // This enables extensibility - new strategies can be registered. +// The registry supports dynamic strategy registration, allowing custom authentication +// strategies to be added at runtime. Once registered, strategies can be retrieved +// by name and used to authenticate requests to backends. +// +// Responsibilities: +// - Maintain registry of available strategies +// - Retrieve strategies by name +// - Register new strategies dynamically +// +// This registry does NOT perform authentication itself. Authentication is performed +// by Strategy implementations retrieved from this registry. +// +// Usage Pattern: +// 1. Register strategies during application initialization +// 2. Resolve strategy once at client creation time (cold path) +// 3. Call strategy.Authenticate() directly per-request (hot path) +// +// Thread-safety: Implementations must be safe for concurrent access. +type OutgoingAuthRegistry interface { + // GetStrategy retrieves an authentication strategy by name. + // Returns an error if the strategy is not found. GetStrategy(name string) (Strategy, error) // RegisterStrategy registers a new authentication strategy. - // This allows custom auth strategies to be added at runtime. + // The strategy name must match the name returned by strategy.Name(). + // Returns an error if: + // - name is empty + // - strategy is nil + // - a strategy with the same name is already registered + // - strategy.Name() does not match the registration name RegisterStrategy(name string, strategy Strategy) error } diff --git a/pkg/vmcp/auth/incoming_factory.go b/pkg/vmcp/auth/factory/incoming.go similarity index 99% rename from pkg/vmcp/auth/incoming_factory.go rename to pkg/vmcp/auth/factory/incoming.go index 479876d2d..edb09a6cd 100644 --- a/pkg/vmcp/auth/incoming_factory.go +++ b/pkg/vmcp/auth/factory/incoming.go @@ -1,4 +1,4 @@ -package auth +package factory import ( "context" diff --git a/pkg/vmcp/auth/incoming_factory_test.go b/pkg/vmcp/auth/factory/incoming_test.go similarity index 99% rename from pkg/vmcp/auth/incoming_factory_test.go rename to pkg/vmcp/auth/factory/incoming_test.go index e3f7a22cb..10bc65344 100644 --- a/pkg/vmcp/auth/incoming_factory_test.go +++ b/pkg/vmcp/auth/factory/incoming_test.go @@ -1,4 +1,4 @@ -package auth +package factory import ( "context" diff --git a/pkg/vmcp/auth/factory/outgoing.go b/pkg/vmcp/auth/factory/outgoing.go new file mode 100644 index 000000000..1c7cf7254 --- /dev/null +++ b/pkg/vmcp/auth/factory/outgoing.go @@ -0,0 +1,166 @@ +// Copyright 2025 Stacklok, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package factory provides factory functions for creating vMCP authentication components. +package factory + +import ( + "context" + "fmt" + "strings" + + "github.com/stacklok/toolhive/pkg/vmcp/auth" + "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" + "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +// NewOutgoingAuthRegistry creates an OutgoingAuthRegistry from configuration. +// It registers all strategies found in the configuration (both default and backend-specific). +// +// The factory ALWAYS registers the "unauthenticated" strategy as a default fallback, +// ensuring that backends without explicit authentication configuration can function. +// This makes empty/nil configuration safe: the registry will have at least one +// usable strategy. +// +// Strategy Registration: +// - "unauthenticated" is always registered (default fallback) +// - Additional strategies are registered based on configuration +// - Each strategy is instantiated once and shared across backends +// - Strategies are stateless (except token_exchange which has internal caching) +// +// Parameters: +// - ctx: Context for any initialization that requires it +// - cfg: The outgoing authentication configuration (may be nil) +// +// Returns: +// - auth.OutgoingAuthRegistry: Configured registry with registered strategies +// - error: Any error during strategy initialization or registration +func NewOutgoingAuthRegistry(_ context.Context, cfg *config.OutgoingAuthConfig) (auth.OutgoingAuthRegistry, error) { + registry := auth.NewDefaultOutgoingAuthRegistry() + + // ALWAYS register the unauthenticated strategy as the default fallback. + if err := registerUnauthenticatedStrategy(registry); err != nil { + return nil, err + } + + // Handle nil config gracefully - return registry with unauthenticated strategy + if cfg == nil { + return registry, nil + } + + // Validate configuration structure + if err := validateConfig(cfg); err != nil { + return nil, err + } + + // Collect and register all unique strategy types from configuration + strategyTypes := collectStrategyTypes(cfg) + if err := registerStrategies(registry, strategyTypes); err != nil { + return nil, err + } + + return registry, nil +} + +// registerUnauthenticatedStrategy registers the default unauthenticated strategy. +func registerUnauthenticatedStrategy(registry auth.OutgoingAuthRegistry) error { + unauthStrategy := strategies.NewUnauthenticatedStrategy() + if err := registry.RegisterStrategy("unauthenticated", unauthStrategy); err != nil { + return fmt.Errorf("failed to register default unauthenticated strategy: %w", err) + } + return nil +} + +// validateConfig validates the configuration structure. +func validateConfig(cfg *config.OutgoingAuthConfig) error { + if cfg.Default != nil && strings.TrimSpace(cfg.Default.Type) == "" { + return fmt.Errorf("default auth strategy type cannot be empty") + } + + for backendID, backendCfg := range cfg.Backends { + if backendCfg != nil && strings.TrimSpace(backendCfg.Type) == "" { + return fmt.Errorf("backend %q has empty auth strategy type", backendID) + } + } + + return nil +} + +// collectStrategyTypes collects all unique strategy types from configuration. +func collectStrategyTypes(cfg *config.OutgoingAuthConfig) map[string]struct{} { + strategyTypes := make(map[string]struct{}) + + // Add default strategy type if present + if cfg.Default != nil && cfg.Default.Type != "" { + strategyTypes[cfg.Default.Type] = struct{}{} + } + + // Add all backend strategy types + for _, backendCfg := range cfg.Backends { + if backendCfg != nil && backendCfg.Type != "" { + strategyTypes[backendCfg.Type] = struct{}{} + } + } + + return strategyTypes +} + +// registerStrategies instantiates and registers each unique strategy type. +func registerStrategies(registry auth.OutgoingAuthRegistry, strategyTypes map[string]struct{}) error { + for strategyType := range strategyTypes { + // Skip "unauthenticated" - already registered + if strategyType == "unauthenticated" { + continue + } + + strategy, err := createStrategy(strategyType) + if err != nil { + return fmt.Errorf("failed to create strategy %q: %w", strategyType, err) + } + + if err := registry.RegisterStrategy(strategyType, strategy); err != nil { + return fmt.Errorf("failed to register strategy %q: %w", strategyType, err) + } + } + + return nil +} + +// createStrategy instantiates a strategy based on its type. +// +// Each strategy instance is stateless (except token_exchange which has internal caching). +// This function validates that the strategy type is not empty and returns an appropriate +// error for unknown strategy types. +// +// Parameters: +// - strategyType: The type identifier of the strategy to create +// +// Returns: +// - auth.Strategy: The instantiated strategy +// - error: Any error during strategy creation or validation +func createStrategy(strategyType string) (auth.Strategy, error) { + // Validate strategy type is not empty + if strings.TrimSpace(strategyType) == "" { + return nil, fmt.Errorf("strategy type cannot be empty") + } + + switch strategyType { + case "header_injection": + return strategies.NewHeaderInjectionStrategy(), nil + case "unauthenticated": + return strategies.NewUnauthenticatedStrategy(), nil + default: + return nil, fmt.Errorf("unknown strategy type: %s", strategyType) + } +} diff --git a/pkg/vmcp/auth/outgoing_authenticator.go b/pkg/vmcp/auth/outgoing_authenticator.go deleted file mode 100644 index 6498f68dd..000000000 --- a/pkg/vmcp/auth/outgoing_authenticator.go +++ /dev/null @@ -1,130 +0,0 @@ -package auth - -import ( - "context" - "errors" - "fmt" - "net/http" - "sync" -) - -// DefaultOutgoingAuthenticator is a thread-safe implementation of OutgoingAuthenticator -// that maintains a registry of authentication strategies. -// -// Thread-safety: Safe for concurrent calls to RegisterStrategy and AuthenticateRequest. -// Strategy implementations must be thread-safe as they are called concurrently. -// It uses sync.RWMutex for thread-safety as HTTP servers are inherently concurrent. -// -// This authenticator supports dynamic registration of strategies and dispatches -// authentication requests to the appropriate strategy based on the strategy name. -// -// Example usage: -// -// auth := NewDefaultOutgoingAuthenticator() -// auth.RegisterStrategy("bearer", NewBearerStrategy()) -// err := auth.AuthenticateRequest(ctx, req, "bearer", metadata) -type DefaultOutgoingAuthenticator struct { - strategies map[string]Strategy - mu sync.RWMutex -} - -// NewDefaultOutgoingAuthenticator creates a new DefaultOutgoingAuthenticator -// with an empty strategy registry. -// -// Strategies must be registered using RegisterStrategy before they can be used -// for authentication. -func NewDefaultOutgoingAuthenticator() *DefaultOutgoingAuthenticator { - return &DefaultOutgoingAuthenticator{ - strategies: make(map[string]Strategy), - } -} - -// RegisterStrategy registers a new authentication strategy. -// -// This method is thread-safe and validates that: -// - name is not empty -// - strategy is not nil -// - no strategy is already registered with the same name -// -// Parameters: -// - name: The unique identifier for this strategy -// - strategy: The Strategy implementation to register -// -// Returns an error if validation fails or a strategy with the same name -// already exists. -func (a *DefaultOutgoingAuthenticator) RegisterStrategy(name string, strategy Strategy) error { - if name == "" { - return errors.New("strategy name cannot be empty") - } - if strategy == nil { - return errors.New("strategy cannot be nil") - } - - a.mu.Lock() - defer a.mu.Unlock() - - if _, exists := a.strategies[name]; exists { - return fmt.Errorf("strategy %q is already registered", name) - } - - a.strategies[name] = strategy - return nil -} - -// GetStrategy retrieves an authentication strategy by name. -// -// This method is thread-safe for concurrent reads. It returns the strategy -// if found, or an error if no strategy is registered with the given name. -// -// Parameters: -// - name: The identifier of the strategy to retrieve -// -// Returns: -// - Strategy: The registered strategy -// - error: An error if the strategy is not found -func (a *DefaultOutgoingAuthenticator) GetStrategy(name string) (Strategy, error) { - a.mu.RLock() - defer a.mu.RUnlock() - - strategy, exists := a.strategies[name] - if !exists { - return nil, fmt.Errorf("strategy %q not found", name) - } - - return strategy, nil -} - -// AuthenticateRequest adds authentication to an outgoing backend request. -// -// This method retrieves the specified strategy and delegates authentication -// to it. The strategy modifies the request by adding appropriate headers, -// tokens, or other authentication artifacts. -// -// Parameters: -// - ctx: Request context (may contain identity for pass-through auth) -// - req: The HTTP request to authenticate -// - strategyName: The name of the strategy to use -// - metadata: Strategy-specific configuration -// -// Returns an error if: -// - The strategy is not found -// - The metadata validation fails -// - The strategy's Authenticate method fails -func (a *DefaultOutgoingAuthenticator) AuthenticateRequest( - ctx context.Context, - req *http.Request, - strategyName string, - metadata map[string]any, -) error { - strategy, err := a.GetStrategy(strategyName) - if err != nil { - return err - } - - // Validate metadata before using it - if err := strategy.Validate(metadata); err != nil { - return fmt.Errorf("invalid metadata for strategy %q: %w", strategyName, err) - } - - return strategy.Authenticate(ctx, req, metadata) -} diff --git a/pkg/vmcp/auth/outgoing_authenticator_test.go b/pkg/vmcp/auth/outgoing_authenticator_test.go deleted file mode 100644 index 43073bc7d..000000000 --- a/pkg/vmcp/auth/outgoing_authenticator_test.go +++ /dev/null @@ -1,455 +0,0 @@ -package auth - -import ( - "context" - "errors" - "net/http" - "net/http/httptest" - "sync" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - - "github.com/stacklok/toolhive/pkg/vmcp/auth/mocks" -) - -type testContextKey struct{} - -var testKey = testContextKey{} - -func TestDefaultOutgoingAuthenticator_RegisterStrategy(t *testing.T) { - t.Parallel() - t.Run("register valid strategy succeeds", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - - err := auth.RegisterStrategy("bearer", strategy) - - require.NoError(t, err) - // Verify strategy was registered - retrieved, err := auth.GetStrategy("bearer") - require.NoError(t, err) - assert.Equal(t, strategy, retrieved) - }) - - t.Run("register empty name fails", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy := mocks.NewMockStrategy(ctrl) - - err := auth.RegisterStrategy("", strategy) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "strategy name cannot be empty") - }) - - t.Run("register nil strategy fails", func(t *testing.T) { - t.Parallel() - auth := NewDefaultOutgoingAuthenticator() - - err := auth.RegisterStrategy("bearer", nil) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "strategy cannot be nil") - }) - - t.Run("register duplicate name fails", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy1 := mocks.NewMockStrategy(ctrl) - strategy1.EXPECT().Name().Return("bearer").AnyTimes() - strategy2 := mocks.NewMockStrategy(ctrl) - - err := auth.RegisterStrategy("bearer", strategy1) - require.NoError(t, err) - - err = auth.RegisterStrategy("bearer", strategy2) - assert.Error(t, err) - assert.Contains(t, err.Error(), "already registered") - assert.Contains(t, err.Error(), "bearer") - }) - - t.Run("register multiple different strategies succeeds", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - bearer := mocks.NewMockStrategy(ctrl) - bearer.EXPECT().Name().Return("bearer").AnyTimes() - basic := mocks.NewMockStrategy(ctrl) - basic.EXPECT().Name().Return("basic").AnyTimes() - apiKey := mocks.NewMockStrategy(ctrl) - apiKey.EXPECT().Name().Return("api-key").AnyTimes() - - require.NoError(t, auth.RegisterStrategy("bearer", bearer)) - require.NoError(t, auth.RegisterStrategy("basic", basic)) - require.NoError(t, auth.RegisterStrategy("api-key", apiKey)) - - // Verify all strategies are registered - s1, err := auth.GetStrategy("bearer") - require.NoError(t, err) - assert.Equal(t, bearer, s1) - - s2, err := auth.GetStrategy("basic") - require.NoError(t, err) - assert.Equal(t, basic, s2) - - s3, err := auth.GetStrategy("api-key") - require.NoError(t, err) - assert.Equal(t, apiKey, s3) - }) -} - -func TestDefaultOutgoingAuthenticator_GetStrategy(t *testing.T) { - t.Parallel() - t.Run("get existing strategy succeeds", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - require.NoError(t, auth.RegisterStrategy("bearer", strategy)) - - retrieved, err := auth.GetStrategy("bearer") - - require.NoError(t, err) - assert.Equal(t, strategy, retrieved) - }) - - t.Run("get non-existent strategy fails", func(t *testing.T) { - t.Parallel() - auth := NewDefaultOutgoingAuthenticator() - - retrieved, err := auth.GetStrategy("non-existent") - - assert.Error(t, err) - assert.Nil(t, retrieved) - assert.Contains(t, err.Error(), "not found") - assert.Contains(t, err.Error(), "non-existent") - }) - - t.Run("get from empty registry fails", func(t *testing.T) { - t.Parallel() - auth := NewDefaultOutgoingAuthenticator() - - retrieved, err := auth.GetStrategy("bearer") - - assert.Error(t, err) - assert.Nil(t, retrieved) - }) -} - -func TestDefaultOutgoingAuthenticator_AuthenticateRequest(t *testing.T) { - t.Parallel() - t.Run("authenticates with valid strategy", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - strategy.EXPECT().Validate(gomock.Any()).Return(nil) - strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, req *http.Request, _ map[string]any) error { - // Add a header to verify the request was modified - req.Header.Set("Authorization", "Bearer token123") - return nil - }, - ) - require.NoError(t, auth.RegisterStrategy("bearer", strategy)) - - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - metadata := map[string]any{"token": "token123"} - err := auth.AuthenticateRequest(context.Background(), req, "bearer", metadata) - - require.NoError(t, err) - assert.Equal(t, "Bearer token123", req.Header.Get("Authorization")) - }) - - t.Run("fails with non-existent strategy", func(t *testing.T) { - t.Parallel() - auth := NewDefaultOutgoingAuthenticator() - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - - err := auth.AuthenticateRequest(context.Background(), req, "non-existent", nil) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "not found") - }) - - t.Run("returns error from strategy", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategyErr := errors.New("authentication failed") - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - strategy.EXPECT().Validate(gomock.Any()).Return(nil) - strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).Return(strategyErr) - require.NoError(t, auth.RegisterStrategy("bearer", strategy)) - - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - err := auth.AuthenticateRequest(context.Background(), req, "bearer", nil) - - assert.Error(t, err) - assert.Equal(t, strategyErr, err) - }) - - t.Run("passes context and metadata to strategy", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - - var receivedCtx context.Context - var receivedMetadata map[string]any - - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - strategy.EXPECT().Validate(gomock.Any()).Return(nil) - strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(ctx context.Context, _ *http.Request, metadata map[string]any) error { - receivedCtx = ctx - receivedMetadata = metadata - return nil - }, - ) - require.NoError(t, auth.RegisterStrategy("bearer", strategy)) - - ctx := context.WithValue(context.Background(), testKey, "test-value") - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - metadata := map[string]any{ - "token": "abc123", - "scopes": []string{"read", "write"}, - } - - err := auth.AuthenticateRequest(ctx, req, "bearer", metadata) - - require.NoError(t, err) - assert.NotNil(t, receivedCtx) - assert.Equal(t, "test-value", receivedCtx.Value(testKey)) - assert.Equal(t, metadata, receivedMetadata) - }) - - t.Run("validates metadata before authentication", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("test-strategy").AnyTimes() - - require.NoError(t, auth.RegisterStrategy("test-strategy", strategy)) - - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - metadata := map[string]any{"invalid": "data"} - - // Expect Validate to be called and return error - strategy.EXPECT(). - Validate(metadata). - Return(errors.New("invalid metadata")) - - // Authenticate should NOT be called if validation fails - // (no EXPECT for Authenticate) - - err := auth.AuthenticateRequest(context.Background(), req, "test-strategy", metadata) - - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid metadata for strategy") - assert.Contains(t, err.Error(), "test-strategy") - }) -} - -func TestDefaultOutgoingAuthenticator_ConcurrentAccess(t *testing.T) { - t.Parallel() - t.Run("concurrent GetStrategy calls are thread-safe", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - - // Register multiple strategies - strategies := []string{"bearer", "basic", "api-key", "oauth2", "jwt"} - for _, name := range strategies { - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return(name).AnyTimes() - require.NoError(t, auth.RegisterStrategy(name, strategy)) - } - - // Test concurrent reads with -race detector - const numGoroutines = 100 - const numOperations = 1000 - - var wg sync.WaitGroup - wg.Add(numGoroutines) - - errs := make(chan error, numGoroutines*numOperations) - - for i := 0; i < numGoroutines; i++ { - go func(_ int) { - defer wg.Done() - for j := 0; j < numOperations; j++ { - // Rotate through strategies - strategyName := strategies[j%len(strategies)] - strategy, err := auth.GetStrategy(strategyName) - if err != nil { - errs <- err - return - } - if strategy.Name() != strategyName { - errs <- errors.New("strategy name mismatch") - return - } - } - }(i) - } - - wg.Wait() - close(errs) - - // Check for errors - var collectedErrors []error - for err := range errs { - collectedErrors = append(collectedErrors, err) - } - - if len(collectedErrors) > 0 { - t.Fatalf("concurrent access produced errors: %v", collectedErrors) - } - }) - - t.Run("concurrent AuthenticateRequest calls are thread-safe", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - - // Counter to verify all authentications happen - var authCount int64 - var authMu sync.Mutex - - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - strategy.EXPECT().Validate(gomock.Any()).Return(nil).AnyTimes() - strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, req *http.Request, _ map[string]any) error { - authMu.Lock() - authCount++ - authMu.Unlock() - req.Header.Set("Authorization", "Bearer test") - return nil - }, - ).AnyTimes() - require.NoError(t, auth.RegisterStrategy("bearer", strategy)) - - const numGoroutines = 100 - - var wg sync.WaitGroup - wg.Add(numGoroutines) - - errs := make(chan error, numGoroutines) - - for i := 0; i < numGoroutines; i++ { - go func() { - defer wg.Done() - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - err := auth.AuthenticateRequest(context.Background(), req, "bearer", nil) - if err != nil { - errs <- err - } - }() - } - - wg.Wait() - close(errs) - - // Check for errors - var collectedErrors []error - for err := range errs { - collectedErrors = append(collectedErrors, err) - } - - if len(collectedErrors) > 0 { - t.Fatalf("concurrent AuthenticateRequest produced errors: %v", collectedErrors) - } - - // Verify all authentications completed - assert.Equal(t, int64(numGoroutines), authCount) - }) - - t.Run("concurrent RegisterStrategy and GetStrategy are thread-safe", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - - const numRegister = 50 - const numGet = 50 - - var wg sync.WaitGroup - wg.Add(numRegister + numGet) - - errs := make(chan error, numRegister+numGet) - - // Goroutines registering strategies - for i := 0; i < numRegister; i++ { - go func(id int) { - defer wg.Done() - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("strategy").AnyTimes() - strategyName := "strategy-" + string(rune('A'+id%26)) + string(rune('0'+id/26)) - err := auth.RegisterStrategy(strategyName, strategy) - if err != nil { - errs <- err - } - }(i) - } - - // Goroutines reading strategies (will mostly fail, but shouldn't race) - for i := 0; i < numGet; i++ { - go func(id int) { - defer wg.Done() - strategyName := "strategy-" + string(rune('A'+id%26)) + string(rune('0'+id/26)) - // GetStrategy may return error if not registered yet, that's OK - _, _ = auth.GetStrategy(strategyName) - }(i) - } - - wg.Wait() - close(errs) - - // Check for unexpected errors (registration errors are not expected) - var collectedErrors []error - for err := range errs { - collectedErrors = append(collectedErrors, err) - } - - if len(collectedErrors) > 0 { - t.Fatalf("concurrent RegisterStrategy/GetStrategy produced errors: %v", collectedErrors) - } - }) -} diff --git a/pkg/vmcp/auth/outgoing_registry.go b/pkg/vmcp/auth/outgoing_registry.go new file mode 100644 index 000000000..04f2513a3 --- /dev/null +++ b/pkg/vmcp/auth/outgoing_registry.go @@ -0,0 +1,103 @@ +package auth + +import ( + "errors" + "fmt" + "sync" +) + +// DefaultOutgoingAuthRegistry is a thread-safe implementation of OutgoingAuthRegistry +// that maintains a registry of authentication strategies. +// +// Thread-safety: Safe for concurrent calls to RegisterStrategy and GetStrategy. +// Strategy implementations must be thread-safe as they are called concurrently. +// It uses sync.RWMutex for thread-safety as HTTP servers are inherently concurrent. +// +// This registry supports dynamic registration of strategies and retrieval by name. +// It does not perform authentication itself - that is done by the Strategy implementations. +// +// Example usage: +// +// registry := NewDefaultOutgoingAuthRegistry() +// registry.RegisterStrategy("header_injection", NewHeaderInjectionStrategy()) +// strategy, err := registry.GetStrategy("header_injection") +// if err == nil { +// err = strategy.Authenticate(ctx, req, metadata) +// } +type DefaultOutgoingAuthRegistry struct { + strategies map[string]Strategy + mu sync.RWMutex +} + +// NewDefaultOutgoingAuthRegistry creates a new DefaultOutgoingAuthRegistry +// with an empty strategy registry. +// +// Strategies must be registered using RegisterStrategy before they can be used +// for authentication. +func NewDefaultOutgoingAuthRegistry() *DefaultOutgoingAuthRegistry { + return &DefaultOutgoingAuthRegistry{ + strategies: make(map[string]Strategy), + } +} + +// RegisterStrategy registers a new authentication strategy. +// +// This method is thread-safe and validates that: +// - name is not empty +// - strategy is not nil +// - strategy.Name() matches the registration name +// - no strategy is already registered with the same name +// +// Parameters: +// - name: The unique identifier for this strategy +// - strategy: The Strategy implementation to register +// +// Returns an error if validation fails or a strategy with the same name +// already exists. +func (r *DefaultOutgoingAuthRegistry) RegisterStrategy(name string, strategy Strategy) error { + if name == "" { + return errors.New("strategy name cannot be empty") + } + if strategy == nil { + return errors.New("strategy cannot be nil") + } + + // Validate that strategy name matches registration name + if name != strategy.Name() { + return fmt.Errorf("strategy name mismatch: registered as %q but strategy.Name() returns %q", + name, strategy.Name()) + } + + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.strategies[name]; exists { + return fmt.Errorf("strategy %q is already registered", name) + } + + r.strategies[name] = strategy + return nil +} + +// GetStrategy retrieves an authentication strategy by name. +// +// This method is thread-safe for concurrent reads. It returns the strategy +// if found, or an error if no strategy is registered with the given name. +// +// Parameters: +// - name: The identifier of the strategy to retrieve +// +// Returns: +// - Strategy: The registered strategy +// - error: An error if the strategy is not found +func (r *DefaultOutgoingAuthRegistry) GetStrategy(name string) (Strategy, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + strategy, exists := r.strategies[name] + if !exists { + return nil, fmt.Errorf("strategy %q not found", name) + } + + return strategy, nil +} diff --git a/pkg/vmcp/auth/outgoing_registry_test.go b/pkg/vmcp/auth/outgoing_registry_test.go new file mode 100644 index 000000000..0c87b48d8 --- /dev/null +++ b/pkg/vmcp/auth/outgoing_registry_test.go @@ -0,0 +1,280 @@ +package auth + +import ( + "errors" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/vmcp/auth/mocks" +) + +func TestDefaultOutgoingAuthRegistry_RegisterStrategy(t *testing.T) { + t.Parallel() + t.Run("register valid strategy succeeds", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + strategy := mocks.NewMockStrategy(ctrl) + strategy.EXPECT().Name().Return("bearer").AnyTimes() + + err := registry.RegisterStrategy("bearer", strategy) + + require.NoError(t, err) + // Verify strategy was registered + retrieved, err := registry.GetStrategy("bearer") + require.NoError(t, err) + assert.Equal(t, strategy, retrieved) + }) + + t.Run("register empty name fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + strategy := mocks.NewMockStrategy(ctrl) + + err := registry.RegisterStrategy("", strategy) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "strategy name cannot be empty") + }) + + t.Run("register nil strategy fails", func(t *testing.T) { + t.Parallel() + registry := NewDefaultOutgoingAuthRegistry() + + err := registry.RegisterStrategy("bearer", nil) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "strategy cannot be nil") + }) + + t.Run("register strategy name mismatch fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + strategy := mocks.NewMockStrategy(ctrl) + strategy.EXPECT().Name().Return("actual_name").AnyTimes() + + err := registry.RegisterStrategy("different_name", strategy) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "strategy name mismatch") + assert.Contains(t, err.Error(), "different_name") + assert.Contains(t, err.Error(), "actual_name") + }) + + t.Run("register duplicate name fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + strategy1 := mocks.NewMockStrategy(ctrl) + strategy1.EXPECT().Name().Return("bearer").AnyTimes() + strategy2 := mocks.NewMockStrategy(ctrl) + strategy2.EXPECT().Name().Return("bearer").AnyTimes() + + err := registry.RegisterStrategy("bearer", strategy1) + require.NoError(t, err) + + err = registry.RegisterStrategy("bearer", strategy2) + assert.Error(t, err) + assert.Contains(t, err.Error(), "already registered") + assert.Contains(t, err.Error(), "bearer") + }) + + t.Run("register multiple different strategies succeeds", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + bearer := mocks.NewMockStrategy(ctrl) + bearer.EXPECT().Name().Return("bearer").AnyTimes() + basic := mocks.NewMockStrategy(ctrl) + basic.EXPECT().Name().Return("basic").AnyTimes() + apiKey := mocks.NewMockStrategy(ctrl) + apiKey.EXPECT().Name().Return("api-key").AnyTimes() + + require.NoError(t, registry.RegisterStrategy("bearer", bearer)) + require.NoError(t, registry.RegisterStrategy("basic", basic)) + require.NoError(t, registry.RegisterStrategy("api-key", apiKey)) + + // Verify all strategies are registered + s1, err := registry.GetStrategy("bearer") + require.NoError(t, err) + assert.Equal(t, bearer, s1) + + s2, err := registry.GetStrategy("basic") + require.NoError(t, err) + assert.Equal(t, basic, s2) + + s3, err := registry.GetStrategy("api-key") + require.NoError(t, err) + assert.Equal(t, apiKey, s3) + }) +} + +func TestDefaultOutgoingAuthRegistry_GetStrategy(t *testing.T) { + t.Parallel() + t.Run("get existing strategy succeeds", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + strategy := mocks.NewMockStrategy(ctrl) + strategy.EXPECT().Name().Return("bearer").AnyTimes() + require.NoError(t, registry.RegisterStrategy("bearer", strategy)) + + retrieved, err := registry.GetStrategy("bearer") + + require.NoError(t, err) + assert.Equal(t, strategy, retrieved) + }) + + t.Run("get non-existent strategy fails", func(t *testing.T) { + t.Parallel() + registry := NewDefaultOutgoingAuthRegistry() + + retrieved, err := registry.GetStrategy("non-existent") + + assert.Error(t, err) + assert.Nil(t, retrieved) + assert.Contains(t, err.Error(), "not found") + assert.Contains(t, err.Error(), "non-existent") + }) + + t.Run("get from empty registry fails", func(t *testing.T) { + t.Parallel() + registry := NewDefaultOutgoingAuthRegistry() + + retrieved, err := registry.GetStrategy("bearer") + + assert.Error(t, err) + assert.Nil(t, retrieved) + }) +} + +func TestDefaultOutgoingAuthRegistry_ConcurrentAccess(t *testing.T) { + t.Parallel() + t.Run("concurrent GetStrategy calls are thread-safe", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + + // Register multiple strategies + strategies := []string{"bearer", "basic", "api-key", "oauth2", "jwt"} + for _, name := range strategies { + strategy := mocks.NewMockStrategy(ctrl) + strategy.EXPECT().Name().Return(name).AnyTimes() + require.NoError(t, registry.RegisterStrategy(name, strategy)) + } + + // Test concurrent reads with -race detector + const numGoroutines = 100 + const numOperations = 1000 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + errs := make(chan error, numGoroutines*numOperations) + + for i := 0; i < numGoroutines; i++ { + go func(_ int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + // Rotate through strategies + strategyName := strategies[j%len(strategies)] + strategy, err := registry.GetStrategy(strategyName) + if err != nil { + errs <- err + return + } + if strategy.Name() != strategyName { + errs <- errors.New("strategy name mismatch") + return + } + } + }(i) + } + + wg.Wait() + close(errs) + + // Check for errors + var collectedErrors []error + for err := range errs { + collectedErrors = append(collectedErrors, err) + } + + if len(collectedErrors) > 0 { + t.Fatalf("concurrent access produced errors: %v", collectedErrors) + } + }) + + t.Run("concurrent RegisterStrategy and GetStrategy are thread-safe", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + + const numRegister = 50 + const numGet = 50 + + var wg sync.WaitGroup + wg.Add(numRegister + numGet) + + errs := make(chan error, numRegister+numGet) + + // Goroutines registering strategies + for i := 0; i < numRegister; i++ { + go func(id int) { + defer wg.Done() + strategyName := "strategy-" + string(rune('A'+id%26)) + string(rune('0'+id/26)) + strategy := mocks.NewMockStrategy(ctrl) + strategy.EXPECT().Name().Return(strategyName).AnyTimes() + err := registry.RegisterStrategy(strategyName, strategy) + if err != nil { + errs <- err + } + }(i) + } + + // Goroutines reading strategies (will mostly fail, but shouldn't race) + for i := 0; i < numGet; i++ { + go func(id int) { + defer wg.Done() + strategyName := "strategy-" + string(rune('A'+id%26)) + string(rune('0'+id/26)) + // GetStrategy may return error if not registered yet, that's OK + _, _ = registry.GetStrategy(strategyName) + }(i) + } + + wg.Wait() + close(errs) + + // Check for unexpected errors (registration errors are not expected) + var collectedErrors []error + for err := range errs { + collectedErrors = append(collectedErrors, err) + } + + if len(collectedErrors) > 0 { + t.Fatalf("concurrent RegisterStrategy/GetStrategy produced errors: %v", collectedErrors) + } + }) +} diff --git a/pkg/vmcp/auth/strategies/header_injection.go b/pkg/vmcp/auth/strategies/header_injection.go new file mode 100644 index 000000000..07fccc084 --- /dev/null +++ b/pkg/vmcp/auth/strategies/header_injection.go @@ -0,0 +1,113 @@ +// Package strategies provides authentication strategy implementations for Virtual MCP Server. +package strategies + +import ( + "context" + "fmt" + "net/http" + + "github.com/stacklok/toolhive/pkg/validation" +) + +// HeaderInjectionStrategy injects a static header value into request headers. +// This is a general-purpose strategy that can inject any header with any value, +// commonly used for API keys, bearer tokens, or custom authentication headers. +// +// The strategy extracts the header name and value from the metadata +// configuration and injects them into the backend request headers. +// +// Required metadata fields: +// - header_name: The HTTP header name to use (e.g., "X-API-Key", "Authorization") +// - api_key: The header value to inject (can be an API key, token, or any value) +// +// This strategy is appropriate when: +// - The backend requires a static header value for authentication +// - The header value is stored securely in the vMCP configuration +// - No dynamic token exchange or user-specific authentication is required +// +// Future enhancements may include: +// - Secret reference resolution (e.g., ${SECRET_REF:...}) +// - Support for multiple header formats (e.g., "Bearer ") +// - Value rotation and refresh mechanisms +type HeaderInjectionStrategy struct{} + +// NewHeaderInjectionStrategy creates a new HeaderInjectionStrategy instance. +func NewHeaderInjectionStrategy() *HeaderInjectionStrategy { + return &HeaderInjectionStrategy{} +} + +// Name returns the strategy identifier. +func (*HeaderInjectionStrategy) Name() string { + return "header_injection" +} + +// Authenticate injects the header value from metadata into the request header. +// +// This method: +// 1. Validates that header_name and api_key are present in metadata +// 2. Sets the specified header with the provided value +// +// Parameters: +// - ctx: Request context (currently unused, reserved for future secret resolution) +// - req: The HTTP request to authenticate +// - metadata: Strategy-specific configuration containing header_name and api_key +// +// Returns an error if: +// - header_name is missing or empty +// - api_key is missing or empty +func (*HeaderInjectionStrategy) Authenticate(_ context.Context, req *http.Request, metadata map[string]any) error { + headerName, ok := metadata["header_name"].(string) + if !ok || headerName == "" { + return fmt.Errorf("header_name required in metadata") + } + + apiKey, ok := metadata["api_key"].(string) + if !ok || apiKey == "" { + return fmt.Errorf("api_key required in metadata") + } + + // TODO: Future enhancement - resolve secret references + // if strings.HasPrefix(apiKey, "${SECRET_REF:") { + // apiKey, err = s.secretResolver.Resolve(ctx, apiKey) + // if err != nil { + // return fmt.Errorf("failed to resolve secret reference: %w", err) + // } + // } + + req.Header.Set(headerName, apiKey) + return nil +} + +// Validate checks if the required metadata fields are present and valid. +// +// This method verifies that: +// - header_name is present and non-empty +// - api_key is present and non-empty +// - header_name is a valid HTTP header name (prevents CRLF injection) +// - api_key is a valid HTTP header value (prevents CRLF injection) +// +// This validation is typically called during configuration parsing to fail fast +// if the strategy is misconfigured. +func (*HeaderInjectionStrategy) Validate(metadata map[string]any) error { + headerName, ok := metadata["header_name"].(string) + if !ok || headerName == "" { + return fmt.Errorf("header_name required in metadata") + } + + apiKey, ok := metadata["api_key"].(string) + if !ok || apiKey == "" { + return fmt.Errorf("api_key required in metadata") + } + + // Validate header name to prevent injection attacks + if err := validation.ValidateHTTPHeaderName(headerName); err != nil { + return fmt.Errorf("invalid header_name: %w", err) + } + + // Validate API key value to prevent injection attacks + if err := validation.ValidateHTTPHeaderValue(apiKey); err != nil { + return fmt.Errorf("invalid api_key: %w", err) + } + + return nil +} diff --git a/pkg/vmcp/auth/strategies/header_injection_test.go b/pkg/vmcp/auth/strategies/header_injection_test.go new file mode 100644 index 000000000..537fd3d86 --- /dev/null +++ b/pkg/vmcp/auth/strategies/header_injection_test.go @@ -0,0 +1,408 @@ +package strategies + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHeaderInjectionStrategy_Name(t *testing.T) { + t.Parallel() + + strategy := NewHeaderInjectionStrategy() + assert.Equal(t, "header_injection", strategy.Name()) +} + +func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata map[string]any + expectError bool + errorContains string + checkHeader func(t *testing.T, req *http.Request) + }{ + { + name: "sets X-API-Key header correctly", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "secret-key-123", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "secret-key-123", req.Header.Get("X-API-Key")) + }, + }, + { + name: "sets Authorization header with API key", + metadata: map[string]any{ + "header_name": "Authorization", + "api_key": "ApiKey my-secret-key", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "ApiKey my-secret-key", req.Header.Get("Authorization")) + }, + }, + { + name: "sets custom header name", + metadata: map[string]any{ + "header_name": "X-Custom-Auth-Token", + "api_key": "custom-token-value", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "custom-token-value", req.Header.Get("X-Custom-Auth-Token")) + }, + }, + { + name: "handles complex API key values", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.test", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.test", + req.Header.Get("X-API-Key")) + }, + }, + { + name: "handles API key with special characters", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "key-with-!@#$%^&*()-_=+[]{}|;:,.<>?", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "key-with-!@#$%^&*()-_=+[]{}|;:,.<>?", req.Header.Get("X-API-Key")) + }, + }, + { + name: "ignores additional metadata fields", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "my-key", + "extra_field": "ignored", + "another": 123, + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "my-key", req.Header.Get("X-API-Key")) + }, + }, + { + name: "returns error when header_name is missing", + metadata: map[string]any{ + "api_key": "my-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when header_name is empty string", + metadata: map[string]any{ + "header_name": "", + "api_key": "my-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when header_name is not a string", + metadata: map[string]any{ + "header_name": 123, + "api_key": "my-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when api_key is missing", + metadata: map[string]any{ + "header_name": "X-API-Key", + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when api_key is empty string", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "", + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when api_key is not a string", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": 123, + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when metadata is nil", + metadata: nil, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when metadata is empty", + metadata: map[string]any{}, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when both fields are missing", + metadata: map[string]any{ + "unrelated": "field", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "overwrites existing header value", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "new-key", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + // Verify the new key was set (old-key was already set before Authenticate) + assert.Equal(t, "new-key", req.Header.Get("X-API-Key")) + }, + }, + { + name: "handles very long API keys", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": string(make([]byte, 10000)) + "very-long-key", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + expected := string(make([]byte, 10000)) + "very-long-key" + assert.Equal(t, expected, req.Header.Get("X-API-Key")) + }, + }, + { + name: "handles case-sensitive header names", + metadata: map[string]any{ + "header_name": "x-api-key", // lowercase + "api_key": "my-key", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + // HTTP headers are case-insensitive, but Go normalizes them + assert.Equal(t, "my-key", req.Header.Get("x-api-key")) + assert.Equal(t, "my-key", req.Header.Get("X-Api-Key")) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + strategy := NewHeaderInjectionStrategy() + ctx := context.Background() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + + // Special setup for the "overwrites existing header value" test + if tt.name == "overwrites existing header value" { + req.Header.Set("X-API-Key", "old-key") + } + + err := strategy.Authenticate(ctx, req, tt.metadata) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + return + } + + require.NoError(t, err) + if tt.checkHeader != nil { + tt.checkHeader(t, req) + } + }) + } +} + +func TestHeaderInjectionStrategy_Validate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata map[string]any + expectError bool + errorContains string + }{ + { + name: "valid metadata with all required fields", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "secret-key", + }, + expectError: false, + }, + { + name: "valid with extra metadata fields", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "secret-key", + "extra": "ignored", + "count": 123, + }, + expectError: false, + }, + { + name: "valid with different header name", + metadata: map[string]any{ + "header_name": "Authorization", + "api_key": "Bearer token", + }, + expectError: false, + }, + { + name: "returns error when header_name is missing", + metadata: map[string]any{ + "api_key": "secret-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when header_name is empty", + metadata: map[string]any{ + "header_name": "", + "api_key": "secret-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when header_name is not a string", + metadata: map[string]any{ + "header_name": 123, + "api_key": "secret-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when header_name is a boolean", + metadata: map[string]any{ + "header_name": true, + "api_key": "secret-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when api_key is missing", + metadata: map[string]any{ + "header_name": "X-API-Key", + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when api_key is empty", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "", + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when api_key is not a string", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": 123, + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when api_key is a map", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": map[string]any{"nested": "value"}, + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when metadata is nil", + metadata: nil, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when metadata is empty", + metadata: map[string]any{}, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when both fields are wrong type", + metadata: map[string]any{ + "header_name": 123, + "api_key": false, + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error for whitespace in header_name", + metadata: map[string]any{ + "header_name": "X-Custom Header", + "api_key": "key", + }, + expectError: true, + errorContains: "invalid header_name", + }, + { + name: "accepts unicode in api_key", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "key-with-unicode-日本語", + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + strategy := NewHeaderInjectionStrategy() + err := strategy.Validate(tt.metadata) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/vmcp/auth/strategies/unauthenticated.go b/pkg/vmcp/auth/strategies/unauthenticated.go new file mode 100644 index 000000000..454495c52 --- /dev/null +++ b/pkg/vmcp/auth/strategies/unauthenticated.go @@ -0,0 +1,72 @@ +package strategies + +import ( + "context" + "net/http" +) + +// UnauthenticatedStrategy is a no-op authentication strategy that performs no authentication. +// This strategy is used when a backend MCP server requires no authentication. +// +// Unlike passing a nil authenticator (which is now an error), this strategy makes +// the intent explicit: "this backend intentionally has no authentication". +// +// The strategy performs no modifications to requests and validates all metadata. +// +// This is appropriate when: +// - The backend MCP server is on a trusted network (e.g., localhost) +// - The backend has no authentication requirements +// - Authentication is handled by network-level security (e.g., VPC, firewall) +// +// Security Warning: Only use this strategy when you are certain the backend +// requires no authentication. For production deployments, prefer explicit +// authentication strategies (pass_through, header_injection, token_exchange). +// +// Configuration: No metadata required, but any metadata is accepted and ignored. +// +// Example configuration: +// +// backends: +// local-backend: +// strategy: "unauthenticated" +type UnauthenticatedStrategy struct{} + +// NewUnauthenticatedStrategy creates a new UnauthenticatedStrategy instance. +func NewUnauthenticatedStrategy() *UnauthenticatedStrategy { + return &UnauthenticatedStrategy{} +} + +// Name returns the strategy identifier. +func (*UnauthenticatedStrategy) Name() string { + return "unauthenticated" +} + +// Authenticate performs no authentication and returns immediately. +// +// This method: +// 1. Does not modify the request in any way +// 2. Always returns nil (success) +// +// Parameters: +// - ctx: Request context (unused) +// - req: The HTTP request (not modified) +// - metadata: Strategy-specific configuration (ignored) +// +// Returns nil (always succeeds). +func (*UnauthenticatedStrategy) Authenticate(_ context.Context, _ *http.Request, _ map[string]any) error { + // No-op: intentionally does nothing + return nil +} + +// Validate checks if the strategy configuration is valid. +// +// UnauthenticatedStrategy accepts any metadata (including nil or empty), +// so this always returns nil. +// +// This permissive validation allows the strategy to be used without +// configuration or with arbitrary configuration that may be present +// for documentation purposes. +func (*UnauthenticatedStrategy) Validate(_ map[string]any) error { + // No-op: accepts any metadata + return nil +} diff --git a/pkg/vmcp/auth/strategies/unauthenticated_test.go b/pkg/vmcp/auth/strategies/unauthenticated_test.go new file mode 100644 index 000000000..43ee62bb2 --- /dev/null +++ b/pkg/vmcp/auth/strategies/unauthenticated_test.go @@ -0,0 +1,196 @@ +package strategies + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnauthenticatedStrategy_Name(t *testing.T) { + t.Parallel() + + strategy := NewUnauthenticatedStrategy() + assert.Equal(t, "unauthenticated", strategy.Name()) +} + +func TestUnauthenticatedStrategy_Authenticate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata map[string]any + setupRequest func() *http.Request + checkRequest func(t *testing.T, req *http.Request) + }{ + { + name: "does not modify request with no metadata", + metadata: nil, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + req.Header.Set("X-Custom-Header", "original-value") + return req + }, + checkRequest: func(t *testing.T, req *http.Request) { + t.Helper() + // Original headers should be unchanged + assert.Equal(t, "original-value", req.Header.Get("X-Custom-Header")) + // No auth headers should be added + assert.Empty(t, req.Header.Get("Authorization")) + }, + }, + { + name: "does not modify request with metadata present", + metadata: map[string]any{ + "some_key": "some_value", + "count": 42, + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + req.Header.Set("X-Existing", "existing-value") + return req + }, + checkRequest: func(t *testing.T, req *http.Request) { + t.Helper() + // Original headers should be unchanged + assert.Equal(t, "existing-value", req.Header.Get("X-Existing")) + // No auth headers should be added + assert.Empty(t, req.Header.Get("Authorization")) + }, + }, + { + name: "preserves existing Authorization header", + metadata: nil, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + req.Header.Set("Authorization", "Bearer existing-token") + return req + }, + checkRequest: func(t *testing.T, req *http.Request) { + t.Helper() + // Should not modify existing Authorization header + assert.Equal(t, "Bearer existing-token", req.Header.Get("Authorization")) + }, + }, + { + name: "works with empty request", + metadata: nil, + setupRequest: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + }, + checkRequest: func(t *testing.T, req *http.Request) { + t.Helper() + // Request should have no auth headers + assert.Empty(t, req.Header.Get("Authorization")) + // Headers should be empty or minimal + assert.LessOrEqual(t, len(req.Header), 1) // May have Host header + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + strategy := NewUnauthenticatedStrategy() + req := tt.setupRequest() + ctx := context.Background() + + err := strategy.Authenticate(ctx, req, tt.metadata) + + require.NoError(t, err) + tt.checkRequest(t, req) + }) + } +} + +func TestUnauthenticatedStrategy_Validate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata map[string]any + }{ + { + name: "accepts nil metadata", + metadata: nil, + }, + { + name: "accepts empty metadata", + metadata: map[string]any{}, + }, + { + name: "accepts arbitrary metadata", + metadata: map[string]any{ + "key1": "value1", + "key2": 42, + "key3": []string{"a", "b", "c"}, + "nested": map[string]any{"inner": "value"}, + }, + }, + { + name: "accepts metadata with typical auth fields", + metadata: map[string]any{ + "token_url": "https://example.com/token", + "client_id": "client-123", + "header_name": "X-API-Key", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + strategy := NewUnauthenticatedStrategy() + err := strategy.Validate(tt.metadata) + + require.NoError(t, err) + }) + } +} + +func TestUnauthenticatedStrategy_IntegrationBehavior(t *testing.T) { + t.Parallel() + + t.Run("strategy can be called multiple times safely", func(t *testing.T) { + t.Parallel() + + strategy := NewUnauthenticatedStrategy() + ctx := context.Background() + + // Call multiple times with different requests + for i := 0; i < 5; i++ { + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + err := strategy.Authenticate(ctx, req, nil) + require.NoError(t, err) + assert.Empty(t, req.Header.Get("Authorization")) + } + }) + + t.Run("strategy is safe for concurrent use", func(t *testing.T) { + t.Parallel() + + strategy := NewUnauthenticatedStrategy() + ctx := context.Background() + + // Run authentication concurrently + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + err := strategy.Authenticate(ctx, req, nil) + assert.NoError(t, err) + done <- true + }() + } + + // Wait for all goroutines + for i := 0; i < 10; i++ { + <-done + } + }) +} diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index aaaf9cc59..cd83cd061 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -17,6 +17,7 @@ import ( "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/auth" ) const ( @@ -44,14 +45,30 @@ type httpBackendClient struct { // clientFactory creates MCP clients for backends. // Abstracted as a function to enable testing with mock clients. clientFactory func(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) + + // registry manages authentication strategies for outgoing requests to backend MCP servers. + // Must not be nil - use UnauthenticatedStrategy for no authentication. + registry auth.OutgoingAuthRegistry } // NewHTTPBackendClient creates a new HTTP-based backend client. // This client supports streamable-HTTP and SSE transports. -func NewHTTPBackendClient() vmcp.BackendClient { - return &httpBackendClient{ - clientFactory: defaultClientFactory, +// +// The registry parameter manages authentication strategies for outgoing requests to backend MCP servers. +// It must not be nil. To disable authentication, use a registry configured with the +// "unauthenticated" strategy. +// +// Returns an error if registry is nil. +func NewHTTPBackendClient(registry auth.OutgoingAuthRegistry) (vmcp.BackendClient, error) { + if registry == nil { + return nil, fmt.Errorf("registry cannot be nil; use UnauthenticatedStrategy for no authentication") + } + + c := &httpBackendClient{ + registry: registry, } + c.clientFactory = c.defaultClientFactory + return c, nil } // roundTripperFunc is a function adapter for http.RoundTripper. @@ -62,29 +79,103 @@ func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) } +// authRoundTripper is an http.RoundTripper that adds authentication to backend requests. +// The authentication strategy and metadata are pre-resolved and validated at client creation time, +// eliminating per-request lookups and validation overhead. +type authRoundTripper struct { + base http.RoundTripper + authStrategy auth.Strategy + authMetadata map[string]any + target *vmcp.BackendTarget +} + +// RoundTrip implements http.RoundTripper by adding authentication headers to requests. +// The authentication strategy was pre-resolved and validated at client creation time, +// so this method simply applies the authentication without any lookups or validation. +func (a *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone request to avoid modifying the original + reqClone := req.Clone(req.Context()) + + // Apply pre-resolved authentication strategy + if err := a.authStrategy.Authenticate(reqClone.Context(), reqClone, a.authMetadata); err != nil { + return nil, fmt.Errorf("authentication failed for backend %s: %w", a.target.WorkloadID, err) + } + + logger.Debugf("Applied authentication strategy %q to backend %s", a.authStrategy.Name(), a.target.WorkloadID) + + return a.base.RoundTrip(reqClone) +} + +// resolveAuthStrategy resolves the authentication strategy for a backend target. +// It handles defaulting to "unauthenticated" when no strategy is specified. +// This method should be called once at client creation time to enable fail-fast +// behavior for invalid authentication configurations. +func (h *httpBackendClient) resolveAuthStrategy(target *vmcp.BackendTarget) (auth.Strategy, error) { + strategyName := target.AuthStrategy + + // Default to unauthenticated if not specified + if strategyName == "" { + strategyName = "unauthenticated" + } + + // Resolve strategy from registry + strategy, err := h.registry.GetStrategy(strategyName) + if err != nil { + return nil, fmt.Errorf("authentication strategy %q not found: %w", strategyName, err) + } + + return strategy, nil +} + // defaultClientFactory creates mark3labs MCP clients for different transport types. -func defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) { - // Create HTTP client with response size limits for DoS protection +func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) { + // Build transport chain: size limit → authentication → HTTP + var baseTransport = http.DefaultTransport + + // Resolve authentication strategy ONCE at client creation time + authStrategy, err := h.resolveAuthStrategy(target) + if err != nil { + return nil, fmt.Errorf("failed to resolve authentication for backend %s: %w", + target.WorkloadID, err) + } + + // Validate metadata ONCE at client creation time + if err := authStrategy.Validate(target.AuthMetadata); err != nil { + return nil, fmt.Errorf("invalid authentication configuration for backend %s: %w", + target.WorkloadID, err) + } + + // Add authentication layer with pre-resolved strategy + baseTransport = &authRoundTripper{ + base: baseTransport, + authStrategy: authStrategy, + authMetadata: target.AuthMetadata, + target: target, + } + + // Add size limit layer for DoS protection + sizeLimitedTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + resp, err := baseTransport.RoundTrip(req) + if err != nil { + return nil, err + } + // Wrap response body with size limit + resp.Body = struct { + io.Reader + io.Closer + }{ + Reader: io.LimitReader(resp.Body, maxResponseSize), + Closer: resp.Body, + } + return resp, nil + }) + + // Create HTTP client with configured transport chain httpClient := &http.Client{ - Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { - resp, err := http.DefaultTransport.RoundTrip(req) - if err != nil { - return nil, err - } - // Wrap response body with size limit - resp.Body = struct { - io.Reader - io.Closer - }{ - Reader: io.LimitReader(resp.Body, maxResponseSize), - Closer: resp.Body, - } - return resp, nil - }), + Transport: sizeLimitedTransport, } var c *client.Client - var err error switch target.TransportType { case "streamable-http", "streamable": @@ -93,8 +184,6 @@ func defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*cli transport.WithHTTPTimeout(0), transport.WithContinuousListening(), transport.WithHTTPBasicClient(httpClient), - // TODO: Add authentication header injection via WithHTTPHeaderFunc - // This will be implemented when we add OutgoingAuthenticator support ) if err != nil { return nil, fmt.Errorf("failed to create streamable-http client: %w", err) diff --git a/pkg/vmcp/client/client_test.go b/pkg/vmcp/client/client_test.go index 2a7619cb0..4e1c38837 100644 --- a/pkg/vmcp/client/client_test.go +++ b/pkg/vmcp/client/client_test.go @@ -1,15 +1,23 @@ package client +//go:generate mockgen -destination=mocks/mock_outgoing_registry.go -package=mocks github.com/stacklok/toolhive/pkg/vmcp/auth OutgoingAuthRegistry + import ( "context" "errors" + "net/http" + "net/http/httptest" "testing" "github.com/mark3labs/mcp-go/client" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/auth" + authmocks "github.com/stacklok/toolhive/pkg/vmcp/auth/mocks" + "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" ) func TestHTTPBackendClient_ListCapabilities_WithMockFactory(t *testing.T) { @@ -76,7 +84,16 @@ func TestDefaultClientFactory_UnsupportedTransport(t *testing.T) { TransportType: tc.transportType, } - _, err := defaultClientFactory(context.Background(), target) + // Create authenticator with unauthenticated strategy for testing + mockRegistry := auth.NewDefaultOutgoingAuthRegistry() + err := mockRegistry.RegisterStrategy("unauthenticated", &strategies.UnauthenticatedStrategy{}) + require.NoError(t, err) + + backendClient, err := NewHTTPBackendClient(mockRegistry) + require.NoError(t, err) + httpClient := backendClient.(*httpBackendClient) + + _, err = httpClient.defaultClientFactory(context.Background(), target) require.Error(t, err) assert.ErrorIs(t, err, vmcp.ErrUnsupportedTransport) @@ -189,3 +206,410 @@ func TestInitializeClient_ErrorHandling(t *testing.T) { assert.NotNil(t, initializeClient) }) } + +// mockRoundTripper is a test implementation of http.RoundTripper that captures requests +type mockRoundTripper struct { + capturedReq *http.Request + response *http.Response + err error +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + m.capturedReq = req + if m.err != nil { + return nil, m.err + } + return m.response, nil +} + +func TestAuthRoundTripper_RoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + target *vmcp.BackendTarget + setupStrategy func(*gomock.Controller) auth.Strategy + baseTransportResp *http.Response + baseTransportErr error + expectError bool + errorContains string + checkRequest func(t *testing.T, originalReq, capturedReq *http.Request) + checkBaseTransport func(t *testing.T, baseTransport *mockRoundTripper) + }{ + { + name: "successful authentication adds headers and forwards request", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "pass_through", + AuthMetadata: map[string]any{"key": "value"}, + }, + setupStrategy: func(ctrl *gomock.Controller) auth.Strategy { + mockStrategy := authmocks.NewMockStrategy(ctrl) + mockStrategy.EXPECT(). + Name(). + Return("pass_through"). + AnyTimes() + mockStrategy.EXPECT(). + Authenticate( + gomock.Any(), + gomock.Any(), + map[string]any{"key": "value"}, + ). + DoAndReturn(func(_ context.Context, req *http.Request, _ map[string]any) error { + // Simulate adding auth header + req.Header.Set("Authorization", "Bearer test-token") + return nil + }) + return mockStrategy + }, + baseTransportResp: &http.Response{StatusCode: http.StatusOK}, + expectError: false, + checkRequest: func(t *testing.T, originalReq, capturedReq *http.Request) { + t.Helper() + // Original request should not be modified + assert.Empty(t, originalReq.Header.Get("Authorization")) + // Captured request should have auth header + assert.Equal(t, "Bearer test-token", capturedReq.Header.Get("Authorization")) + }, + checkBaseTransport: func(t *testing.T, baseTransport *mockRoundTripper) { + t.Helper() + // Base transport should have been called + assert.NotNil(t, baseTransport.capturedReq) + }, + }, + { + name: "unauthenticated strategy skips authentication", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "unauthenticated", + AuthMetadata: nil, + }, + setupStrategy: func(ctrl *gomock.Controller) auth.Strategy { + mockStrategy := authmocks.NewMockStrategy(ctrl) + mockStrategy.EXPECT(). + Name(). + Return("unauthenticated"). + AnyTimes() + mockStrategy.EXPECT(). + Authenticate( + gomock.Any(), + gomock.Any(), + gomock.Nil(), + ). + DoAndReturn(func(_ context.Context, _ *http.Request, _ map[string]any) error { + // UnauthenticatedStrategy does nothing + return nil + }) + return mockStrategy + }, + baseTransportResp: &http.Response{StatusCode: http.StatusOK}, + expectError: false, + checkRequest: func(t *testing.T, originalReq, capturedReq *http.Request) { + t.Helper() + // Neither request should have auth headers + assert.Empty(t, originalReq.Header.Get("Authorization")) + assert.Empty(t, capturedReq.Header.Get("Authorization")) + }, + checkBaseTransport: func(t *testing.T, baseTransport *mockRoundTripper) { + t.Helper() + // Base transport should have been called + assert.NotNil(t, baseTransport.capturedReq) + }, + }, + { + name: "authentication failure returns error without calling base transport", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "pass_through", + AuthMetadata: map[string]any{"key": "value"}, + }, + setupStrategy: func(ctrl *gomock.Controller) auth.Strategy { + mockStrategy := authmocks.NewMockStrategy(ctrl) + mockStrategy.EXPECT(). + Name(). + Return("pass_through"). + AnyTimes() + mockStrategy.EXPECT(). + Authenticate( + gomock.Any(), + gomock.Any(), + map[string]any{"key": "value"}, + ). + Return(errors.New("auth failed")) + return mockStrategy + }, + baseTransportResp: &http.Response{StatusCode: http.StatusOK}, + expectError: true, + errorContains: "authentication failed for backend backend-1", + checkBaseTransport: func(t *testing.T, baseTransport *mockRoundTripper) { + t.Helper() + // Base transport should NOT have been called + assert.Nil(t, baseTransport.capturedReq) + }, + }, + { + name: "base transport error propagates after successful auth", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "pass_through", + AuthMetadata: map[string]any{"key": "value"}, + }, + setupStrategy: func(ctrl *gomock.Controller) auth.Strategy { + mockStrategy := authmocks.NewMockStrategy(ctrl) + mockStrategy.EXPECT(). + Name(). + Return("pass_through"). + AnyTimes() + mockStrategy.EXPECT(). + Authenticate( + gomock.Any(), + gomock.Any(), + map[string]any{"key": "value"}, + ). + Return(nil) + return mockStrategy + }, + baseTransportErr: errors.New("connection refused"), + expectError: true, + errorContains: "connection refused", + checkBaseTransport: func(t *testing.T, baseTransport *mockRoundTripper) { + t.Helper() + // Base transport should have been called + assert.NotNil(t, baseTransport.capturedReq) + }, + }, + { + name: "request immutability - original request unchanged", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "pass_through", + AuthMetadata: map[string]any{"key": "value"}, + }, + setupStrategy: func(ctrl *gomock.Controller) auth.Strategy { + mockStrategy := authmocks.NewMockStrategy(ctrl) + mockStrategy.EXPECT(). + Name(). + Return("pass_through"). + AnyTimes() + mockStrategy.EXPECT(). + Authenticate( + gomock.Any(), + gomock.Any(), + map[string]any{"key": "value"}, + ). + DoAndReturn(func(_ context.Context, req *http.Request, _ map[string]any) error { + // Modify the cloned request + req.Header.Set("Authorization", "Bearer modified-token") + req.Header.Set("X-Custom-Header", "custom-value") + return nil + }) + return mockStrategy + }, + baseTransportResp: &http.Response{StatusCode: http.StatusOK}, + expectError: false, + checkRequest: func(t *testing.T, originalReq, capturedReq *http.Request) { + t.Helper() + // Original request should be completely unmodified + assert.Empty(t, originalReq.Header.Get("Authorization")) + assert.Empty(t, originalReq.Header.Get("X-Custom-Header")) + + // Captured (cloned) request should have modifications + assert.Equal(t, "Bearer modified-token", capturedReq.Header.Get("Authorization")) + assert.Equal(t, "custom-value", capturedReq.Header.Get("X-Custom-Header")) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + // Setup mock strategy + var mockStrategy auth.Strategy + if tt.setupStrategy != nil { + mockStrategy = tt.setupStrategy(ctrl) + } + + // Setup mock base transport + baseTransport := &mockRoundTripper{ + response: tt.baseTransportResp, + err: tt.baseTransportErr, + } + + // Create authRoundTripper with pre-resolved strategy + authRT := &authRoundTripper{ + base: baseTransport, + authStrategy: mockStrategy, + authMetadata: tt.target.AuthMetadata, + target: tt.target, + } + + // Create test request + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + ctx := context.Background() + req = req.WithContext(ctx) + + // Execute RoundTrip + resp, err := authRT.RoundTrip(req) + + // Check error expectations + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + require.NoError(t, err) + assert.NotNil(t, resp) + } + + // Check request modifications if specified + if tt.checkRequest != nil { + tt.checkRequest(t, req, baseTransport.capturedReq) + } + + // Check base transport calls if specified + if tt.checkBaseTransport != nil { + tt.checkBaseTransport(t, baseTransport) + } + }) + } +} + +func TestNewHTTPBackendClient_NilRegistry(t *testing.T) { + t.Parallel() + + t.Run("returns error when registry is nil", func(t *testing.T) { + t.Parallel() + + client, err := NewHTTPBackendClient(nil) + + require.Error(t, err) + assert.Nil(t, client) + assert.Contains(t, err.Error(), "registry cannot be nil") + assert.Contains(t, err.Error(), "UnauthenticatedStrategy") + }) + + t.Run("succeeds with valid registry", func(t *testing.T) { + t.Parallel() + + mockRegistry := auth.NewDefaultOutgoingAuthRegistry() + client, err := NewHTTPBackendClient(mockRegistry) + + require.NoError(t, err) + assert.NotNil(t, client) + }) +} + +func TestResolveAuthStrategy(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + target *vmcp.BackendTarget + setupRegistry func() auth.OutgoingAuthRegistry + expectError bool + errorContains string + checkStrategy func(t *testing.T, strategy auth.Strategy) + }{ + { + name: "defaults to unauthenticated when strategy is empty", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "", + AuthMetadata: nil, + }, + setupRegistry: func() auth.OutgoingAuthRegistry { + registry := auth.NewDefaultOutgoingAuthRegistry() + err := registry.RegisterStrategy("unauthenticated", &strategies.UnauthenticatedStrategy{}) + require.NoError(t, err) + return registry + }, + expectError: false, + checkStrategy: func(t *testing.T, strategy auth.Strategy) { + t.Helper() + assert.Equal(t, "unauthenticated", strategy.Name()) + }, + }, + { + name: "resolves explicitly configured strategy", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "header_injection", + AuthMetadata: map[string]any{"header_name": "X-API-Key", "api_key": "test-key"}, + }, + setupRegistry: func() auth.OutgoingAuthRegistry { + registry := auth.NewDefaultOutgoingAuthRegistry() + err := registry.RegisterStrategy("header_injection", strategies.NewHeaderInjectionStrategy()) + require.NoError(t, err) + return registry + }, + expectError: false, + checkStrategy: func(t *testing.T, strategy auth.Strategy) { + t.Helper() + assert.Equal(t, "header_injection", strategy.Name()) + }, + }, + { + name: "returns error for unknown strategy", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "unknown_strategy", + AuthMetadata: nil, + }, + setupRegistry: func() auth.OutgoingAuthRegistry { + registry := auth.NewDefaultOutgoingAuthRegistry() + err := registry.RegisterStrategy("unauthenticated", &strategies.UnauthenticatedStrategy{}) + require.NoError(t, err) + return registry + }, + expectError: true, + errorContains: "authentication strategy \"unknown_strategy\" not found", + }, + { + name: "returns error when unauthenticated strategy not registered", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "", // Empty strategy defaults to unauthenticated + AuthMetadata: nil, + }, + setupRegistry: func() auth.OutgoingAuthRegistry { + // Don't register unauthenticated strategy + return auth.NewDefaultOutgoingAuthRegistry() + }, + expectError: true, + errorContains: "authentication strategy \"unauthenticated\" not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + registry := tt.setupRegistry() + backendClient, err := NewHTTPBackendClient(registry) + require.NoError(t, err) + + httpClient := backendClient.(*httpBackendClient) + + // Call resolveAuthStrategy + strategy, err := httpClient.resolveAuthStrategy(tt.target) + + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, strategy) + } else { + require.NoError(t, err) + assert.NotNil(t, strategy) + if tt.checkStrategy != nil { + tt.checkStrategy(t, strategy) + } + } + }) + } +} diff --git a/pkg/vmcp/client/mocks/mock_outgoing_registry.go b/pkg/vmcp/client/mocks/mock_outgoing_registry.go new file mode 100644 index 000000000..e18e65e05 --- /dev/null +++ b/pkg/vmcp/client/mocks/mock_outgoing_registry.go @@ -0,0 +1,70 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/stacklok/toolhive/pkg/vmcp/auth (interfaces: OutgoingAuthRegistry) +// +// Generated by this command: +// +// mockgen -destination=mocks/mock_outgoing_registry.go -package=mocks github.com/stacklok/toolhive/pkg/vmcp/auth OutgoingAuthRegistry +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + auth "github.com/stacklok/toolhive/pkg/vmcp/auth" + gomock "go.uber.org/mock/gomock" +) + +// MockOutgoingAuthRegistry is a mock of OutgoingAuthRegistry interface. +type MockOutgoingAuthRegistry struct { + ctrl *gomock.Controller + recorder *MockOutgoingAuthRegistryMockRecorder + isgomock struct{} +} + +// MockOutgoingAuthRegistryMockRecorder is the mock recorder for MockOutgoingAuthRegistry. +type MockOutgoingAuthRegistryMockRecorder struct { + mock *MockOutgoingAuthRegistry +} + +// NewMockOutgoingAuthRegistry creates a new mock instance. +func NewMockOutgoingAuthRegistry(ctrl *gomock.Controller) *MockOutgoingAuthRegistry { + mock := &MockOutgoingAuthRegistry{ctrl: ctrl} + mock.recorder = &MockOutgoingAuthRegistryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOutgoingAuthRegistry) EXPECT() *MockOutgoingAuthRegistryMockRecorder { + return m.recorder +} + +// GetStrategy mocks base method. +func (m *MockOutgoingAuthRegistry) GetStrategy(name string) (auth.Strategy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStrategy", name) + ret0, _ := ret[0].(auth.Strategy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetStrategy indicates an expected call of GetStrategy. +func (mr *MockOutgoingAuthRegistryMockRecorder) GetStrategy(name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStrategy", reflect.TypeOf((*MockOutgoingAuthRegistry)(nil).GetStrategy), name) +} + +// RegisterStrategy mocks base method. +func (m *MockOutgoingAuthRegistry) RegisterStrategy(name string, strategy auth.Strategy) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterStrategy", name, strategy) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterStrategy indicates an expected call of RegisterStrategy. +func (mr *MockOutgoingAuthRegistryMockRecorder) RegisterStrategy(name, strategy any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterStrategy", reflect.TypeOf((*MockOutgoingAuthRegistry)(nil).RegisterStrategy), name, strategy) +} diff --git a/pkg/vmcp/doc.go b/pkg/vmcp/doc.go index 246b03d2c..f81f8561b 100644 --- a/pkg/vmcp/doc.go +++ b/pkg/vmcp/doc.go @@ -83,12 +83,11 @@ // Middleware() func(http.Handler) http.Handler // } // -// OutgoingAuthenticator (pkg/vmcp/auth): +// OutgoingAuthRegistry (pkg/vmcp/auth): // -// type OutgoingAuthenticator interface { -// AuthenticateRequest(ctx context.Context, req *http.Request, strategy string, metadata map[string]any) error -// GetStrategy(name string) (AuthStrategy, error) -// RegisterStrategy(name string, strategy AuthStrategy) error +// type OutgoingAuthRegistry interface { +// GetStrategy(name string) (Strategy, error) +// RegisterStrategy(name string, strategy Strategy) error // } // // # Design Principles @@ -137,9 +136,10 @@ // // Route to backend // target, err := rtr.RouteTool(ctx, toolName) // -// // Authenticate to backend +// // Authenticate to backend (resolve strategy and call it) // backendReq := createBackendRequest(...) -// err = outAuth.AuthenticateRequest(ctx, backendReq, target.AuthStrategy, target.AuthMetadata) +// strategy, err := outAuth.GetStrategy(target.AuthStrategy) +// err = strategy.Authenticate(ctx, backendReq, target.AuthMetadata) // // // Forward request and return response // // ... diff --git a/pkg/vmcp/types.go b/pkg/vmcp/types.go index 118e2082a..f46383fdf 100644 --- a/pkg/vmcp/types.go +++ b/pkg/vmcp/types.go @@ -24,7 +24,7 @@ type BackendTarget struct { TransportType string // AuthStrategy identifies the authentication strategy for this backend. - // The actual authentication is handled by OutgoingAuthenticator interface. + // The actual authentication is handled by OutgoingAuthRegistry interface. // Examples: "pass_through", "token_exchange", "client_credentials", "oauth_proxy" AuthStrategy string