Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 124 additions & 6 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,122 @@ func TestErrorHandling(t *testing.T) {
}
}

// TestStableRequestEncoding validates that a given intercepted request and a
// given set of injected tools should result identical payloads.
//
// Should the payload vary, it may subvert any caching mechanisms the provider may have.
func TestStableRequestEncoding(t *testing.T) {
t.Parallel()

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)

cases := []struct {
name string
fixture []byte
createRequestFunc createRequestFunc
configureFunc configureFunc
}{
{
name: aibridge.ProviderAnthropic,
fixture: antSimple,
createRequestFunc: createAnthropicMessagesReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr)
},
},
{
name: aibridge.ProviderOpenAI,
fixture: oaiSimple,
createRequestFunc: createOpenAIChatCompletionsReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
},
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)

// Setup MCP tools.
tools := setupMCPServerProxiesForTest(t)

// Configure the bridge with injected tools.
mcpMgr := mcp.NewServerProxyManager(tools)
require.NoError(t, mcpMgr.Init(ctx))

arc := txtar.Parse(tc.fixture)
t.Logf("%s: %s", t.Name(), arc.Comment)

files := filesMap(arc)
require.Contains(t, files, fixtureRequest)
require.Contains(t, files, fixtureNonStreamingResponse)

var (
reference []byte
reqCount atomic.Int32
)

// Create a mock server that captures and compares request bodies.
mockSrv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqCount.Add(1)

// Capture the raw request body.
raw, err := io.ReadAll(r.Body)
defer r.Body.Close()
require.NoError(t, err)
require.NotEmpty(t, raw)

// Store the first instance as the reference value.
if reference == nil {
reference = raw
} else {
// Compare all subsequent requests to the reference.
assert.JSONEq(t, string(reference), string(raw))
}

// Return a valid API response.
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(files[fixtureNonStreamingResponse])
}))
mockSrv.Config.BaseContext = func(_ net.Listener) context.Context {
return ctx
}
mockSrv.Start()
t.Cleanup(mockSrv.Close)

recorder := &mockRecorderClient{}
bridge, err := tc.configureFunc(mockSrv.URL, recorder, mcpMgr)
require.NoError(t, err)

// Invoke request to mocked API via aibridge.
bridgeSrv := httptest.NewUnstartedServer(bridge)
bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
return aibridge.AsActor(ctx, userID, nil)
}
bridgeSrv.Start()
t.Cleanup(bridgeSrv.Close)

// Make multiple requests and verify they all have identical payloads.
count := 10
for range count {
req := tc.createRequestFunc(t, bridgeSrv.URL, files[fixtureRequest])
client := &http.Client{}
resp, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
_ = resp.Body.Close()
}

require.EqualValues(t, count, reqCount.Load())
})
}
}

func calculateTotalInputTokens(in []*aibridge.TokenUsageRecord) int64 {
var total int64
for _, el := range in {
Expand Down Expand Up @@ -1340,12 +1456,14 @@ func createMockMCPSrv(t *testing.T) http.Handler {
server.WithToolCapabilities(true),
)

tool := mcplib.NewTool(mockToolName,
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", mockToolName)),
)
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
return mcplib.NewToolResultText("mock"), nil
})
for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} {
tool := mcplib.NewTool(name,
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
)
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
return mcplib.NewToolResultText("mock"), nil
})
}

return server.NewStreamableHTTPServer(s)
}
Expand Down
86 changes: 86 additions & 0 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
package mcp_test

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"regexp"
"slices"
"strings"
"testing"
"time"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"go.uber.org/goleak"

"github.com/coder/aibridge/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/require"

mcplib "github.com/mark3labs/mcp-go/mcp"
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -282,3 +293,78 @@ func TestFilterAllowedTools(t *testing.T) {
})
}
}

func TestToolInjectionOrder(t *testing.T) {
t.Parallel()

// Setup.
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)

// Given: a MCP mock server offering a set of tools.
mcpSrv := httptest.NewServer(createMockMCPSrv(t))
t.Cleanup(mcpSrv.Close)

// When: creating two MCP server proxies, both listing the same tools by name but under different server namespaces.
proxy, err := mcp.NewStreamableHTTPServerProxy(logger, "coder", mcpSrv.URL, nil, nil, nil)
require.NoError(t, err)
proxy2, err := mcp.NewStreamableHTTPServerProxy(logger, "shmoder", mcpSrv.URL, nil, nil, nil)
require.NoError(t, err)

// Then: initialize both proxies.
require.NoError(t, proxy.Init(ctx))
require.NoError(t, proxy2.Init(ctx))

// Then: validate that their tools are separately sorted stably.
validateToolOrder(t, proxy)
validateToolOrder(t, proxy2)

// When: creating a manager which contains both MCP server proxies.
mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{
"coder": proxy,
"shmoder": proxy2,
})
require.NoError(t, mgr.Init(ctx))

// Then: the tools from both servers should be collectively sorted stably.
validateToolOrder(t, mgr)
}

func validateToolOrder(t *testing.T, proxy mcp.ServerProxier) {
t.Helper()

tools := proxy.ListTools()
require.NotEmpty(t, tools)
require.Greater(t, len(tools), 1)

// Ensure tools are sorted by ID; unstable order can bust the cache and lead to increased costs.
sorted := slices.Clone(tools)
slices.SortFunc(sorted, func(a, b *mcp.Tool) int {
return strings.Compare(a.ID, b.ID)
})
for i, tool := range tools {
require.Equal(t, tool.ID, sorted[i].ID, "tool order is not stable")
}
}

func createMockMCPSrv(t *testing.T) http.Handler {
t.Helper()

s := server.NewMCPServer(
"Mock coder MCP server",
"1.0.0",
server.WithToolCapabilities(true),
)

for _, name := range []string{"coder_list_workspaces", "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user"} {
tool := mcplib.NewTool(name,
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
)
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
return mcplib.NewToolResultText("mock"), nil
})
}

return server.NewStreamableHTTPServer(s)
}
6 changes: 5 additions & 1 deletion mcp/proxy_streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ func (p *StreamableHTTPServerProxy) Init(ctx context.Context) error {
}

func (p *StreamableHTTPServerProxy) ListTools() []*Tool {
return maps.Values(p.tools)
tools := maps.Values(p.tools)
slices.SortStableFunc(tools, func(a, b *Tool) int {
return strings.Compare(a.ID, b.ID)
})
return tools
}

func (p *StreamableHTTPServerProxy) GetTool(name string) *Tool {
Expand Down
7 changes: 7 additions & 0 deletions mcp/server_proxy_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package mcp
import (
"context"
"fmt"
"slices"
"strings"
"sync"

"github.com/coder/aibridge/utils"
Expand Down Expand Up @@ -82,6 +84,11 @@ func (s *ServerProxyManager) ListTools() []*Tool {
for _, tool := range s.tools {
out = append(out, tool)
}

slices.SortStableFunc(out, func(a, b *Tool) int {
return strings.Compare(a.ID, b.ID)
})

return out
}

Expand Down