diff --git a/bridge_integration_test.go b/bridge_integration_test.go index c606340..5bf519e 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -1131,6 +1131,122 @@ func TestErrorHandling(t *testing.T) { } } +// TestStableRequestEncoding validates that a given intercepted request and a +// given set of injected tools should result identical payloads. +// +// Should the payload vary, it may subvert any caching mechanisms the provider may have. +func TestStableRequestEncoding(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + + cases := []struct { + name string + fixture []byte + createRequestFunc createRequestFunc + configureFunc configureFunc + }{ + { + name: aibridge.ProviderAnthropic, + fixture: antSimple, + createRequestFunc: createAnthropicMessagesReq, + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr) + }, + }, + { + name: aibridge.ProviderOpenAI, + fixture: oaiSimple, + createRequestFunc: createOpenAIChatCompletionsReq, + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr) + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Setup MCP tools. + tools := setupMCPServerProxiesForTest(t) + + // Configure the bridge with injected tools. + mcpMgr := mcp.NewServerProxyManager(tools) + require.NoError(t, mcpMgr.Init(ctx)) + + arc := txtar.Parse(tc.fixture) + t.Logf("%s: %s", t.Name(), arc.Comment) + + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureNonStreamingResponse) + + var ( + reference []byte + reqCount atomic.Int32 + ) + + // Create a mock server that captures and compares request bodies. + mockSrv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqCount.Add(1) + + // Capture the raw request body. + raw, err := io.ReadAll(r.Body) + defer r.Body.Close() + require.NoError(t, err) + require.NotEmpty(t, raw) + + // Store the first instance as the reference value. + if reference == nil { + reference = raw + } else { + // Compare all subsequent requests to the reference. + assert.JSONEq(t, string(reference), string(raw)) + } + + // Return a valid API response. + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(files[fixtureNonStreamingResponse]) + })) + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return ctx + } + mockSrv.Start() + t.Cleanup(mockSrv.Close) + + recorder := &mockRecorderClient{} + bridge, err := tc.configureFunc(mockSrv.URL, recorder, mcpMgr) + require.NoError(t, err) + + // Invoke request to mocked API via aibridge. + bridgeSrv := httptest.NewUnstartedServer(bridge) + bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, userID, nil) + } + bridgeSrv.Start() + t.Cleanup(bridgeSrv.Close) + + // Make multiple requests and verify they all have identical payloads. + count := 10 + for range count { + req := tc.createRequestFunc(t, bridgeSrv.URL, files[fixtureRequest]) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + _ = resp.Body.Close() + } + + require.EqualValues(t, count, reqCount.Load()) + }) + } +} + func calculateTotalInputTokens(in []*aibridge.TokenUsageRecord) int64 { var total int64 for _, el := range in { @@ -1340,12 +1456,14 @@ func createMockMCPSrv(t *testing.T) http.Handler { server.WithToolCapabilities(true), ) - tool := mcplib.NewTool(mockToolName, - mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", mockToolName)), - ) - s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { - return mcplib.NewToolResultText("mock"), nil - }) + for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} { + tool := mcplib.NewTool(name, + mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)), + ) + s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + return mcplib.NewToolResultText("mock"), nil + }) + } return server.NewStreamableHTTPServer(s) } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index e1dcd8e..4d6e2d3 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -1,14 +1,25 @@ package mcp_test import ( + "context" + "fmt" + "net/http" + "net/http/httptest" "regexp" + "slices" + "strings" "testing" + "time" "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" "go.uber.org/goleak" "github.com/coder/aibridge/mcp" + "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/require" + + mcplib "github.com/mark3labs/mcp-go/mcp" ) func TestMain(m *testing.M) { @@ -282,3 +293,78 @@ func TestFilterAllowedTools(t *testing.T) { }) } } + +func TestToolInjectionOrder(t *testing.T) { + t.Parallel() + + // Setup. + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Given: a MCP mock server offering a set of tools. + mcpSrv := httptest.NewServer(createMockMCPSrv(t)) + t.Cleanup(mcpSrv.Close) + + // When: creating two MCP server proxies, both listing the same tools by name but under different server namespaces. + proxy, err := mcp.NewStreamableHTTPServerProxy(logger, "coder", mcpSrv.URL, nil, nil, nil) + require.NoError(t, err) + proxy2, err := mcp.NewStreamableHTTPServerProxy(logger, "shmoder", mcpSrv.URL, nil, nil, nil) + require.NoError(t, err) + + // Then: initialize both proxies. + require.NoError(t, proxy.Init(ctx)) + require.NoError(t, proxy2.Init(ctx)) + + // Then: validate that their tools are separately sorted stably. + validateToolOrder(t, proxy) + validateToolOrder(t, proxy2) + + // When: creating a manager which contains both MCP server proxies. + mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{ + "coder": proxy, + "shmoder": proxy2, + }) + require.NoError(t, mgr.Init(ctx)) + + // Then: the tools from both servers should be collectively sorted stably. + validateToolOrder(t, mgr) +} + +func validateToolOrder(t *testing.T, proxy mcp.ServerProxier) { + t.Helper() + + tools := proxy.ListTools() + require.NotEmpty(t, tools) + require.Greater(t, len(tools), 1) + + // Ensure tools are sorted by ID; unstable order can bust the cache and lead to increased costs. + sorted := slices.Clone(tools) + slices.SortFunc(sorted, func(a, b *mcp.Tool) int { + return strings.Compare(a.ID, b.ID) + }) + for i, tool := range tools { + require.Equal(t, tool.ID, sorted[i].ID, "tool order is not stable") + } +} + +func createMockMCPSrv(t *testing.T) http.Handler { + t.Helper() + + s := server.NewMCPServer( + "Mock coder MCP server", + "1.0.0", + server.WithToolCapabilities(true), + ) + + for _, name := range []string{"coder_list_workspaces", "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user"} { + tool := mcplib.NewTool(name, + mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)), + ) + s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + return mcplib.NewToolResultText("mock"), nil + }) + } + + return server.NewStreamableHTTPServer(s) +} diff --git a/mcp/proxy_streamable_http.go b/mcp/proxy_streamable_http.go index eb2a7c7..5a5c092 100644 --- a/mcp/proxy_streamable_http.go +++ b/mcp/proxy_streamable_http.go @@ -89,7 +89,11 @@ func (p *StreamableHTTPServerProxy) Init(ctx context.Context) error { } func (p *StreamableHTTPServerProxy) ListTools() []*Tool { - return maps.Values(p.tools) + tools := maps.Values(p.tools) + slices.SortStableFunc(tools, func(a, b *Tool) int { + return strings.Compare(a.ID, b.ID) + }) + return tools } func (p *StreamableHTTPServerProxy) GetTool(name string) *Tool { diff --git a/mcp/server_proxy_manager.go b/mcp/server_proxy_manager.go index 2041433..732f1a0 100644 --- a/mcp/server_proxy_manager.go +++ b/mcp/server_proxy_manager.go @@ -3,6 +3,8 @@ package mcp import ( "context" "fmt" + "slices" + "strings" "sync" "github.com/coder/aibridge/utils" @@ -82,6 +84,11 @@ func (s *ServerProxyManager) ListTools() []*Tool { for _, tool := range s.tools { out = append(out, tool) } + + slices.SortStableFunc(out, func(a, b *Tool) int { + return strings.Compare(a.ID, b.ID) + }) + return out }