@@ -11,6 +11,7 @@ import (
1111 "net"
1212 "net/http"
1313 "net/http/httptest"
14+ "slices"
1415 "sync"
1516 "sync/atomic"
1617 "testing"
@@ -174,6 +175,8 @@ func TestAnthropicMessages(t *testing.T) {
174175
175176 require .Len (t , recorderClient .userPrompts , 1 )
176177 assert .Equal (t , "read the foo file" , recorderClient .userPrompts [0 ].Prompt )
178+
179+ recorderClient .verifyAllInterceptionsEnded (t )
177180 })
178181 }
179182 })
@@ -273,6 +276,8 @@ func TestOpenAIChatCompletions(t *testing.T) {
273276
274277 require .Len (t , recorderClient .userPrompts , 1 )
275278 assert .Equal (t , "how large is the README.md file in my current path" , recorderClient .userPrompts [0 ].Prompt )
279+
280+ recorderClient .verifyAllInterceptionsEnded (t )
276281 })
277282 }
278283 })
@@ -437,6 +442,8 @@ func TestSimple(t *testing.T) {
437442
438443 require .GreaterOrEqual (t , len (recorderClient .tokenUsages ), 1 )
439444 require .Equal (t , recorderClient .tokenUsages [0 ].MsgID , tc .expectedMsgID )
445+
446+ recorderClient .verifyAllInterceptionsEnded (t )
440447 })
441448 }
442449 })
@@ -574,8 +581,10 @@ func setupMCPServerProxiesForTest(t *testing.T) map[string]mcp.ServerProxier {
574581 return map [string ]mcp.ServerProxier {proxy .Name (): proxy }
575582}
576583
577- type configureFunc func (string , aibridge.Recorder , * mcp.ServerProxyManager ) (* aibridge.RequestBridge , error )
578- type createRequestFunc func (* testing.T , string , []byte ) * http.Request
584+ type (
585+ configureFunc func (string , aibridge.Recorder , * mcp.ServerProxyManager ) (* aibridge.RequestBridge , error )
586+ createRequestFunc func (* testing.T , string , []byte ) * http.Request
587+ )
579588
580589func TestAnthropicInjectedTools (t * testing.T ) {
581590 t .Parallel ()
@@ -953,6 +962,7 @@ func TestErrorHandling(t *testing.T) {
953962 require .NoError (t , err )
954963
955964 tc .responseHandlerFn (streaming , resp )
965+ recorderClient .verifyAllInterceptionsEnded (t )
956966 })
957967 }
958968 })
@@ -1097,10 +1107,11 @@ var _ aibridge.Recorder = &mockRecorderClient{}
10971107type mockRecorderClient struct {
10981108 mu sync.Mutex
10991109
1100- interceptions []* aibridge.InterceptionRecord
1101- tokenUsages []* aibridge.TokenUsageRecord
1102- userPrompts []* aibridge.PromptUsageRecord
1103- toolUsages []* aibridge.ToolUsageRecord
1110+ interceptions []* aibridge.InterceptionRecord
1111+ tokenUsages []* aibridge.TokenUsageRecord
1112+ userPrompts []* aibridge.PromptUsageRecord
1113+ toolUsages []* aibridge.ToolUsageRecord
1114+ interceptionsEnd map [string ]time.Time
11041115}
11051116
11061117func (m * mockRecorderClient ) RecordInterception (ctx context.Context , req * aibridge.InterceptionRecord ) error {
@@ -1110,6 +1121,19 @@ func (m *mockRecorderClient) RecordInterception(ctx context.Context, req *aibrid
11101121 return nil
11111122}
11121123
1124+ func (m * mockRecorderClient ) RecordInterceptionEnded (ctx context.Context , req * aibridge.InterceptionRecordEnded ) error {
1125+ m .mu .Lock ()
1126+ defer m .mu .Unlock ()
1127+ if m .interceptionsEnd == nil {
1128+ m .interceptionsEnd = make (map [string ]time.Time )
1129+ }
1130+ if ! slices .ContainsFunc (m .interceptions , func (intc * aibridge.InterceptionRecord ) bool { return intc .ID == req .ID }) {
1131+ return fmt .Errorf ("id not found" )
1132+ }
1133+ m .interceptionsEnd [req .ID ] = req .EndedAt
1134+ return nil
1135+ }
1136+
11131137func (m * mockRecorderClient ) RecordPromptUsage (ctx context.Context , req * aibridge.PromptUsageRecord ) error {
11141138 m .mu .Lock ()
11151139 defer m .mu .Unlock ()
@@ -1131,6 +1155,18 @@ func (m *mockRecorderClient) RecordToolUsage(ctx context.Context, req *aibridge.
11311155 return nil
11321156}
11331157
1158+ // verify all recorded interceptions has been marked as completed
1159+ func (m * mockRecorderClient ) verifyAllInterceptionsEnded (t * testing.T ) {
1160+ t .Helper ()
1161+
1162+ m .mu .Lock ()
1163+ defer m .mu .Unlock ()
1164+ require .Equalf (t , len (m .interceptions ), len (m .interceptionsEnd ), "got %v interception ended calls, want: %v" , len (m .interceptionsEnd ), len (m .interceptions ))
1165+ for _ , intc := range m .interceptions {
1166+ require .Containsf (t , m .interceptionsEnd , intc .ID , "interception with id: %v has not been ended" , intc .ID )
1167+ }
1168+ }
1169+
11341170const mockToolName = "coder_list_workspaces"
11351171
11361172func createMockMCPSrv (t * testing.T ) http.Handler {
0 commit comments