diff --git a/go.mod b/go.mod index 3a74120d..fe12ed45 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/go-jose/go-jose/v4 v4.1.3 github.com/google/jsonschema-go v0.3.0 github.com/mark3labs/mcp-go v0.43.0 + github.com/modelcontextprotocol/go-sdk v1.1.0 github.com/pkg/errors v0.9.1 github.com/spf13/afero v1.15.0 github.com/spf13/cobra v1.10.1 diff --git a/go.sum b/go.sum index 7a5186b8..f6d2a864 100644 --- a/go.sum +++ b/go.sum @@ -209,6 +209,8 @@ github.com/moby/spdystream v0.5.0 h1:7r0J1Si3QO/kjRitvSLVVFUjxMEb/YLj6S9FF62JBCU github.com/moby/spdystream v0.5.0/go.mod h1:xBAYlnt/ay+11ShkdFKNAG7LsyK/tmNBVvVOwrfMgdI= github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ= github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc= +github.com/modelcontextprotocol/go-sdk v1.1.0 h1:Qjayg53dnKC4UZ+792W21e4BpwEZBzwgRW6LrjLWSwA= +github.com/modelcontextprotocol/go-sdk v1.1.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/pkg/http/http.go b/pkg/http/http.go index 8001462c..84103acd 100644 --- a/pkg/http/http.go +++ b/pkg/http/http.go @@ -36,8 +36,8 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.Stat Handler: wrappedMux, } - sseServer := mcpServer.ServeSse(staticConfig.SSEBaseURL, httpServer) - streamableHttpServer := mcpServer.ServeHTTP(httpServer) + sseServer := mcpServer.ServeSse() + streamableHttpServer := mcpServer.ServeHTTP() mux.Handle(sseEndpoint, sseServer) mux.Handle(sseMessageEndpoint, sseServer) mux.Handle(mcpEndpoint, streamableHttpServer) diff --git a/pkg/kubernetes-mcp-server/cmd/root.go b/pkg/kubernetes-mcp-server/cmd/root.go index db1782ab..13d4a14e 100644 --- a/pkg/kubernetes-mcp-server/cmd/root.go +++ b/pkg/kubernetes-mcp-server/cmd/root.go @@ -345,7 +345,8 @@ func (m *MCPServerOptions) Run() error { return internalhttp.Serve(ctx, mcpServer, m.StaticConfig, oidcProvider, httpClient) } - if err := mcpServer.ServeStdio(); err != nil && !errors.Is(err, context.Canceled) { + ctx := context.Background() + if err := mcpServer.ServeStdio(ctx); err != nil && !errors.Is(err, context.Canceled) { return err } diff --git a/pkg/mcp/common_test.go b/pkg/mcp/common_test.go index b91df691..8a0158d5 100644 --- a/pkg/mcp/common_test.go +++ b/pkg/mcp/common_test.go @@ -190,7 +190,7 @@ func (s *BaseMcpSuite) InitMcpClient(options ...transport.StreamableHTTPCOption) var err error s.mcpServer, err = NewServer(Configuration{StaticConfig: s.Cfg}) s.Require().NoError(err, "Expected no error creating MCP server") - s.McpClient = test.NewMcpClient(s.T(), s.mcpServer.ServeHTTP(nil), options...) + s.McpClient = test.NewMcpClient(s.T(), s.mcpServer.ServeHTTP(), options...) } // EnvTestInOpenShift sets up the kubernetes environment to seem to be running OpenShift diff --git a/pkg/mcp/gosdk.go b/pkg/mcp/gosdk.go new file mode 100644 index 00000000..2437b085 --- /dev/null +++ b/pkg/mcp/gosdk.go @@ -0,0 +1,109 @@ +package mcp + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/containers/kubernetes-mcp-server/pkg/api" + "github.com/modelcontextprotocol/go-sdk/mcp" + "k8s.io/utils/ptr" +) + +func ServerToolToGoSdkTool(s *Server, tool api.ServerTool) (*mcp.Tool, mcp.ToolHandler, error) { + goSdkTool := &mcp.Tool{ + Name: tool.Tool.Name, + Description: tool.Tool.Description, + Title: tool.Tool.Annotations.Title, + Annotations: &mcp.ToolAnnotations{ + Title: tool.Tool.Annotations.Title, + ReadOnlyHint: ptr.Deref(tool.Tool.Annotations.ReadOnlyHint, false), + DestructiveHint: tool.Tool.Annotations.DestructiveHint, + IdempotentHint: ptr.Deref(tool.Tool.Annotations.IdempotentHint, false), + OpenWorldHint: tool.Tool.Annotations.OpenWorldHint, + }, + } + if tool.Tool.InputSchema != nil { + schema, err := json.Marshal(tool.Tool.InputSchema) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal tool input schema for tool %s: %v", tool.Tool.Name, err) + } + // TODO: temporary fix to append an empty properties object (some client have trouble parsing a schema without properties) + // As opposed, Gemini had trouble for a while when properties was present but empty. + // https://github.com/containers/kubernetes-mcp-server/issues/340 + if string(schema) == `{"type":"object"}` { + schema = []byte(`{"type":"object","properties":{}}`) + } + + var fixedSchema map[string]interface{} + if err := json.Unmarshal(schema, &fixedSchema); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal tool input schema for tool %s: %v", tool.Tool.Name, err) + } + + goSdkTool.InputSchema = fixedSchema + } + goSdkHandler := func(ctx context.Context, request *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + toolCallRequest, err := GoSdkToolCallRequestToToolCallRequest(request) + if err != nil { + return nil, fmt.Errorf("%v for tool %s", err, tool.Tool.Name) + } + // get the correct derived Kubernetes client for the target specified in the request + cluster := toolCallRequest.GetString(s.p.GetTargetParameterName(), s.p.GetDefaultTarget()) + k, err := s.p.GetDerivedKubernetes(ctx, cluster) + if err != nil { + return nil, err + } + + result, err := tool.Handler(api.ToolHandlerParams{ + Context: ctx, + Kubernetes: k, + ToolCallRequest: toolCallRequest, + ListOutput: s.configuration.ListOutput(), + }) + if err != nil { + return nil, err + } + return NewTextResult(result.Content, result.Error), nil + } + return goSdkTool, goSdkHandler, nil +} + +type ToolCallRequest struct { + Name string + arguments map[string]any +} + +var _ api.ToolCallRequest = (*ToolCallRequest)(nil) + +func GoSdkToolCallRequestToToolCallRequest(request *mcp.CallToolRequest) (*ToolCallRequest, error) { + toolCallParams, ok := request.GetParams().(*mcp.CallToolParamsRaw) + if !ok { + return nil, errors.New("invalid tool call parameters for tool call request") + } + return GoSdkToolCallParamsToToolCallRequest(toolCallParams) +} + +func GoSdkToolCallParamsToToolCallRequest(toolCallParams *mcp.CallToolParamsRaw) (*ToolCallRequest, error) { + var arguments map[string]any + if err := json.Unmarshal(toolCallParams.Arguments, &arguments); err != nil { + return nil, fmt.Errorf("failed to unmarshal tool call arguments: %v", err) + } + return &ToolCallRequest{ + Name: toolCallParams.Name, + arguments: arguments, + }, nil +} + +func (ToolCallRequest *ToolCallRequest) GetArguments() map[string]any { + return ToolCallRequest.arguments +} + +func (ToolCallRequest *ToolCallRequest) GetString(key, defaultValue string) string { + if value, ok := ToolCallRequest.arguments[key]; ok { + if strValue, ok := value.(string); ok { + return strValue + } + } + return defaultValue +} diff --git a/pkg/mcp/m3labs.go b/pkg/mcp/m3labs.go deleted file mode 100644 index ade0f56b..00000000 --- a/pkg/mcp/m3labs.go +++ /dev/null @@ -1,63 +0,0 @@ -package mcp - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" - - "github.com/containers/kubernetes-mcp-server/pkg/api" -) - -func ServerToolToM3LabsServerTool(s *Server, tools []api.ServerTool) ([]server.ServerTool, error) { - m3labTools := make([]server.ServerTool, 0) - for _, tool := range tools { - m3labTool := mcp.Tool{ - Name: tool.Tool.Name, - Description: tool.Tool.Description, - Annotations: mcp.ToolAnnotation{ - Title: tool.Tool.Annotations.Title, - ReadOnlyHint: tool.Tool.Annotations.ReadOnlyHint, - DestructiveHint: tool.Tool.Annotations.DestructiveHint, - IdempotentHint: tool.Tool.Annotations.IdempotentHint, - OpenWorldHint: tool.Tool.Annotations.OpenWorldHint, - }, - } - if tool.Tool.InputSchema != nil { - schema, err := json.Marshal(tool.Tool.InputSchema) - if err != nil { - return nil, fmt.Errorf("failed to marshal tool input schema for tool %s: %v", tool.Tool.Name, err) - } - // TODO: temporary fix to append an empty properties object (some client have trouble parsing a schema without properties) - // As opposed, Gemini had trouble for a while when properties was present but empty. - // https://github.com/containers/kubernetes-mcp-server/issues/340 - if string(schema) == `{"type":"object"}` { - schema = []byte(`{"type":"object","properties":{}}`) - } - m3labTool.RawInputSchema = schema - } - m3labHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - // get the correct derived Kubernetes client for the target specified in the request - cluster := request.GetString(s.p.GetTargetParameterName(), s.p.GetDefaultTarget()) - k, err := s.p.GetDerivedKubernetes(ctx, cluster) - if err != nil { - return nil, err - } - - result, err := tool.Handler(api.ToolHandlerParams{ - Context: ctx, - Kubernetes: k, - ToolCallRequest: request, - ListOutput: s.configuration.ListOutput(), - }) - if err != nil { - return nil, err - } - return NewTextResult(result.Content, result.Error), nil - } - m3labTools = append(m3labTools, server.ServerTool{Tool: m3labTool, Handler: m3labHandler}) - } - return m3labTools, nil -} diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index 5f7511cc..aefc70a2 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -1,16 +1,14 @@ package mcp import ( - "bytes" "context" "fmt" "net/http" + "os" "slices" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/modelcontextprotocol/go-sdk/mcp" authenticationapiv1 "k8s.io/api/authentication/v1" - "k8s.io/klog/v2" "k8s.io/utils/ptr" "github.com/containers/kubernetes-mcp-server/pkg/api" @@ -65,30 +63,31 @@ func (c *Configuration) isToolApplicable(tool api.ServerTool) bool { type Server struct { configuration *Configuration - server *server.MCPServer + server *mcp.Server enabledTools []string p internalk8s.Provider } func NewServer(configuration Configuration) (*Server, error) { - var serverOptions []server.ServerOption - serverOptions = append(serverOptions, - server.WithToolCapabilities(true), - server.WithLogging(), - server.WithToolHandlerMiddleware(toolCallLoggingMiddleware), - ) - if configuration.RequireOAuth && false { // TODO: Disabled scope auth validation for now - serverOptions = append(serverOptions, server.WithToolHandlerMiddleware(toolScopedAuthorizationMiddleware)) - } - s := &Server{ configuration: &configuration, - server: server.NewMCPServer( - version.BinaryName, - version.Version, - serverOptions..., - ), + server: mcp.NewServer( + &mcp.Implementation{ + Name: version.BinaryName, Title: version.BinaryName, Version: version.Version, + }, + &mcp.ServerOptions{ + HasResources: false, + HasPrompts: false, + HasTools: true, + }), + } + + s.server.AddReceivingMiddleware(authHeaderPropagationMiddleware) + s.server.AddReceivingMiddleware(toolCallLoggingMiddleware) + if configuration.RequireOAuth && false { // TODO: Disabled scope auth validation for now + s.server.AddReceivingMiddleware(toolScopedAuthorizationMiddleware) } + if err := s.reloadKubernetesClusterProvider(); err != nil { return nil, err } @@ -139,38 +138,41 @@ func (s *Server) reloadKubernetesClusterProvider() error { s.enabledTools = append(s.enabledTools, tool.Tool.Name) } } - m3labsServerTools, err := ServerToolToM3LabsServerTool(s, applicableTools) - if err != nil { - return fmt.Errorf("failed to convert tools: %v", err) - } - s.server.SetTools(m3labsServerTools...) + // TODO: remove old tools that are no longer applicable + for _, tool := range applicableTools { + goSdkTool, goSdkToolHandler, err := ServerToolToGoSdkTool(s, tool) + if err != nil { + return fmt.Errorf("failed to convert tool %s: %v", tool.Tool.Name, err) + } + s.server.AddTool(goSdkTool, goSdkToolHandler) + } + // TODO: No option to perform a full replacement of tools. + // s.server.SetTools(m3labsServerTools...) // start new watch s.p.WatchTargets(s.reloadKubernetesClusterProvider) return nil } -func (s *Server) ServeStdio() error { - return server.ServeStdio(s.server) +func (s *Server) ServeStdio(ctx context.Context) error { + return s.server.Run(ctx, &mcp.LoggingTransport{Transport: &mcp.StdioTransport{}, Writer: os.Stderr}) } -func (s *Server) ServeSse(baseUrl string, httpServer *http.Server) *server.SSEServer { - options := make([]server.SSEOption, 0) - options = append(options, server.WithSSEContextFunc(contextFunc), server.WithHTTPServer(httpServer)) - if baseUrl != "" { - options = append(options, server.WithBaseURL(baseUrl)) - } - return server.NewSSEServer(s.server, options...) +func (s *Server) ServeSse() *mcp.SSEHandler { + return mcp.NewSSEHandler(func(request *http.Request) *mcp.Server { + return s.server + }, &mcp.SSEOptions{}) } -func (s *Server) ServeHTTP(httpServer *http.Server) *server.StreamableHTTPServer { - options := []server.StreamableHTTPOption{ - server.WithHTTPContextFunc(contextFunc), - server.WithStreamableHTTPServer(httpServer), - server.WithStateLess(true), - } - return server.NewStreamableHTTPServer(s.server, options...) +func (s *Server) ServeHTTP() *mcp.StreamableHTTPHandler { + return mcp.NewStreamableHTTPHandler(func(request *http.Request) *mcp.Server { + return s.server + }, &mcp.StreamableHTTPOptions{ + // For clients to be able to listen to tool changes, we need to set the server stateful + // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server + Stateless: false, + }) } // KubernetesApiVerifyToken verifies the given token with the audience by @@ -205,8 +207,7 @@ func NewTextResult(content string, err error) *mcp.CallToolResult { return &mcp.CallToolResult{ IsError: true, Content: []mcp.Content{ - mcp.TextContent{ - Type: "text", + &mcp.TextContent{ Text: err.Error(), }, }, @@ -214,52 +215,9 @@ func NewTextResult(content string, err error) *mcp.CallToolResult { } return &mcp.CallToolResult{ Content: []mcp.Content{ - mcp.TextContent{ - Type: "text", + &mcp.TextContent{ Text: content, }, }, } } - -func contextFunc(ctx context.Context, r *http.Request) context.Context { - // Get the standard Authorization header (OAuth compliant) - authHeader := r.Header.Get(string(internalk8s.OAuthAuthorizationHeader)) - if authHeader != "" { - return context.WithValue(ctx, internalk8s.OAuthAuthorizationHeader, authHeader) - } - - // Fallback to custom header for backward compatibility - customAuthHeader := r.Header.Get(string(internalk8s.CustomAuthorizationHeader)) - if customAuthHeader != "" { - return context.WithValue(ctx, internalk8s.OAuthAuthorizationHeader, customAuthHeader) - } - - return ctx -} - -func toolCallLoggingMiddleware(next server.ToolHandlerFunc) server.ToolHandlerFunc { - return func(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.CallToolResult, error) { - klog.V(5).Infof("mcp tool call: %s(%v)", ctr.Params.Name, ctr.Params.Arguments) - if ctr.Header != nil { - buffer := bytes.NewBuffer(make([]byte, 0)) - if err := ctr.Header.WriteSubset(buffer, map[string]bool{"Authorization": true, "authorization": true}); err == nil { - klog.V(7).Infof("mcp tool call headers: %s", buffer) - } - } - return next(ctx, ctr) - } -} - -func toolScopedAuthorizationMiddleware(next server.ToolHandlerFunc) server.ToolHandlerFunc { - return func(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.CallToolResult, error) { - scopes, ok := ctx.Value(TokenScopesContextKey).([]string) - if !ok { - return NewTextResult("", fmt.Errorf("authorization failed: Access denied: Tool '%s' requires scope 'mcp:%s' but no scope is available", ctr.Params.Name, ctr.Params.Name)), nil - } - if !slices.Contains(scopes, "mcp:"+ctr.Params.Name) && !slices.Contains(scopes, ctr.Params.Name) { - return NewTextResult("", fmt.Errorf("authorization failed: Access denied: Tool '%s' requires scope 'mcp:%s' but only scopes %s are available", ctr.Params.Name, ctr.Params.Name, scopes)), nil - } - return next(ctx, ctr) - } -} diff --git a/pkg/mcp/middleware.go b/pkg/mcp/middleware.go new file mode 100644 index 00000000..ec6f4d42 --- /dev/null +++ b/pkg/mcp/middleware.go @@ -0,0 +1,61 @@ +package mcp + +import ( + "bytes" + "context" + "fmt" + "slices" + + internalk8s "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" + "github.com/modelcontextprotocol/go-sdk/mcp" + "k8s.io/klog/v2" +) + +func authHeaderPropagationMiddleware(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + if req.GetExtra() != nil && req.GetExtra().Header != nil { + // Get the standard Authorization header (OAuth compliant) + authHeader := req.GetExtra().Header.Get(string(internalk8s.OAuthAuthorizationHeader)) + if authHeader != "" { + return next(context.WithValue(ctx, internalk8s.OAuthAuthorizationHeader, authHeader), method, req) + } + + // Fallback to custom header for backward compatibility + customAuthHeader := req.GetExtra().Header.Get(string(internalk8s.CustomAuthorizationHeader)) + if customAuthHeader != "" { + return next(context.WithValue(ctx, internalk8s.OAuthAuthorizationHeader, customAuthHeader), method, req) + } + } + return next(ctx, method, req) + } +} + +func toolCallLoggingMiddleware(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + switch params := req.GetParams().(type) { + case *mcp.CallToolParamsRaw: + toolCallRequest, _ := GoSdkToolCallParamsToToolCallRequest(params) + klog.V(5).Infof("mcp tool call: %s(%v)", toolCallRequest.Name, toolCallRequest.GetArguments()) + if req.GetExtra() != nil && req.GetExtra().Header != nil { + buffer := bytes.NewBuffer(make([]byte, 0)) + if err := req.GetExtra().Header.WriteSubset(buffer, map[string]bool{"Authorization": true, "authorization": true}); err == nil { + klog.V(7).Infof("mcp tool call headers: %s", buffer) + } + } + } + return next(ctx, method, req) + } +} + +func toolScopedAuthorizationMiddleware(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + scopes, ok := ctx.Value(TokenScopesContextKey).([]string) + if !ok { + return NewTextResult("", fmt.Errorf("authorization failed: Access denied: Tool '%s' requires scope 'mcp:%s' but no scope is available", method, method)), nil + } + if !slices.Contains(scopes, "mcp:"+method) && !slices.Contains(scopes, method) { + return NewTextResult("", fmt.Errorf("authorization failed: Access denied: Tool '%s' requires scope 'mcp:%s' but only scopes %s are available", method, method, scopes)), nil + } + return next(ctx, method, req) + } +} diff --git a/pkg/mcp/toolsets_test.go b/pkg/mcp/toolsets_test.go index d81392a5..f58cb913 100644 --- a/pkg/mcp/toolsets_test.go +++ b/pkg/mcp/toolsets_test.go @@ -208,7 +208,7 @@ func (s *ToolsetsSuite) InitMcpClient() { var err error s.mcpServer, err = NewServer(Configuration{StaticConfig: s.Cfg}) s.Require().NoError(err, "Expected no error creating MCP server") - s.McpClient = test.NewMcpClient(s.T(), s.mcpServer.ServeHTTP(nil)) + s.McpClient = test.NewMcpClient(s.T(), s.mcpServer.ServeHTTP()) } func TestToolsets(t *testing.T) {