Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
package mcp_test

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"regexp"
"slices"
"strings"
"testing"
"time"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"go.uber.org/goleak"

"github.com/coder/aibridge/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/require"

mcplib "github.com/mark3labs/mcp-go/mcp"
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -282,3 +293,66 @@ func TestFilterAllowedTools(t *testing.T) {
})
}
}

func TestToolInjectionOrder(t *testing.T) {
t.Parallel()

// Setup Coder MCP integration
mcpSrv := httptest.NewServer(createMockMCPSrv(t))
t.Cleanup(mcpSrv.Close)

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)

proxy, err := mcp.NewStreamableHTTPServerProxy(logger, "coder", mcpSrv.URL, nil, nil, nil)
require.NoError(t, err)

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)

require.NoError(t, proxy.Init(ctx))
validateToolOrder(t, proxy)

mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{
"mock": proxy,
})
require.NoError(t, mgr.Init(ctx))
validateToolOrder(t, proxy)
}

func validateToolOrder(t *testing.T, proxy mcp.ServerProxier) {
t.Helper()

tools := proxy.ListTools()
require.NotEmpty(t, tools)
require.Greater(t, len(tools), 1)

// Ensure tools are sorted by name; unstable order can bust the cache and lead to increased costs.
sorted := slices.Clone(tools)
slices.SortFunc(sorted, func(a, b *mcp.Tool) int {
return strings.Compare(a.Name, b.Name)
})
for i, tool := range tools {
require.Equal(t, tool.Name, sorted[i].Name, "tool order is not stable")
}
}

func createMockMCPSrv(t *testing.T) http.Handler {
t.Helper()

s := server.NewMCPServer(
"Mock coder MCP server",
"1.0.0",
server.WithToolCapabilities(true),
)

for _, name := range []string{"coder_list_workspaces", "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user"} {
tool := mcplib.NewTool(name,
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
)
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
return mcplib.NewToolResultText("mock"), nil
})
}

return server.NewStreamableHTTPServer(s)
}
6 changes: 5 additions & 1 deletion mcp/proxy_streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ func (p *StreamableHTTPServerProxy) Init(ctx context.Context) error {
}

func (p *StreamableHTTPServerProxy) ListTools() []*Tool {
return maps.Values(p.tools)
tools := maps.Values(p.tools)
slices.SortStableFunc(tools, func(a, b *Tool) int {
return strings.Compare(a.ID, b.ID)
})
return tools
}

func (p *StreamableHTTPServerProxy) GetTool(name string) *Tool {
Expand Down
7 changes: 7 additions & 0 deletions mcp/server_proxy_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package mcp
import (
"context"
"fmt"
"slices"
"strings"
"sync"

"github.com/coder/aibridge/utils"
Expand Down Expand Up @@ -82,6 +84,11 @@ func (s *ServerProxyManager) ListTools() []*Tool {
for _, tool := range s.tools {
out = append(out, tool)
}

slices.SortStableFunc(out, func(a, b *Tool) int {
return strings.Compare(a.ID, b.ID)
})

return out
}

Expand Down