Skip to content

Commit 70248cc

Browse files
authored
feat: mark interceptions as completed (#43)
Adds logic that marks interceptions as completed.
1 parent 4754b1e commit 70248cc

File tree

6 files changed

+86
-10
lines changed

6 files changed

+86
-10
lines changed

api.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ type InterceptionRecord struct {
1616
StartedAt time.Time
1717
}
1818

19+
type InterceptionRecordEnded struct {
20+
ID string
21+
EndedAt time.Time
22+
}
23+
1924
type TokenUsageRecord struct {
2025
InterceptionID string
2126
MsgID string
@@ -48,6 +53,8 @@ type ToolUsageRecord struct {
4853
type Recorder interface {
4954
// RecordInterception records metadata about an interception with an upstream AI provider.
5055
RecordInterception(ctx context.Context, req *InterceptionRecord) error
56+
// RecordInterceptionEnded records that given interception has completed.
57+
RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) error
5158
// RecordTokenUsage records the tokens used in an interception with an upstream AI provider.
5259
RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) error
5360
// RecordPromptUsage records the prompts used in an interception with an upstream AI provider.

bridge_integration_test.go

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"net"
1212
"net/http"
1313
"net/http/httptest"
14+
"slices"
1415
"sync"
1516
"sync/atomic"
1617
"testing"
@@ -174,6 +175,8 @@ func TestAnthropicMessages(t *testing.T) {
174175

175176
require.Len(t, recorderClient.userPrompts, 1)
176177
assert.Equal(t, "read the foo file", recorderClient.userPrompts[0].Prompt)
178+
179+
recorderClient.verifyAllInterceptionsEnded(t)
177180
})
178181
}
179182
})
@@ -273,6 +276,8 @@ func TestOpenAIChatCompletions(t *testing.T) {
273276

274277
require.Len(t, recorderClient.userPrompts, 1)
275278
assert.Equal(t, "how large is the README.md file in my current path", recorderClient.userPrompts[0].Prompt)
279+
280+
recorderClient.verifyAllInterceptionsEnded(t)
276281
})
277282
}
278283
})
@@ -437,6 +442,8 @@ func TestSimple(t *testing.T) {
437442

438443
require.GreaterOrEqual(t, len(recorderClient.tokenUsages), 1)
439444
require.Equal(t, recorderClient.tokenUsages[0].MsgID, tc.expectedMsgID)
445+
446+
recorderClient.verifyAllInterceptionsEnded(t)
440447
})
441448
}
442449
})
@@ -574,8 +581,10 @@ func setupMCPServerProxiesForTest(t *testing.T) map[string]mcp.ServerProxier {
574581
return map[string]mcp.ServerProxier{proxy.Name(): proxy}
575582
}
576583

577-
type configureFunc func(string, aibridge.Recorder, *mcp.ServerProxyManager) (*aibridge.RequestBridge, error)
578-
type createRequestFunc func(*testing.T, string, []byte) *http.Request
584+
type (
585+
configureFunc func(string, aibridge.Recorder, *mcp.ServerProxyManager) (*aibridge.RequestBridge, error)
586+
createRequestFunc func(*testing.T, string, []byte) *http.Request
587+
)
579588

580589
func TestAnthropicInjectedTools(t *testing.T) {
581590
t.Parallel()
@@ -953,6 +962,7 @@ func TestErrorHandling(t *testing.T) {
953962
require.NoError(t, err)
954963

955964
tc.responseHandlerFn(streaming, resp)
965+
recorderClient.verifyAllInterceptionsEnded(t)
956966
})
957967
}
958968
})
@@ -1097,10 +1107,11 @@ var _ aibridge.Recorder = &mockRecorderClient{}
10971107
type mockRecorderClient struct {
10981108
mu sync.Mutex
10991109

1100-
interceptions []*aibridge.InterceptionRecord
1101-
tokenUsages []*aibridge.TokenUsageRecord
1102-
userPrompts []*aibridge.PromptUsageRecord
1103-
toolUsages []*aibridge.ToolUsageRecord
1110+
interceptions []*aibridge.InterceptionRecord
1111+
tokenUsages []*aibridge.TokenUsageRecord
1112+
userPrompts []*aibridge.PromptUsageRecord
1113+
toolUsages []*aibridge.ToolUsageRecord
1114+
interceptionsEnd map[string]time.Time
11041115
}
11051116

11061117
func (m *mockRecorderClient) RecordInterception(ctx context.Context, req *aibridge.InterceptionRecord) error {
@@ -1110,6 +1121,19 @@ func (m *mockRecorderClient) RecordInterception(ctx context.Context, req *aibrid
11101121
return nil
11111122
}
11121123

1124+
func (m *mockRecorderClient) RecordInterceptionEnded(ctx context.Context, req *aibridge.InterceptionRecordEnded) error {
1125+
m.mu.Lock()
1126+
defer m.mu.Unlock()
1127+
if m.interceptionsEnd == nil {
1128+
m.interceptionsEnd = make(map[string]time.Time)
1129+
}
1130+
if !slices.ContainsFunc(m.interceptions, func(intc *aibridge.InterceptionRecord) bool { return intc.ID == req.ID }) {
1131+
return fmt.Errorf("id not found")
1132+
}
1133+
m.interceptionsEnd[req.ID] = req.EndedAt
1134+
return nil
1135+
}
1136+
11131137
func (m *mockRecorderClient) RecordPromptUsage(ctx context.Context, req *aibridge.PromptUsageRecord) error {
11141138
m.mu.Lock()
11151139
defer m.mu.Unlock()
@@ -1131,6 +1155,18 @@ func (m *mockRecorderClient) RecordToolUsage(ctx context.Context, req *aibridge.
11311155
return nil
11321156
}
11331157

1158+
// verify all recorded interceptions has been marked as completed
1159+
func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) {
1160+
t.Helper()
1161+
1162+
m.mu.Lock()
1163+
defer m.mu.Unlock()
1164+
require.Equalf(t, len(m.interceptions), len(m.interceptionsEnd), "got %v interception ended calls, want: %v", len(m.interceptionsEnd), len(m.interceptions))
1165+
for _, intc := range m.interceptions {
1166+
require.Containsf(t, m.interceptionsEnd, intc.ID, "interception with id: %v has not been ended", intc.ID)
1167+
}
1168+
}
1169+
11341170
const mockToolName = "coder_list_workspaces"
11351171

11361172
func createMockMCPSrv(t *testing.T) http.Handler {

interception.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package aibridge
22

33
import (
4-
"context"
54
"errors"
65
"fmt"
76
"net/http"
@@ -67,10 +66,13 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder,
6766

6867
log := logger.With(slog.F("route", r.URL.Path), slog.F("provider", p.Name()), slog.F("interception_id", interceptor.ID()))
6968

70-
log.Debug(context.Background(), "started interception")
69+
log.Debug(r.Context(), "interception started")
7170
if err := interceptor.ProcessRequest(w, r); err != nil {
7271
log.Warn(r.Context(), "interception failed", slog.Error(err))
72+
} else {
73+
log.Debug(r.Context(), "interception ended")
7374
}
75+
asyncRecorder.RecordInterceptionEnded(r.Context(), &InterceptionRecordEnded{ID: interceptor.ID().String()})
7476

7577
// Ensure all recording have completed before completing request.
7678
asyncRecorder.Wait()

mcp/client_info.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ func GetClientInfo() mcp.Implementation {
1212
Name: "coder/aibridge",
1313
Version: buildinfo.Version(),
1414
}
15-
}
15+
}

mcp/client_info_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ func TestGetClientInfo(t *testing.T) {
1414
assert.NotEmpty(t, info.Version)
1515
// Version will either be a git revision, a semantic version, or a combination
1616
assert.NotEqual(t, "", info.Version)
17-
}
17+
}

recorder.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,21 @@ func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *Intercept
3333
return err
3434
}
3535

36+
func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) error {
37+
client, err := r.clientFn()
38+
if err != nil {
39+
return fmt.Errorf("acquire client: %w", err)
40+
}
41+
42+
req.EndedAt = time.Now().UTC()
43+
if err = client.RecordInterceptionEnded(ctx, req); err == nil {
44+
return nil
45+
}
46+
47+
r.logger.Warn(ctx, "failed to record that interception ended", slog.Error(err), slog.F("interception_id", req.ID))
48+
return err
49+
}
50+
3651
func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) error {
3752
client, err := r.clientFn()
3853
if err != nil {
@@ -103,6 +118,22 @@ func (a *AsyncRecorder) RecordInterception(ctx context.Context, req *Interceptio
103118
panic("RecordInterception must not be called asynchronously")
104119
}
105120

121+
func (a *AsyncRecorder) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) error {
122+
a.wg.Add(1)
123+
go func() {
124+
defer a.wg.Done()
125+
timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout)
126+
defer cancel()
127+
128+
err := a.wrapped.RecordInterceptionEnded(timedCtx, req)
129+
if err != nil {
130+
a.logger.Warn(timedCtx, "failed to record interception end", slog.F("type", "prompt"), slog.Error(err), slog.F("payload", req))
131+
}
132+
}()
133+
134+
return nil // Caller is not interested in error.
135+
}
136+
106137
func (a *AsyncRecorder) RecordPromptUsage(_ context.Context, req *PromptUsageRecord) error {
107138
a.wg.Add(1)
108139
go func() {

0 commit comments

Comments
 (0)