From 58dbb2336abea56849ed84502a8aaf080e9c364b Mon Sep 17 00:00:00 2001 From: Marc Nuri Date: Fri, 7 Nov 2025 10:44:17 +0100 Subject: [PATCH] test(authorization): improved tests to use real MCP clients Signed-off-by: Marc Nuri --- internal/test/mcp.go | 12 +- internal/test/mock_server.go | 30 ++ internal/test/test.go | 30 ++ pkg/http/http_authorization_test.go | 472 ++++++++++++++++++ pkg/http/http_mcp_test.go | 67 +++ pkg/http/http_test.go | 542 ++------------------- pkg/kubernetes/provider_kubeconfig_test.go | 22 +- pkg/kubernetes/provider_single_test.go | 23 +- 8 files changed, 660 insertions(+), 538 deletions(-) create mode 100644 pkg/http/http_authorization_test.go create mode 100644 pkg/http/http_mcp_test.go diff --git a/internal/test/mcp.go b/internal/test/mcp.go index 0b411b1d..4ddbe70b 100644 --- a/internal/test/mcp.go +++ b/internal/test/mcp.go @@ -12,6 +12,13 @@ import ( "golang.org/x/net/context" ) +func McpInitRequest() mcp.InitializeRequest { + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{Name: "test", Version: "1.33.7"} + return initRequest +} + type McpClient struct { ctx context.Context testServer *httptest.Server @@ -28,10 +35,7 @@ func NewMcpClient(t *testing.T, mcpHttpServer http.Handler, options ...transport require.NoError(t, err, "Expected no error creating MCP client") err = ret.Start(t.Context()) require.NoError(t, err, "Expected no error starting MCP client") - initRequest := mcp.InitializeRequest{} - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.ClientInfo = mcp.Implementation{Name: "test", Version: "1.33.7"} - _, err = ret.Initialize(t.Context(), initRequest) + _, err = ret.Initialize(t.Context(), McpInitRequest()) require.NoError(t, err, "Expected no error initializing MCP client") return ret } diff --git a/internal/test/mock_server.go b/internal/test/mock_server.go index 58740ad6..36324a5e 100644 --- a/internal/test/mock_server.go +++ b/internal/test/mock_server.go @@ -216,3 +216,33 @@ func (h *InOpenShiftHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) return } } + +const tokenReviewSuccessful = ` + { + "kind": "TokenReview", + "apiVersion": "authentication.k8s.io/v1", + "spec": {"token": "valid-token"}, + "status": { + "authenticated": true, + "user": { + "username": "test-user", + "groups": ["system:authenticated"] + }, + "audiences": ["the-audience"] + } + }` + +type TokenReviewHandler struct { + TokenReviewed bool +} + +var _ http.Handler = (*TokenReviewHandler)(nil) + +func (h *TokenReviewHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(tokenReviewSuccessful)) + h.TokenReviewed = true + return + } +} diff --git a/internal/test/test.go b/internal/test/test.go index 03491422..c2ccec4e 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -1,9 +1,12 @@ package test import ( + "fmt" + "net" "os" "path/filepath" "runtime" + "time" ) func Must[T any](v T, err error) T { @@ -19,3 +22,30 @@ func ReadFile(path ...string) string { fileBytes := Must(os.ReadFile(filePath)) return string(fileBytes) } + +func RandomPortAddress() (*net.TCPAddr, error) { + ln, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + return nil, fmt.Errorf("failed to find random port for HTTP server: %v", err) + } + defer func() { _ = ln.Close() }() + tcpAddr, ok := ln.Addr().(*net.TCPAddr) + if !ok { + return nil, fmt.Errorf("failed to cast listener address to TCPAddr") + } + return tcpAddr, nil +} + +func WaitForServer(tcpAddr *net.TCPAddr) error { + var conn *net.TCPConn + var err error + for i := 0; i < 10; i++ { + conn, err = net.DialTCP("tcp", nil, tcpAddr) + if err == nil { + _ = conn.Close() + break + } + time.Sleep(50 * time.Millisecond) + } + return err +} diff --git a/pkg/http/http_authorization_test.go b/pkg/http/http_authorization_test.go new file mode 100644 index 00000000..a8995c45 --- /dev/null +++ b/pkg/http/http_authorization_test.go @@ -0,0 +1,472 @@ +package http + +import ( + "bytes" + "flag" + "fmt" + "net/http" + "strconv" + "strings" + "testing" + "time" + + "github.com/containers/kubernetes-mcp-server/internal/test" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/coreos/go-oidc/v3/oidc/oidctest" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/stretchr/testify/suite" + "k8s.io/klog/v2" + "k8s.io/klog/v2/textlogger" +) + +type AuthorizationSuite struct { + BaseHttpSuite + mcpClient *client.Client + klogState klog.State + logBuffer bytes.Buffer +} + +func (s *AuthorizationSuite) SetupTest() { + s.BaseHttpSuite.SetupTest() + + // Capture logs + s.klogState = klog.CaptureState() + flags := flag.NewFlagSet("test", flag.ContinueOnError) + klog.InitFlags(flags) + _ = flags.Set("v", "5") + klog.SetLogger(textlogger.NewLogger(textlogger.NewConfig(textlogger.Verbosity(5), textlogger.Output(&s.logBuffer)))) + + // Default Auth settings (overridden in tests as needed) + s.OidcProvider = nil + s.StaticConfig.RequireOAuth = true + s.StaticConfig.ValidateToken = true + s.StaticConfig.OAuthAudience = "" + s.StaticConfig.StsClientId = "" + s.StaticConfig.StsClientSecret = "" + s.StaticConfig.StsAudience = "" + s.StaticConfig.StsScopes = []string{} +} + +func (s *AuthorizationSuite) TearDownTest() { + s.BaseHttpSuite.TearDownTest() + s.klogState.Restore() + + if s.mcpClient != nil { + _ = s.mcpClient.Close() + } +} + +func (s *AuthorizationSuite) StartClient(options ...transport.StreamableHTTPCOption) { + var err error + s.mcpClient, err = client.NewStreamableHttpClient(fmt.Sprintf("http://127.0.0.1:%d/mcp", s.TcpAddr.Port), options...) + s.Require().NoError(err, "Expected no error creating Streamable HTTP MCP client") + err = s.mcpClient.Start(s.T().Context()) + s.Require().NoError(err, "Expected no error starting Streamable HTTP MCP client") +} + +func (s *AuthorizationSuite) HttpGet(authHeader string) *http.Response { + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/mcp", s.TcpAddr.Port), nil) + s.Require().NoError(err, "Failed to create request") + if authHeader != "" { + req.Header.Set("Authorization", authHeader) + } + resp, err := http.DefaultClient.Do(req) + s.Require().NoError(err, "Failed to get protected endpoint") + return resp +} + +func (s *AuthorizationSuite) TestAuthorizationUnauthorizedMissingHeader() { + // Missing Authorization header + s.StartServer() + s.StartClient() + + s.Run("Initialize returns error for MISSING Authorization header", func() { + _, err := s.mcpClient.Initialize(s.T().Context(), test.McpInitRequest()) + s.Require().Error(err, "Expected error creating initial request") + s.ErrorContains(err, "transport error: request failed with status 401: Unauthorized: Bearer token required") + }) + + s.Run("Protected resource with MISSING Authorization header", func() { + resp := s.HttpGet("") + s.T().Cleanup(func() { _ = resp.Body.Close }) + + s.Run("returns 401 - Unauthorized status", func() { + s.Equal(401, resp.StatusCode, "Expected HTTP 401 for MISSING Authorization header") + }) + s.Run("returns WWW-Authenticate header", func() { + authHeader := resp.Header.Get("WWW-Authenticate") + expected := `Bearer realm="Kubernetes MCP Server", error="missing_token"` + s.Equal(expected, authHeader, "Expected WWW-Authenticate header to match") + }) + s.Run("logs error", func() { + s.Contains(s.logBuffer.String(), "Authentication failed - missing or invalid bearer token", "Expected log entry for missing or invalid bearer token") + }) + }) +} + +func (s *AuthorizationSuite) TestAuthorizationUnauthorizedHeaderIncompatible() { + // Authorization header without Bearer prefix + s.StartServer() + s.StartClient(transport.WithHTTPHeaders(map[string]string{ + "Authorization": "Basic YWxhZGRpbjpvcGVuc2VzYW1l", + })) + + s.Run("Initialize returns error for INCOMPATIBLE Authorization header", func() { + _, err := s.mcpClient.Initialize(s.T().Context(), test.McpInitRequest()) + s.Require().Error(err, "Expected error creating initial request") + s.ErrorContains(err, "transport error: request failed with status 401: Unauthorized: Bearer token required") + }) + + s.Run("Protected resource with INCOMPATIBLE Authorization header", func() { + resp := s.HttpGet("Basic YWxhZGRpbjpvcGVuc2VzYW1l") + s.T().Cleanup(func() { _ = resp.Body.Close }) + + s.Run("returns 401 - Unauthorized status", func() { + s.Equal(401, resp.StatusCode, "Expected HTTP 401 for INCOMPATIBLE Authorization header") + }) + s.Run("returns WWW-Authenticate header", func() { + authHeader := resp.Header.Get("WWW-Authenticate") + expected := `Bearer realm="Kubernetes MCP Server", error="missing_token"` + s.Equal(expected, authHeader, "Expected WWW-Authenticate header to match") + }) + s.Run("logs error", func() { + s.Contains(s.logBuffer.String(), "Authentication failed - missing or invalid bearer token", "Expected log entry for missing or invalid bearer token") + }) + }) +} + +func (s *AuthorizationSuite) TestAuthorizationUnauthorizedHeaderInvalid() { + // Invalid Authorization header + s.StartServer() + s.StartClient(transport.WithHTTPHeaders(map[string]string{ + "Authorization": "Bearer " + strings.ReplaceAll(tokenBasicNotExpired, ".", ".invalid"), + })) + + s.Run("Initialize returns error for INVALID Authorization header", func() { + _, err := s.mcpClient.Initialize(s.T().Context(), test.McpInitRequest()) + s.Require().Error(err, "Expected error creating initial request") + s.ErrorContains(err, "transport error: request failed with status 401: Unauthorized: Invalid token") + }) + + s.Run("Protected resource with INVALID Authorization header", func() { + resp := s.HttpGet("Bearer " + strings.ReplaceAll(tokenBasicNotExpired, ".", ".invalid")) + s.T().Cleanup(func() { _ = resp.Body.Close }) + + s.Run("returns 401 - Unauthorized status", func() { + s.Equal(401, resp.StatusCode, "Expected HTTP 401 for INVALID Authorization header") + }) + s.Run("returns WWW-Authenticate header", func() { + authHeader := resp.Header.Get("WWW-Authenticate") + expected := `Bearer realm="Kubernetes MCP Server", error="invalid_token"` + s.Equal(expected, authHeader, "Expected WWW-Authenticate header to match") + }) + s.Run("logs error", func() { + s.Contains(s.logBuffer.String(), "Authentication failed - JWT validation error", "Expected log entry for JWT validation error") + s.Contains(s.logBuffer.String(), "error: failed to parse JWT token: illegal base64 data", "Expected log entry for JWT validation error details") + }) + }) +} + +func (s *AuthorizationSuite) TestAuthorizationUnauthorizedHeaderExpired() { + // Expired Authorization Bearer token + s.StartServer() + s.StartClient(transport.WithHTTPHeaders(map[string]string{ + "Authorization": "Bearer " + tokenBasicExpired, + })) + + s.Run("Initialize returns error for EXPIRED Authorization header", func() { + _, err := s.mcpClient.Initialize(s.T().Context(), test.McpInitRequest()) + s.Require().Error(err, "Expected error creating initial request") + s.ErrorContains(err, "transport error: request failed with status 401: Unauthorized: Invalid token") + }) + + s.Run("Protected resource with EXPIRED Authorization header", func() { + resp := s.HttpGet("Bearer " + tokenBasicExpired) + s.T().Cleanup(func() { _ = resp.Body.Close }) + + s.Run("returns 401 - Unauthorized status", func() { + s.Equal(401, resp.StatusCode, "Expected HTTP 401 for EXPIRED Authorization header") + }) + s.Run("returns WWW-Authenticate header", func() { + authHeader := resp.Header.Get("WWW-Authenticate") + expected := `Bearer realm="Kubernetes MCP Server", error="invalid_token"` + s.Equal(expected, authHeader, "Expected WWW-Authenticate header to match") + }) + s.Run("logs error", func() { + s.Contains(s.logBuffer.String(), "Authentication failed - JWT validation error", "Expected log entry for JWT validation error") + s.Contains(s.logBuffer.String(), "validation failed, token is expired (exp)", "Expected log entry for JWT validation error details") + }) + }) +} + +func (s *AuthorizationSuite) TestAuthorizationUnauthorizedHeaderInvalidAudience() { + // Invalid audience claim Bearer token + s.StaticConfig.OAuthAudience = "expected-audience" + s.StartServer() + s.StartClient(transport.WithHTTPHeaders(map[string]string{ + "Authorization": "Bearer " + tokenBasicNotExpired, + })) + + s.Run("Initialize returns error for INVALID AUDIENCE Authorization header", func() { + _, err := s.mcpClient.Initialize(s.T().Context(), test.McpInitRequest()) + s.Require().Error(err, "Expected error creating initial request") + s.ErrorContains(err, "transport error: request failed with status 401: Unauthorized: Invalid token") + }) + + s.Run("Protected resource with INVALID AUDIENCE Authorization header", func() { + resp := s.HttpGet("Bearer " + tokenBasicNotExpired) + s.T().Cleanup(func() { _ = resp.Body.Close }) + + s.Run("returns 401 - Unauthorized status", func() { + s.Equal(401, resp.StatusCode, "Expected HTTP 401 for INVALID AUDIENCE Authorization header") + }) + s.Run("returns WWW-Authenticate header", func() { + authHeader := resp.Header.Get("WWW-Authenticate") + expected := `Bearer realm="Kubernetes MCP Server", audience="expected-audience", error="invalid_token"` + s.Equal(expected, authHeader, "Expected WWW-Authenticate header to match") + }) + s.Run("logs error", func() { + s.Contains(s.logBuffer.String(), "Authentication failed - JWT validation error", "Expected log entry for JWT validation error") + s.Contains(s.logBuffer.String(), "invalid audience claim (aud)", "Expected log entry for JWT validation error details") + }) + }) +} + +func (s *AuthorizationSuite) TestAuthorizationUnauthorizedOidcValidation() { + // Failed OIDC validation + s.StaticConfig.OAuthAudience = "mcp-server" + oidcTestServer := NewOidcTestServer(s.T()) + s.T().Cleanup(oidcTestServer.Close) + s.OidcProvider = oidcTestServer.Provider + s.StartServer() + s.StartClient(transport.WithHTTPHeaders(map[string]string{ + "Authorization": "Bearer " + tokenBasicNotExpired, + })) + + s.Run("Initialize returns error for INVALID OIDC Authorization header", func() { + _, err := s.mcpClient.Initialize(s.T().Context(), test.McpInitRequest()) + s.Require().Error(err, "Expected error creating initial request") + s.ErrorContains(err, "transport error: request failed with status 401: Unauthorized: Invalid token") + }) + + s.Run("Protected resource with INVALID OIDC Authorization header", func() { + resp := s.HttpGet("Bearer " + tokenBasicNotExpired) + s.T().Cleanup(func() { _ = resp.Body.Close }) + + s.Run("returns 401 - Unauthorized status", func() { + s.Equal(401, resp.StatusCode, "Expected HTTP 401 for INVALID OIDC Authorization header") + }) + s.Run("returns WWW-Authenticate header", func() { + authHeader := resp.Header.Get("WWW-Authenticate") + expected := `Bearer realm="Kubernetes MCP Server", audience="mcp-server", error="invalid_token"` + s.Equal(expected, authHeader, "Expected WWW-Authenticate header to match") + }) + s.Run("logs error", func() { + s.Contains(s.logBuffer.String(), "Authentication failed - JWT validation error", "Expected log entry for JWT validation error") + s.Contains(s.logBuffer.String(), "OIDC token validation error: failed to verify signature", "Expected log entry for OIDC validation error details") + }) + }) +} + +func (s *AuthorizationSuite) TestAuthorizationUnauthorizedKubernetesValidation() { + // Failed Kubernetes TokenReview + s.StaticConfig.OAuthAudience = "mcp-server" + oidcTestServer := NewOidcTestServer(s.T()) + s.T().Cleanup(oidcTestServer.Close) + rawClaims := `{ + "iss": "` + oidcTestServer.URL + `", + "exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `, + "aud": "mcp-server" + }` + validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims) + s.OidcProvider = oidcTestServer.Provider + s.StartServer() + s.StartClient(transport.WithHTTPHeaders(map[string]string{ + "Authorization": "Bearer " + validOidcToken, + })) + + s.Run("Initialize returns error for INVALID KUBERNETES Authorization header", func() { + _, err := s.mcpClient.Initialize(s.T().Context(), test.McpInitRequest()) + s.Require().Error(err, "Expected error creating initial request") + s.ErrorContains(err, "transport error: request failed with status 401: Unauthorized: Invalid token") + }) + + s.Run("Protected resource with INVALID KUBERNETES Authorization header", func() { + resp := s.HttpGet("Bearer " + validOidcToken) + s.T().Cleanup(func() { _ = resp.Body.Close }) + + s.Run("returns 401 - Unauthorized status", func() { + s.Equal(401, resp.StatusCode, "Expected HTTP 401 for INVALID KUBERNETES Authorization header") + }) + s.Run("returns WWW-Authenticate header", func() { + authHeader := resp.Header.Get("WWW-Authenticate") + expected := `Bearer realm="Kubernetes MCP Server", audience="mcp-server", error="invalid_token"` + s.Equal(expected, authHeader, "Expected WWW-Authenticate header to match") + }) + s.Run("logs error", func() { + s.Contains(s.logBuffer.String(), "Authentication failed - JWT validation error", "Expected log entry for JWT validation error") + s.Contains(s.logBuffer.String(), "kubernetes API token validation error: failed to create token review", "Expected log entry for Kubernetes TokenReview error details") + }) + }) +} + +func (s *AuthorizationSuite) TestAuthorizationRequireOAuthFalse() { + s.StaticConfig.RequireOAuth = false + s.StartServer() + s.StartClient() + + s.Run("Initialize returns OK for MISSING Authorization header", func() { + result, err := s.mcpClient.Initialize(s.T().Context(), test.McpInitRequest()) + s.Require().NoError(err, "Expected no error creating initial request") + s.Require().NotNil(result, "Expected initial request to not be nil") + }) +} + +func (s *AuthorizationSuite) TestAuthorizationRawToken() { + tokenReviewHandler := &test.TokenReviewHandler{} + s.MockServer.Handle(tokenReviewHandler) + + cases := []struct { + audience string + validateToken bool + }{ + {"", false}, // No audience, no validation + {"", true}, // No audience, validation enabled + {"mcp-server", false}, // Audience set, no validation + {"mcp-server", true}, // Audience set, validation enabled + } + for _, c := range cases { + s.StaticConfig.OAuthAudience = c.audience + s.StaticConfig.ValidateToken = c.validateToken + s.StartServer() + s.StartClient(transport.WithHTTPHeaders(map[string]string{ + "Authorization": "Bearer " + tokenBasicNotExpired, + })) + tokenReviewHandler.TokenReviewed = false + + s.Run(fmt.Sprintf("Protected resource with audience = '%s' and validate-token = '%t'", c.audience, c.validateToken), func() { + s.Run("Initialize returns OK for VALID Authorization header", func() { + result, err := s.mcpClient.Initialize(s.T().Context(), test.McpInitRequest()) + s.Require().NoError(err, "Expected no error creating initial request") + s.Require().NotNil(result, "Expected initial request to not be nil") + }) + + s.Run("Performs token validation accordingly", func() { + if tokenReviewHandler.TokenReviewed == true && !c.validateToken { + s.Fail("Expected token review to be skipped when validate-token is false, but it was performed") + } + if tokenReviewHandler.TokenReviewed == false && c.validateToken { + s.Fail("Expected token review to be performed when validate-token is true, but it was skipped") + } + }) + }) + _ = s.mcpClient.Close() + s.StopServer() + } +} + +func (s *AuthorizationSuite) TestAuthorizationOidcToken() { + tokenReviewHandler := &test.TokenReviewHandler{} + s.MockServer.Handle(tokenReviewHandler) + + oidcTestServer := NewOidcTestServer(s.T()) + s.T().Cleanup(oidcTestServer.Close) + rawClaims := `{ + "iss": "` + oidcTestServer.URL + `", + "exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `, + "aud": "mcp-server" + }` + validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims) + + cases := []bool{false, true} + for _, validateToken := range cases { + s.OidcProvider = oidcTestServer.Provider + s.StaticConfig.OAuthAudience = "mcp-server" + s.StaticConfig.ValidateToken = validateToken + s.StartServer() + s.StartClient(transport.WithHTTPHeaders(map[string]string{ + "Authorization": "Bearer " + validOidcToken, + })) + tokenReviewHandler.TokenReviewed = false + + s.Run(fmt.Sprintf("Protected resource with validate-token = '%t'", validateToken), func() { + s.Run("Initialize returns OK for VALID OIDC Authorization header", func() { + result, err := s.mcpClient.Initialize(s.T().Context(), test.McpInitRequest()) + s.Require().NoError(err, "Expected no error creating initial request") + s.Require().NotNil(result, "Expected initial request to not be nil") + }) + + s.Run("Performs token validation accordingly for VALID OIDC Authorization header", func() { + if tokenReviewHandler.TokenReviewed == true && !validateToken { + s.Fail("Expected token review to be skipped when validate-token is false, but it was performed") + } + if tokenReviewHandler.TokenReviewed == false && validateToken { + s.Fail("Expected token review to be performed when validate-token is true, but it was skipped") + } + }) + }) + _ = s.mcpClient.Close() + s.StopServer() + } +} + +func (s *AuthorizationSuite) TestAuthorizationOidcTokenExchange() { + tokenReviewHandler := &test.TokenReviewHandler{} + s.MockServer.Handle(tokenReviewHandler) + + oidcTestServer := NewOidcTestServer(s.T()) + s.T().Cleanup(oidcTestServer.Close) + rawClaims := `{ + "iss": "` + oidcTestServer.URL + `", + "exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `, + "aud": "%s" + }` + validOidcClientToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, + fmt.Sprintf(rawClaims, "mcp-server")) + validOidcBackendToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, + fmt.Sprintf(rawClaims, "backend-audience")) + oidcTestServer.TokenEndpointHandler = func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = fmt.Fprintf(w, `{"access_token":"%s","token_type":"Bearer","expires_in":253402297199}`, validOidcBackendToken) + } + + cases := []bool{false, true} + for _, validateToken := range cases { + s.OidcProvider = oidcTestServer.Provider + s.StaticConfig.OAuthAudience = "mcp-server" + s.StaticConfig.ValidateToken = validateToken + s.StaticConfig.StsClientId = "test-sts-client-id" + s.StaticConfig.StsClientSecret = "test-sts-client-secret" + s.StaticConfig.StsAudience = "backend-audience" + s.StaticConfig.StsScopes = []string{"backend-scope"} + s.StartServer() + s.StartClient(transport.WithHTTPHeaders(map[string]string{ + "Authorization": "Bearer " + validOidcClientToken, + })) + tokenReviewHandler.TokenReviewed = false + + s.Run(fmt.Sprintf("Protected resource with validate-token='%t'", validateToken), func() { + s.Run("Initialize returns OK for VALID OIDC EXCHANGE Authorization header", func() { + result, err := s.mcpClient.Initialize(s.T().Context(), test.McpInitRequest()) + s.Require().NoError(err, "Expected no error creating initial request") + s.Require().NotNil(result, "Expected initial request to not be nil") + }) + + s.Run("Performs token validation accordingly for VALID OIDC EXCHANGE Authorization header", func() { + if tokenReviewHandler.TokenReviewed == true && !validateToken { + s.Fail("Expected token review to be skipped when validate-token is false, but it was performed") + } + if tokenReviewHandler.TokenReviewed == false && validateToken { + s.Fail("Expected token review to be performed when validate-token is true, but it was skipped") + } + }) + }) + _ = s.mcpClient.Close() + s.StopServer() + } +} + +func TestAuthorization(t *testing.T) { + suite.Run(t, new(AuthorizationSuite)) +} diff --git a/pkg/http/http_mcp_test.go b/pkg/http/http_mcp_test.go new file mode 100644 index 00000000..2a79b4be --- /dev/null +++ b/pkg/http/http_mcp_test.go @@ -0,0 +1,67 @@ +package http + +import ( + "fmt" + "testing" + + "github.com/containers/kubernetes-mcp-server/internal/test" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/suite" +) + +type McpTransportSuite struct { + BaseHttpSuite +} + +func (s *McpTransportSuite) SetupTest() { + s.BaseHttpSuite.SetupTest() + s.StartServer() +} + +func (s *McpTransportSuite) TearDownTest() { + s.BaseHttpSuite.TearDownTest() +} + +func (s *McpTransportSuite) TestSseTransport() { + sseClient, sseClientErr := client.NewSSEMCPClient(fmt.Sprintf("http://127.0.0.1:%d/sse", s.TcpAddr.Port)) + s.Require().NoError(sseClientErr, "Expected no error creating SSE MCP client") + startErr := sseClient.Start(s.T().Context()) + s.Require().NoError(startErr, "Expected no error starting SSE MCP client") + s.Run("Can Initialize Session", func() { + _, err := sseClient.Initialize(s.T().Context(), test.McpInitRequest()) + s.Require().NoError(err, "Expected no error initializing SSE MCP client") + }) + s.Run("Can List Tools", func() { + tools, err := sseClient.ListTools(s.T().Context(), mcp.ListToolsRequest{}) + s.Require().NoError(err, "Expected no error listing tools from SSE MCP client") + s.Greater(len(tools.Tools), 0, "Expected at least one tool from SSE MCP client") + }) + s.Run("Can close SSE client", func() { + s.Require().NoError(sseClient.Close(), "Expected no error closing SSE MCP client") + }) +} + +func (s *McpTransportSuite) TestStreamableHttpTransport() { + httpClient, httpClientErr := client.NewStreamableHttpClient(fmt.Sprintf("http://127.0.0.1:%d/mcp", s.TcpAddr.Port), transport.WithContinuousListening()) + s.Require().NoError(httpClientErr, "Expected no error creating Streamable HTTP MCP client") + startErr := httpClient.Start(s.T().Context()) + s.Require().NoError(startErr, "Expected no error starting Streamable HTTP MCP client") + s.Run("Can Initialize Session", func() { + _, err := httpClient.Initialize(s.T().Context(), test.McpInitRequest()) + s.Require().NoError(err, "Expected no error initializing Streamable HTTP MCP client") + }) + s.Run("Can List Tools", func() { + tools, err := httpClient.ListTools(s.T().Context(), mcp.ListToolsRequest{}) + s.Require().NoError(err, "Expected no error listing tools from Streamable HTTP MCP client") + s.Greater(len(tools.Tools), 0, "Expected at least one tool from Streamable HTTP MCP client") + }) + s.Run("Can close Streamable HTTP client", func() { + s.Require().NoError(httpClient.Close(), "Expected no error closing Streamable HTTP MCP client") + }) +} + +func TestMcpTransport(t *testing.T) { + suite.Run(t, new(McpTransportSuite)) +} diff --git a/pkg/http/http_test.go b/pkg/http/http_test.go index ab531813..64c3355e 100644 --- a/pkg/http/http_test.go +++ b/pkg/http/http_test.go @@ -1,7 +1,6 @@ package http import ( - "bufio" "bytes" "context" "crypto/rand" @@ -22,6 +21,7 @@ import ( "github.com/containers/kubernetes-mcp-server/internal/test" "github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc/oidctest" + "github.com/stretchr/testify/suite" "golang.org/x/sync/errgroup" "k8s.io/klog/v2" "k8s.io/klog/v2/textlogger" @@ -30,6 +30,53 @@ import ( "github.com/containers/kubernetes-mcp-server/pkg/mcp" ) +type BaseHttpSuite struct { + suite.Suite + MockServer *test.MockServer + TcpAddr *net.TCPAddr + StaticConfig *config.StaticConfig + mcpServer *mcp.Server + OidcProvider *oidc.Provider + timeoutCancel context.CancelFunc + StopServer context.CancelFunc + WaitForShutdown func() error +} + +func (s *BaseHttpSuite) SetupTest() { + var err error + http.DefaultClient.Timeout = 10 * time.Second + s.MockServer = test.NewMockServer() + s.TcpAddr, err = test.RandomPortAddress() + s.Require().NoError(err, "Expected no error getting random port address") + s.StaticConfig = config.Default() + s.StaticConfig.KubeConfig = s.MockServer.KubeconfigFile(s.T()) + s.StaticConfig.Port = strconv.Itoa(s.TcpAddr.Port) +} + +func (s *BaseHttpSuite) StartServer() { + var err error + s.mcpServer, err = mcp.NewServer(mcp.Configuration{StaticConfig: s.StaticConfig}) + s.Require().NoError(err, "Expected no error creating MCP server") + s.Require().NotNil(s.mcpServer, "MCP server should not be nil") + var timeoutCtx, cancelCtx context.Context + timeoutCtx, s.timeoutCancel = context.WithTimeout(s.T().Context(), 10*time.Second) + group, gc := errgroup.WithContext(timeoutCtx) + cancelCtx, s.StopServer = context.WithCancel(gc) + group.Go(func() error { return Serve(cancelCtx, s.mcpServer, s.StaticConfig, s.OidcProvider, nil) }) + s.WaitForShutdown = group.Wait + s.Require().NoError(test.WaitForServer(s.TcpAddr), "HTTP server did not start in time") +} + +func (s *BaseHttpSuite) TearDownTest() { + s.MockServer.Close() + if s.mcpServer != nil { + s.mcpServer.Close() + } + s.StopServer() + s.Require().NoError(s.WaitForShutdown(), "HTTP server did not shut down gracefully") + s.timeoutCancel() +} + type httpContext struct { klogState klog.State mockServer *test.MockServer @@ -42,20 +89,6 @@ type httpContext struct { OidcProvider *oidc.Provider } -const tokenReviewSuccessful = ` - { - "kind": "TokenReview", - "apiVersion": "authentication.k8s.io/v1", - "spec": {"token": "valid-token"}, - "status": { - "authenticated": true, - "user": { - "username": "test-user", - "groups": ["system:authenticated"] - } - } - }` - func (c *httpContext) beforeEach(t *testing.T) { t.Helper() http.DefaultClient.Timeout = 10 * time.Second @@ -192,92 +225,6 @@ func TestGracefulShutdown(t *testing.T) { }) } -func TestSseTransport(t *testing.T) { - testCase(t, func(ctx *httpContext) { - sseResp, sseErr := http.Get(fmt.Sprintf("http://%s/sse", ctx.HttpAddress)) - t.Cleanup(func() { _ = sseResp.Body.Close() }) - t.Run("Exposes SSE endpoint at /sse", func(t *testing.T) { - if sseErr != nil { - t.Fatalf("Failed to get SSE endpoint: %v", sseErr) - } - if sseResp.StatusCode != http.StatusOK { - t.Errorf("Expected HTTP 200 OK, got %d", sseResp.StatusCode) - } - }) - t.Run("SSE endpoint returns text/event-stream content type", func(t *testing.T) { - if sseResp.Header.Get("Content-Type") != "text/event-stream" { - t.Errorf("Expected Content-Type text/event-stream, got %s", sseResp.Header.Get("Content-Type")) - } - }) - responseReader := bufio.NewReader(sseResp.Body) - event, eventErr := responseReader.ReadString('\n') - endpoint, endpointErr := responseReader.ReadString('\n') - t.Run("SSE endpoint returns stream with messages endpoint", func(t *testing.T) { - if eventErr != nil { - t.Fatalf("Failed to read SSE response body (event): %v", eventErr) - } - if event != "event: endpoint\n" { - t.Errorf("Expected SSE event 'endpoint', got %s", event) - } - if endpointErr != nil { - t.Fatalf("Failed to read SSE response body (endpoint): %v", endpointErr) - } - if !strings.HasPrefix(endpoint, "data: /message?sessionId=") { - t.Errorf("Expected SSE data: '/message', got %s", endpoint) - } - }) - messageResp, messageErr := http.Post( - fmt.Sprintf("http://%s/message?sessionId=%s", ctx.HttpAddress, strings.TrimSpace(endpoint[25:])), - "application/json", - bytes.NewBufferString("{}"), - ) - t.Cleanup(func() { _ = messageResp.Body.Close() }) - t.Run("Exposes message endpoint at /message", func(t *testing.T) { - if messageErr != nil { - t.Fatalf("Failed to get message endpoint: %v", messageErr) - } - if messageResp.StatusCode != http.StatusAccepted { - t.Errorf("Expected HTTP 202 OK, got %d", messageResp.StatusCode) - } - }) - }) -} - -func TestStreamableHttpTransport(t *testing.T) { - testCase(t, func(ctx *httpContext) { - mcpGetResp, mcpGetErr := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress)) - t.Cleanup(func() { _ = mcpGetResp.Body.Close() }) - t.Run("Exposes MCP GET endpoint at /mcp", func(t *testing.T) { - if mcpGetErr != nil { - t.Fatalf("Failed to get MCP endpoint: %v", mcpGetErr) - } - if mcpGetResp.StatusCode != http.StatusOK { - t.Errorf("Expected HTTP 200 OK, got %d", mcpGetResp.StatusCode) - } - }) - t.Run("MCP GET endpoint returns text/event-stream content type", func(t *testing.T) { - if mcpGetResp.Header.Get("Content-Type") != "text/event-stream" { - t.Errorf("Expected Content-Type text/event-stream (GET), got %s", mcpGetResp.Header.Get("Content-Type")) - } - }) - mcpPostResp, mcpPostErr := http.Post(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), "application/json", bytes.NewBufferString("{}")) - t.Cleanup(func() { _ = mcpPostResp.Body.Close() }) - t.Run("Exposes MCP POST endpoint at /mcp", func(t *testing.T) { - if mcpPostErr != nil { - t.Fatalf("Failed to post to MCP endpoint: %v", mcpPostErr) - } - if mcpPostResp.StatusCode != http.StatusOK { - t.Errorf("Expected HTTP 200 OK, got %d", mcpPostResp.StatusCode) - } - }) - t.Run("MCP POST endpoint returns application/json content type", func(t *testing.T) { - if mcpPostResp.Header.Get("Content-Type") != "application/json" { - t.Errorf("Expected Content-Type application/json (POST), got %s", mcpPostResp.Header.Get("Content-Type")) - } - }) - }) -} - func TestHealthCheck(t *testing.T) { testCase(t, func(ctx *httpContext) { t.Run("Exposes health check endpoint at /healthz", func(t *testing.T) { @@ -616,396 +563,3 @@ func TestMiddlewareLogging(t *testing.T) { }) }) } - -func TestAuthorizationUnauthorized(t *testing.T) { - // Missing Authorization header - testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) { - resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress)) - if err != nil { - t.Fatalf("Failed to get protected endpoint: %v", err) - } - t.Cleanup(func() { _ = resp.Body.Close }) - t.Run("Protected resource with MISSING Authorization header returns 401 - Unauthorized", func(t *testing.T) { - if resp.StatusCode != 401 { - t.Errorf("Expected HTTP 401, got %d", resp.StatusCode) - } - }) - t.Run("Protected resource with MISSING Authorization header returns WWW-Authenticate header", func(t *testing.T) { - authHeader := resp.Header.Get("WWW-Authenticate") - expected := `Bearer realm="Kubernetes MCP Server", error="missing_token"` - if authHeader != expected { - t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader) - } - }) - t.Run("Protected resource with MISSING Authorization header logs error", func(t *testing.T) { - if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - missing or invalid bearer token") { - t.Errorf("Expected log entry for missing or invalid bearer token, got: %s", ctx.LogBuffer.String()) - } - }) - }) - // Authorization header without Bearer prefix - testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) { - req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - req.Header.Set("Authorization", "Basic YWxhZGRpbjpvcGVuc2VzYW1l") - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("Failed to get protected endpoint: %v", err) - } - t.Cleanup(func() { _ = resp.Body.Close }) - t.Run("Protected resource with INCOMPATIBLE Authorization header returns WWW-Authenticate header", func(t *testing.T) { - authHeader := resp.Header.Get("WWW-Authenticate") - expected := `Bearer realm="Kubernetes MCP Server", error="missing_token"` - if authHeader != expected { - t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader) - } - }) - t.Run("Protected resource with INCOMPATIBLE Authorization header logs error", func(t *testing.T) { - if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - missing or invalid bearer token") { - t.Errorf("Expected log entry for missing or invalid bearer token, got: %s", ctx.LogBuffer.String()) - } - }) - }) - // Invalid Authorization header - testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) { - req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - req.Header.Set("Authorization", "Bearer "+strings.ReplaceAll(tokenBasicNotExpired, ".", ".invalid")) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("Failed to get protected endpoint: %v", err) - } - t.Cleanup(func() { _ = resp.Body.Close }) - t.Run("Protected resource with INVALID Authorization header returns 401 - Unauthorized", func(t *testing.T) { - if resp.StatusCode != 401 { - t.Errorf("Expected HTTP 401, got %d", resp.StatusCode) - } - }) - t.Run("Protected resource with INVALID Authorization header returns WWW-Authenticate header", func(t *testing.T) { - authHeader := resp.Header.Get("WWW-Authenticate") - expected := `Bearer realm="Kubernetes MCP Server", error="invalid_token"` - if authHeader != expected { - t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader) - } - }) - t.Run("Protected resource with INVALID Authorization header logs error", func(t *testing.T) { - if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") || - !strings.Contains(ctx.LogBuffer.String(), "error: failed to parse JWT token: illegal base64 data") { - t.Errorf("Expected log entry for JWT validation error, got: %s", ctx.LogBuffer.String()) - } - }) - }) - // Expired Authorization Bearer token - testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) { - req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - req.Header.Set("Authorization", "Bearer "+tokenBasicExpired) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("Failed to get protected endpoint: %v", err) - } - t.Cleanup(func() { _ = resp.Body.Close }) - t.Run("Protected resource with EXPIRED Authorization header returns 401 - Unauthorized", func(t *testing.T) { - if resp.StatusCode != 401 { - t.Errorf("Expected HTTP 401, got %d", resp.StatusCode) - } - }) - t.Run("Protected resource with EXPIRED Authorization header returns WWW-Authenticate header", func(t *testing.T) { - authHeader := resp.Header.Get("WWW-Authenticate") - expected := `Bearer realm="Kubernetes MCP Server", error="invalid_token"` - if authHeader != expected { - t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader) - } - }) - t.Run("Protected resource with EXPIRED Authorization header logs error", func(t *testing.T) { - if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") || - !strings.Contains(ctx.LogBuffer.String(), "validation failed, token is expired (exp)") { - t.Errorf("Expected log entry for JWT validation error, got: %s", ctx.LogBuffer.String()) - } - }) - }) - // Invalid audience claim Bearer token - testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "expected-audience", ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) { - req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - req.Header.Set("Authorization", "Bearer "+tokenBasicExpired) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("Failed to get protected endpoint: %v", err) - } - t.Cleanup(func() { _ = resp.Body.Close }) - t.Run("Protected resource with INVALID AUDIENCE Authorization header returns 401 - Unauthorized", func(t *testing.T) { - if resp.StatusCode != 401 { - t.Errorf("Expected HTTP 401, got %d", resp.StatusCode) - } - }) - t.Run("Protected resource with INVALID AUDIENCE Authorization header returns WWW-Authenticate header", func(t *testing.T) { - authHeader := resp.Header.Get("WWW-Authenticate") - expected := `Bearer realm="Kubernetes MCP Server", audience="expected-audience", error="invalid_token"` - if authHeader != expected { - t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader) - } - }) - t.Run("Protected resource with INVALID AUDIENCE Authorization header logs error", func(t *testing.T) { - if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") || - !strings.Contains(ctx.LogBuffer.String(), "invalid audience claim (aud)") { - t.Errorf("Expected log entry for JWT validation error, got: %s", ctx.LogBuffer.String()) - } - }) - }) - // Failed OIDC validation - oidcTestServer := NewOidcTestServer(t) - t.Cleanup(oidcTestServer.Close) - testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) { - req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - req.Header.Set("Authorization", "Bearer "+tokenBasicNotExpired) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("Failed to get protected endpoint: %v", err) - } - t.Cleanup(func() { _ = resp.Body.Close }) - t.Run("Protected resource with INVALID OIDC Authorization header returns 401 - Unauthorized", func(t *testing.T) { - if resp.StatusCode != 401 { - t.Errorf("Expected HTTP 401, got %d", resp.StatusCode) - } - }) - t.Run("Protected resource with INVALID OIDC Authorization header returns WWW-Authenticate header", func(t *testing.T) { - authHeader := resp.Header.Get("WWW-Authenticate") - expected := `Bearer realm="Kubernetes MCP Server", audience="mcp-server", error="invalid_token"` - if authHeader != expected { - t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader) - } - }) - t.Run("Protected resource with INVALID OIDC Authorization header logs error", func(t *testing.T) { - if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") || - !strings.Contains(ctx.LogBuffer.String(), "OIDC token validation error: failed to verify signature") { - t.Errorf("Expected log entry for OIDC validation error, got: %s", ctx.LogBuffer.String()) - } - }) - }) - // Failed Kubernetes TokenReview - rawClaims := `{ - "iss": "` + oidcTestServer.URL + `", - "exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `, - "aud": "mcp-server" - }` - validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims) - testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) { - req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - req.Header.Set("Authorization", "Bearer "+validOidcToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("Failed to get protected endpoint: %v", err) - } - t.Cleanup(func() { _ = resp.Body.Close }) - t.Run("Protected resource with INVALID KUBERNETES Authorization header returns 401 - Unauthorized", func(t *testing.T) { - if resp.StatusCode != 401 { - t.Errorf("Expected HTTP 401, got %d", resp.StatusCode) - } - }) - t.Run("Protected resource with INVALID KUBERNETES Authorization header returns WWW-Authenticate header", func(t *testing.T) { - authHeader := resp.Header.Get("WWW-Authenticate") - expected := `Bearer realm="Kubernetes MCP Server", audience="mcp-server", error="invalid_token"` - if authHeader != expected { - t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader) - } - }) - t.Run("Protected resource with INVALID KUBERNETES Authorization header logs error", func(t *testing.T) { - if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") || - !strings.Contains(ctx.LogBuffer.String(), "kubernetes API token validation error: failed to create token review") { - t.Errorf("Expected log entry for Kubernetes TokenReview error, got: %s", ctx.LogBuffer.String()) - } - }) - }) -} - -func TestAuthorizationRequireOAuthFalse(t *testing.T) { - testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: false, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) { - resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress)) - if err != nil { - t.Fatalf("Failed to get protected endpoint: %v", err) - } - t.Cleanup(func() { _ = resp.Body.Close() }) - t.Run("Protected resource with MISSING Authorization header returns 200 - OK)", func(t *testing.T) { - if resp.StatusCode != http.StatusOK { - t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode) - } - }) - }) -} - -func TestAuthorizationRawToken(t *testing.T) { - cases := []struct { - audience string - validateToken bool - }{ - {"", false}, // No audience, no validation - {"", true}, // No audience, validation enabled - {"mcp-server", false}, // Audience set, no validation - {"mcp-server", true}, // Audience set, validation enabled - } - for _, c := range cases { - testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: c.audience, ValidateToken: c.validateToken, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) { - tokenReviewed := false - ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(tokenReviewSuccessful)) - tokenReviewed = true - return - } - })) - req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - req.Header.Set("Authorization", "Bearer "+tokenBasicNotExpired) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("Failed to get protected endpoint: %v", err) - } - t.Cleanup(func() { _ = resp.Body.Close() }) - t.Run(fmt.Sprintf("Protected resource with audience = '%s' and validate-token = '%t', with VALID Authorization header returns 200 - OK", c.audience, c.validateToken), func(t *testing.T) { - if resp.StatusCode != http.StatusOK { - t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode) - } - }) - t.Run(fmt.Sprintf("Protected resource with audience = '%s' and validate-token = '%t', with VALID Authorization header performs token validation accordingly", c.audience, c.validateToken), func(t *testing.T) { - if tokenReviewed == true && !c.validateToken { - t.Errorf("Expected token review to be skipped when validate-token is false, but it was performed") - } - if tokenReviewed == false && c.validateToken { - t.Errorf("Expected token review to be performed when validate-token is true, but it was skipped") - } - }) - }) - } - -} - -func TestAuthorizationOidcToken(t *testing.T) { - oidcTestServer := NewOidcTestServer(t) - t.Cleanup(oidcTestServer.Close) - rawClaims := `{ - "iss": "` + oidcTestServer.URL + `", - "exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `, - "aud": "mcp-server" - }` - validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims) - cases := []bool{false, true} - for _, validateToken := range cases { - testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: validateToken, ClusterProviderStrategy: config.ClusterProviderKubeConfig}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) { - tokenReviewed := false - ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(tokenReviewSuccessful)) - tokenReviewed = true - return - } - })) - req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - req.Header.Set("Authorization", "Bearer "+validOidcToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("Failed to get protected endpoint: %v", err) - } - t.Cleanup(func() { _ = resp.Body.Close() }) - t.Run(fmt.Sprintf("Protected resource with validate-token='%t' with VALID OIDC Authorization header returns 200 - OK", validateToken), func(t *testing.T) { - if resp.StatusCode != http.StatusOK { - t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode) - } - }) - t.Run(fmt.Sprintf("Protected resource with validate-token='%t' with VALID OIDC Authorization header performs token validation accordingly", validateToken), func(t *testing.T) { - if tokenReviewed == true && !validateToken { - t.Errorf("Expected token review to be skipped when validate-token is false, but it was performed") - } - if tokenReviewed == false && validateToken { - t.Errorf("Expected token review to be performed when validate-token is true, but it was skipped") - } - }) - }) - } -} - -func TestAuthorizationOidcTokenExchange(t *testing.T) { - oidcTestServer := NewOidcTestServer(t) - t.Cleanup(oidcTestServer.Close) - rawClaims := `{ - "iss": "` + oidcTestServer.URL + `", - "exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `, - "aud": "%s" - }` - validOidcClientToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, - fmt.Sprintf(rawClaims, "mcp-server")) - validOidcBackendToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, - fmt.Sprintf(rawClaims, "backend-audience")) - oidcTestServer.TokenEndpointHandler = func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - _, _ = fmt.Fprintf(w, `{"access_token":"%s","token_type":"Bearer","expires_in":253402297199}`, validOidcBackendToken) - } - cases := []bool{false, true} - for _, validateToken := range cases { - staticConfig := &config.StaticConfig{ - RequireOAuth: true, - OAuthAudience: "mcp-server", - ValidateToken: validateToken, - StsClientId: "test-sts-client-id", - StsClientSecret: "test-sts-client-secret", - StsAudience: "backend-audience", - StsScopes: []string{"backend-scope"}, - ClusterProviderStrategy: config.ClusterProviderKubeConfig, - } - testCaseWithContext(t, &httpContext{StaticConfig: staticConfig, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) { - tokenReviewed := false - ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(tokenReviewSuccessful)) - tokenReviewed = true - return - } - })) - req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - req.Header.Set("Authorization", "Bearer "+validOidcClientToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("Failed to get protected endpoint: %v", err) - } - t.Cleanup(func() { _ = resp.Body.Close() }) - t.Run(fmt.Sprintf("Protected resource with validate-token='%t' with VALID OIDC EXCHANGE Authorization header returns 200 - OK", validateToken), func(t *testing.T) { - if resp.StatusCode != http.StatusOK { - t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode) - } - }) - t.Run(fmt.Sprintf("Protected resource with validate-token='%t' with VALID OIDC EXCHANGE Authorization header performs token validation accordingly", validateToken), func(t *testing.T) { - if tokenReviewed == true && !validateToken { - t.Errorf("Expected token review to be skipped when validate-token is false, but it was performed") - } - if tokenReviewed == false && validateToken { - t.Errorf("Expected token review to be performed when validate-token is true, but it was skipped") - } - }) - }) - } -} diff --git a/pkg/kubernetes/provider_kubeconfig_test.go b/pkg/kubernetes/provider_kubeconfig_test.go index 17984990..33ba60d6 100644 --- a/pkg/kubernetes/provider_kubeconfig_test.go +++ b/pkg/kubernetes/provider_kubeconfig_test.go @@ -2,7 +2,6 @@ package kubernetes import ( "fmt" - "net/http" "testing" "github.com/containers/kubernetes-mcp-server/internal/test" @@ -57,25 +56,8 @@ func (s *ProviderKubeconfigTestSuite) TestWithOpenShiftCluster() { } func (s *ProviderKubeconfigTestSuite) TestVerifyToken() { - s.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(` - { - "kind": "TokenReview", - "apiVersion": "authentication.k8s.io/v1", - "spec": {"token": "the-token"}, - "status": { - "authenticated": true, - "user": { - "username": "test-user", - "groups": ["system:authenticated"] - }, - "audiences": ["the-audience"] - } - }`)) - } - })) + s.mockServer.Handle(&test.TokenReviewHandler{}) + s.Run("VerifyToken returns UserInfo for non-empty context", func() { userInfo, audiences, err := s.provider.VerifyToken(s.T().Context(), "fake-context", "some-token", "the-audience") s.Require().NoError(err, "Expected no error from VerifyToken with empty target") diff --git a/pkg/kubernetes/provider_single_test.go b/pkg/kubernetes/provider_single_test.go index ff03e26c..150926b4 100644 --- a/pkg/kubernetes/provider_single_test.go +++ b/pkg/kubernetes/provider_single_test.go @@ -1,7 +1,6 @@ package kubernetes import ( - "net/http" "testing" "github.com/containers/kubernetes-mcp-server/internal/test" @@ -50,6 +49,7 @@ func (s *ProviderSingleTestSuite) TestWithNonOpenShiftCluster() { func (s *ProviderSingleTestSuite) TestWithOpenShiftCluster() { s.mockServer.Handle(&test.InOpenShiftHandler{}) + s.Run("IsOpenShift returns true", func() { inOpenShift := s.provider.IsOpenShift(s.T().Context()) s.True(inOpenShift, "Expected InOpenShift to return true") @@ -57,25 +57,8 @@ func (s *ProviderSingleTestSuite) TestWithOpenShiftCluster() { } func (s *ProviderSingleTestSuite) TestVerifyToken() { - s.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(` - { - "kind": "TokenReview", - "apiVersion": "authentication.k8s.io/v1", - "spec": {"token": "the-token"}, - "status": { - "authenticated": true, - "user": { - "username": "test-user", - "groups": ["system:authenticated"] - }, - "audiences": ["the-audience"] - } - }`)) - } - })) + s.mockServer.Handle(&test.TokenReviewHandler{}) + s.Run("VerifyToken returns UserInfo for empty target (default target)", func() { userInfo, audiences, err := s.provider.VerifyToken(s.T().Context(), "", "the-token", "the-audience") s.Require().NoError(err, "Expected no error from VerifyToken with empty target")