Skip to content

Commit cc9fa29

Browse files
authored
fix: inject MCP tools with stable order (#49)
1 parent 3c485cc commit cc9fa29

File tree

4 files changed

+222
-7
lines changed

4 files changed

+222
-7
lines changed

bridge_integration_test.go

Lines changed: 124 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,122 @@ func TestErrorHandling(t *testing.T) {
11311131
}
11321132
}
11331133

1134+
// TestStableRequestEncoding validates that a given intercepted request and a
1135+
// given set of injected tools should result identical payloads.
1136+
//
1137+
// Should the payload vary, it may subvert any caching mechanisms the provider may have.
1138+
func TestStableRequestEncoding(t *testing.T) {
1139+
t.Parallel()
1140+
1141+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
1142+
1143+
cases := []struct {
1144+
name string
1145+
fixture []byte
1146+
createRequestFunc createRequestFunc
1147+
configureFunc configureFunc
1148+
}{
1149+
{
1150+
name: aibridge.ProviderAnthropic,
1151+
fixture: antSimple,
1152+
createRequestFunc: createAnthropicMessagesReq,
1153+
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
1154+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr)
1155+
},
1156+
},
1157+
{
1158+
name: aibridge.ProviderOpenAI,
1159+
fixture: oaiSimple,
1160+
createRequestFunc: createOpenAIChatCompletionsReq,
1161+
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
1162+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
1163+
},
1164+
},
1165+
}
1166+
1167+
for _, tc := range cases {
1168+
t.Run(tc.name, func(t *testing.T) {
1169+
t.Parallel()
1170+
1171+
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
1172+
t.Cleanup(cancel)
1173+
1174+
// Setup MCP tools.
1175+
tools := setupMCPServerProxiesForTest(t)
1176+
1177+
// Configure the bridge with injected tools.
1178+
mcpMgr := mcp.NewServerProxyManager(tools)
1179+
require.NoError(t, mcpMgr.Init(ctx))
1180+
1181+
arc := txtar.Parse(tc.fixture)
1182+
t.Logf("%s: %s", t.Name(), arc.Comment)
1183+
1184+
files := filesMap(arc)
1185+
require.Contains(t, files, fixtureRequest)
1186+
require.Contains(t, files, fixtureNonStreamingResponse)
1187+
1188+
var (
1189+
reference []byte
1190+
reqCount atomic.Int32
1191+
)
1192+
1193+
// Create a mock server that captures and compares request bodies.
1194+
mockSrv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1195+
reqCount.Add(1)
1196+
1197+
// Capture the raw request body.
1198+
raw, err := io.ReadAll(r.Body)
1199+
defer r.Body.Close()
1200+
require.NoError(t, err)
1201+
require.NotEmpty(t, raw)
1202+
1203+
// Store the first instance as the reference value.
1204+
if reference == nil {
1205+
reference = raw
1206+
} else {
1207+
// Compare all subsequent requests to the reference.
1208+
assert.JSONEq(t, string(reference), string(raw))
1209+
}
1210+
1211+
// Return a valid API response.
1212+
w.Header().Set("Content-Type", "application/json")
1213+
w.WriteHeader(http.StatusOK)
1214+
_, _ = w.Write(files[fixtureNonStreamingResponse])
1215+
}))
1216+
mockSrv.Config.BaseContext = func(_ net.Listener) context.Context {
1217+
return ctx
1218+
}
1219+
mockSrv.Start()
1220+
t.Cleanup(mockSrv.Close)
1221+
1222+
recorder := &mockRecorderClient{}
1223+
bridge, err := tc.configureFunc(mockSrv.URL, recorder, mcpMgr)
1224+
require.NoError(t, err)
1225+
1226+
// Invoke request to mocked API via aibridge.
1227+
bridgeSrv := httptest.NewUnstartedServer(bridge)
1228+
bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
1229+
return aibridge.AsActor(ctx, userID, nil)
1230+
}
1231+
bridgeSrv.Start()
1232+
t.Cleanup(bridgeSrv.Close)
1233+
1234+
// Make multiple requests and verify they all have identical payloads.
1235+
count := 10
1236+
for range count {
1237+
req := tc.createRequestFunc(t, bridgeSrv.URL, files[fixtureRequest])
1238+
client := &http.Client{}
1239+
resp, err := client.Do(req)
1240+
require.NoError(t, err)
1241+
require.Equal(t, http.StatusOK, resp.StatusCode)
1242+
_ = resp.Body.Close()
1243+
}
1244+
1245+
require.EqualValues(t, count, reqCount.Load())
1246+
})
1247+
}
1248+
}
1249+
11341250
func calculateTotalInputTokens(in []*aibridge.TokenUsageRecord) int64 {
11351251
var total int64
11361252
for _, el := range in {
@@ -1340,12 +1456,14 @@ func createMockMCPSrv(t *testing.T) http.Handler {
13401456
server.WithToolCapabilities(true),
13411457
)
13421458

1343-
tool := mcplib.NewTool(mockToolName,
1344-
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", mockToolName)),
1345-
)
1346-
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
1347-
return mcplib.NewToolResultText("mock"), nil
1348-
})
1459+
for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} {
1460+
tool := mcplib.NewTool(name,
1461+
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
1462+
)
1463+
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
1464+
return mcplib.NewToolResultText("mock"), nil
1465+
})
1466+
}
13491467

13501468
return server.NewStreamableHTTPServer(s)
13511469
}

mcp/mcp_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
package mcp_test
22

33
import (
4+
"context"
5+
"fmt"
6+
"net/http"
7+
"net/http/httptest"
48
"regexp"
9+
"slices"
10+
"strings"
511
"testing"
12+
"time"
613

714
"cdr.dev/slog"
15+
"cdr.dev/slog/sloggers/slogtest"
816
"go.uber.org/goleak"
917

1018
"github.com/coder/aibridge/mcp"
19+
"github.com/mark3labs/mcp-go/server"
1120
"github.com/stretchr/testify/require"
21+
22+
mcplib "github.com/mark3labs/mcp-go/mcp"
1223
)
1324

1425
func TestMain(m *testing.M) {
@@ -282,3 +293,78 @@ func TestFilterAllowedTools(t *testing.T) {
282293
})
283294
}
284295
}
296+
297+
func TestToolInjectionOrder(t *testing.T) {
298+
t.Parallel()
299+
300+
// Setup.
301+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
302+
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
303+
t.Cleanup(cancel)
304+
305+
// Given: a MCP mock server offering a set of tools.
306+
mcpSrv := httptest.NewServer(createMockMCPSrv(t))
307+
t.Cleanup(mcpSrv.Close)
308+
309+
// When: creating two MCP server proxies, both listing the same tools by name but under different server namespaces.
310+
proxy, err := mcp.NewStreamableHTTPServerProxy(logger, "coder", mcpSrv.URL, nil, nil, nil)
311+
require.NoError(t, err)
312+
proxy2, err := mcp.NewStreamableHTTPServerProxy(logger, "shmoder", mcpSrv.URL, nil, nil, nil)
313+
require.NoError(t, err)
314+
315+
// Then: initialize both proxies.
316+
require.NoError(t, proxy.Init(ctx))
317+
require.NoError(t, proxy2.Init(ctx))
318+
319+
// Then: validate that their tools are separately sorted stably.
320+
validateToolOrder(t, proxy)
321+
validateToolOrder(t, proxy2)
322+
323+
// When: creating a manager which contains both MCP server proxies.
324+
mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{
325+
"coder": proxy,
326+
"shmoder": proxy2,
327+
})
328+
require.NoError(t, mgr.Init(ctx))
329+
330+
// Then: the tools from both servers should be collectively sorted stably.
331+
validateToolOrder(t, mgr)
332+
}
333+
334+
func validateToolOrder(t *testing.T, proxy mcp.ServerProxier) {
335+
t.Helper()
336+
337+
tools := proxy.ListTools()
338+
require.NotEmpty(t, tools)
339+
require.Greater(t, len(tools), 1)
340+
341+
// Ensure tools are sorted by ID; unstable order can bust the cache and lead to increased costs.
342+
sorted := slices.Clone(tools)
343+
slices.SortFunc(sorted, func(a, b *mcp.Tool) int {
344+
return strings.Compare(a.ID, b.ID)
345+
})
346+
for i, tool := range tools {
347+
require.Equal(t, tool.ID, sorted[i].ID, "tool order is not stable")
348+
}
349+
}
350+
351+
func createMockMCPSrv(t *testing.T) http.Handler {
352+
t.Helper()
353+
354+
s := server.NewMCPServer(
355+
"Mock coder MCP server",
356+
"1.0.0",
357+
server.WithToolCapabilities(true),
358+
)
359+
360+
for _, name := range []string{"coder_list_workspaces", "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user"} {
361+
tool := mcplib.NewTool(name,
362+
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
363+
)
364+
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
365+
return mcplib.NewToolResultText("mock"), nil
366+
})
367+
}
368+
369+
return server.NewStreamableHTTPServer(s)
370+
}

mcp/proxy_streamable_http.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@ func (p *StreamableHTTPServerProxy) Init(ctx context.Context) error {
8989
}
9090

9191
func (p *StreamableHTTPServerProxy) ListTools() []*Tool {
92-
return maps.Values(p.tools)
92+
tools := maps.Values(p.tools)
93+
slices.SortStableFunc(tools, func(a, b *Tool) int {
94+
return strings.Compare(a.ID, b.ID)
95+
})
96+
return tools
9397
}
9498

9599
func (p *StreamableHTTPServerProxy) GetTool(name string) *Tool {

mcp/server_proxy_manager.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package mcp
33
import (
44
"context"
55
"fmt"
6+
"slices"
7+
"strings"
68
"sync"
79

810
"github.com/coder/aibridge/utils"
@@ -82,6 +84,11 @@ func (s *ServerProxyManager) ListTools() []*Tool {
8284
for _, tool := range s.tools {
8385
out = append(out, tool)
8486
}
87+
88+
slices.SortStableFunc(out, func(a, b *Tool) int {
89+
return strings.Compare(a.ID, b.ID)
90+
})
91+
8592
return out
8693
}
8794

0 commit comments

Comments
 (0)