Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
package mcp_test

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"regexp"
"slices"
"strings"
"testing"
"time"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"go.uber.org/goleak"

"github.com/coder/aibridge/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/require"

mcplib "github.com/mark3labs/mcp-go/mcp"
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -282,3 +293,78 @@ func TestFilterAllowedTools(t *testing.T) {
})
}
}

func TestToolInjectionOrder(t *testing.T) {
t.Parallel()

// Setup.
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)

// Given: a MCP mock server offering a set of tools.
mcpSrv := httptest.NewServer(createMockMCPSrv(t))
t.Cleanup(mcpSrv.Close)

// When: creating two MCP server proxies, both listing the same tools by name but under different server namespaces.
proxy, err := mcp.NewStreamableHTTPServerProxy(logger, "coder", mcpSrv.URL, nil, nil, nil)
require.NoError(t, err)
proxy2, err := mcp.NewStreamableHTTPServerProxy(logger, "shmoder", mcpSrv.URL, nil, nil, nil)
require.NoError(t, err)

// Then: initialize both proxies.
require.NoError(t, proxy.Init(ctx))
require.NoError(t, proxy2.Init(ctx))

// Then: validate that their tools are separately sorted stably.
validateToolOrder(t, proxy)
validateToolOrder(t, proxy2)

// When: creating a manager which contains both MCP server proxies.
mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{
"coder": proxy,
"shmoder": proxy2,
})
require.NoError(t, mgr.Init(ctx))

// Then: the tools from both servers should be collectively sorted stably.
validateToolOrder(t, mgr)
}

func validateToolOrder(t *testing.T, proxy mcp.ServerProxier) {
t.Helper()

tools := proxy.ListTools()
require.NotEmpty(t, tools)
require.Greater(t, len(tools), 1)

// Ensure tools are sorted by ID; unstable order can bust the cache and lead to increased costs.
sorted := slices.Clone(tools)
slices.SortFunc(sorted, func(a, b *mcp.Tool) int {
return strings.Compare(a.ID, b.ID)
})
for i, tool := range tools {
require.Equal(t, tool.ID, sorted[i].ID, "tool order is not stable")
}
}

func createMockMCPSrv(t *testing.T) http.Handler {
t.Helper()

s := server.NewMCPServer(
"Mock coder MCP server",
"1.0.0",
server.WithToolCapabilities(true),
)

for _, name := range []string{"coder_list_workspaces", "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user"} {
tool := mcplib.NewTool(name,
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
)
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
return mcplib.NewToolResultText("mock"), nil
})
}

return server.NewStreamableHTTPServer(s)
}
6 changes: 5 additions & 1 deletion mcp/proxy_streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ func (p *StreamableHTTPServerProxy) Init(ctx context.Context) error {
}

func (p *StreamableHTTPServerProxy) ListTools() []*Tool {
return maps.Values(p.tools)
tools := maps.Values(p.tools)
slices.SortStableFunc(tools, func(a, b *Tool) int {
return strings.Compare(a.ID, b.ID)
})
return tools
}

func (p *StreamableHTTPServerProxy) GetTool(name string) *Tool {
Expand Down
7 changes: 7 additions & 0 deletions mcp/server_proxy_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package mcp
import (
"context"
"fmt"
"slices"
"strings"
"sync"

"github.com/coder/aibridge/utils"
Expand Down Expand Up @@ -82,6 +84,11 @@ func (s *ServerProxyManager) ListTools() []*Tool {
for _, tool := range s.tools {
out = append(out, tool)
}

slices.SortStableFunc(out, func(a, b *Tool) int {
return strings.Compare(a.ID, b.ID)
})

return out
}

Expand Down