Skip to content

Commit 5899d51

Browse files
committed
fix: inject MCP tools with stable order (#49)
Signed-off-by: Danny Kopping <danny@coder.com>
1 parent 8275872 commit 5899d51

File tree

4 files changed

+222
-7
lines changed

4 files changed

+222
-7
lines changed

bridge_integration_test.go

Lines changed: 124 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,122 @@ func TestErrorHandling(t *testing.T) {
959959
}
960960
}
961961

962+
// TestStableRequestEncoding validates that a given intercepted request and a
963+
// given set of injected tools should result identical payloads.
964+
//
965+
// Should the payload vary, it may subvert any caching mechanisms the provider may have.
966+
func TestStableRequestEncoding(t *testing.T) {
967+
t.Parallel()
968+
969+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
970+
971+
cases := []struct {
972+
name string
973+
fixture []byte
974+
createRequestFunc createRequestFunc
975+
configureFunc configureFunc
976+
}{
977+
{
978+
name: aibridge.ProviderAnthropic,
979+
fixture: antSimple,
980+
createRequestFunc: createAnthropicMessagesReq,
981+
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
982+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr)
983+
},
984+
},
985+
{
986+
name: aibridge.ProviderOpenAI,
987+
fixture: oaiSimple,
988+
createRequestFunc: createOpenAIChatCompletionsReq,
989+
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
990+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr)
991+
},
992+
},
993+
}
994+
995+
for _, tc := range cases {
996+
t.Run(tc.name, func(t *testing.T) {
997+
t.Parallel()
998+
999+
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
1000+
t.Cleanup(cancel)
1001+
1002+
// Setup MCP tools.
1003+
tools := setupMCPServerProxiesForTest(t)
1004+
1005+
// Configure the bridge with injected tools.
1006+
mcpMgr := mcp.NewServerProxyManager(tools)
1007+
require.NoError(t, mcpMgr.Init(ctx))
1008+
1009+
arc := txtar.Parse(tc.fixture)
1010+
t.Logf("%s: %s", t.Name(), arc.Comment)
1011+
1012+
files := filesMap(arc)
1013+
require.Contains(t, files, fixtureRequest)
1014+
require.Contains(t, files, fixtureNonStreamingResponse)
1015+
1016+
var (
1017+
reference []byte
1018+
reqCount atomic.Int32
1019+
)
1020+
1021+
// Create a mock server that captures and compares request bodies.
1022+
mockSrv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1023+
reqCount.Add(1)
1024+
1025+
// Capture the raw request body.
1026+
raw, err := io.ReadAll(r.Body)
1027+
defer r.Body.Close()
1028+
require.NoError(t, err)
1029+
require.NotEmpty(t, raw)
1030+
1031+
// Store the first instance as the reference value.
1032+
if reference == nil {
1033+
reference = raw
1034+
} else {
1035+
// Compare all subsequent requests to the reference.
1036+
assert.JSONEq(t, string(reference), string(raw))
1037+
}
1038+
1039+
// Return a valid API response.
1040+
w.Header().Set("Content-Type", "application/json")
1041+
w.WriteHeader(http.StatusOK)
1042+
_, _ = w.Write(files[fixtureNonStreamingResponse])
1043+
}))
1044+
mockSrv.Config.BaseContext = func(_ net.Listener) context.Context {
1045+
return ctx
1046+
}
1047+
mockSrv.Start()
1048+
t.Cleanup(mockSrv.Close)
1049+
1050+
recorder := &mockRecorderClient{}
1051+
bridge, err := tc.configureFunc(mockSrv.URL, recorder, mcpMgr)
1052+
require.NoError(t, err)
1053+
1054+
// Invoke request to mocked API via aibridge.
1055+
bridgeSrv := httptest.NewUnstartedServer(bridge)
1056+
bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
1057+
return aibridge.AsActor(ctx, userID, nil)
1058+
}
1059+
bridgeSrv.Start()
1060+
t.Cleanup(bridgeSrv.Close)
1061+
1062+
// Make multiple requests and verify they all have identical payloads.
1063+
count := 10
1064+
for range count {
1065+
req := tc.createRequestFunc(t, bridgeSrv.URL, files[fixtureRequest])
1066+
client := &http.Client{}
1067+
resp, err := client.Do(req)
1068+
require.NoError(t, err)
1069+
require.Equal(t, http.StatusOK, resp.StatusCode)
1070+
_ = resp.Body.Close()
1071+
}
1072+
1073+
require.EqualValues(t, count, reqCount.Load())
1074+
})
1075+
}
1076+
}
1077+
9621078
func calculateTotalInputTokens(in []*aibridge.TokenUsageRecord) int64 {
9631079
var total int64
9641080
for _, el := range in {
@@ -1142,12 +1258,14 @@ func createMockMCPSrv(t *testing.T) http.Handler {
11421258
server.WithToolCapabilities(true),
11431259
)
11441260

1145-
tool := mcplib.NewTool(mockToolName,
1146-
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", mockToolName)),
1147-
)
1148-
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
1149-
return mcplib.NewToolResultText("mock"), nil
1150-
})
1261+
for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} {
1262+
tool := mcplib.NewTool(name,
1263+
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
1264+
)
1265+
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
1266+
return mcplib.NewToolResultText("mock"), nil
1267+
})
1268+
}
11511269

11521270
return server.NewStreamableHTTPServer(s)
11531271
}

mcp/mcp_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
package mcp_test
22

33
import (
4+
"context"
5+
"fmt"
6+
"net/http"
7+
"net/http/httptest"
48
"regexp"
9+
"slices"
10+
"strings"
511
"testing"
12+
"time"
613

714
"cdr.dev/slog"
15+
"cdr.dev/slog/sloggers/slogtest"
816
"go.uber.org/goleak"
917

1018
"github.com/coder/aibridge/mcp"
19+
"github.com/mark3labs/mcp-go/server"
1120
"github.com/stretchr/testify/require"
21+
22+
mcplib "github.com/mark3labs/mcp-go/mcp"
1223
)
1324

1425
func TestMain(m *testing.M) {
@@ -282,3 +293,78 @@ func TestFilterAllowedTools(t *testing.T) {
282293
})
283294
}
284295
}
296+
297+
func TestToolInjectionOrder(t *testing.T) {
298+
t.Parallel()
299+
300+
// Setup.
301+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
302+
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
303+
t.Cleanup(cancel)
304+
305+
// Given: a MCP mock server offering a set of tools.
306+
mcpSrv := httptest.NewServer(createMockMCPSrv(t))
307+
t.Cleanup(mcpSrv.Close)
308+
309+
// When: creating two MCP server proxies, both listing the same tools by name but under different server namespaces.
310+
proxy, err := mcp.NewStreamableHTTPServerProxy(logger, "coder", mcpSrv.URL, nil, nil, nil)
311+
require.NoError(t, err)
312+
proxy2, err := mcp.NewStreamableHTTPServerProxy(logger, "shmoder", mcpSrv.URL, nil, nil, nil)
313+
require.NoError(t, err)
314+
315+
// Then: initialize both proxies.
316+
require.NoError(t, proxy.Init(ctx))
317+
require.NoError(t, proxy2.Init(ctx))
318+
319+
// Then: validate that their tools are separately sorted stably.
320+
validateToolOrder(t, proxy)
321+
validateToolOrder(t, proxy2)
322+
323+
// When: creating a manager which contains both MCP server proxies.
324+
mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{
325+
"coder": proxy,
326+
"shmoder": proxy2,
327+
})
328+
require.NoError(t, mgr.Init(ctx))
329+
330+
// Then: the tools from both servers should be collectively sorted stably.
331+
validateToolOrder(t, mgr)
332+
}
333+
334+
func validateToolOrder(t *testing.T, proxy mcp.ServerProxier) {
335+
t.Helper()
336+
337+
tools := proxy.ListTools()
338+
require.NotEmpty(t, tools)
339+
require.Greater(t, len(tools), 1)
340+
341+
// Ensure tools are sorted by ID; unstable order can bust the cache and lead to increased costs.
342+
sorted := slices.Clone(tools)
343+
slices.SortFunc(sorted, func(a, b *mcp.Tool) int {
344+
return strings.Compare(a.ID, b.ID)
345+
})
346+
for i, tool := range tools {
347+
require.Equal(t, tool.ID, sorted[i].ID, "tool order is not stable")
348+
}
349+
}
350+
351+
func createMockMCPSrv(t *testing.T) http.Handler {
352+
t.Helper()
353+
354+
s := server.NewMCPServer(
355+
"Mock coder MCP server",
356+
"1.0.0",
357+
server.WithToolCapabilities(true),
358+
)
359+
360+
for _, name := range []string{"coder_list_workspaces", "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user"} {
361+
tool := mcplib.NewTool(name,
362+
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
363+
)
364+
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
365+
return mcplib.NewToolResultText("mock"), nil
366+
})
367+
}
368+
369+
return server.NewStreamableHTTPServer(s)
370+
}

mcp/proxy_streamable_http.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@ func (p *StreamableHTTPServerProxy) Init(ctx context.Context) error {
8989
}
9090

9191
func (p *StreamableHTTPServerProxy) ListTools() []*Tool {
92-
return maps.Values(p.tools)
92+
tools := maps.Values(p.tools)
93+
slices.SortStableFunc(tools, func(a, b *Tool) int {
94+
return strings.Compare(a.ID, b.ID)
95+
})
96+
return tools
9397
}
9498

9599
func (p *StreamableHTTPServerProxy) GetTool(name string) *Tool {

mcp/server_proxy_manager.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package mcp
33
import (
44
"context"
55
"fmt"
6+
"slices"
7+
"strings"
68
"sync"
79

810
"github.com/coder/aibridge/utils"
@@ -82,6 +84,11 @@ func (s *ServerProxyManager) ListTools() []*Tool {
8284
for _, tool := range s.tools {
8385
out = append(out, tool)
8486
}
87+
88+
slices.SortStableFunc(out, func(a, b *Tool) int {
89+
return strings.Compare(a.ID, b.ID)
90+
})
91+
8592
return out
8693
}
8794

0 commit comments

Comments
 (0)