Skip to content

Commit 76365a6

Browse files
authored
fix: use io.LimitReader in key fetcher (#558)
## Motivation and Context Replace server-only `http.MaxBytesReader(nil, ...)` with `io.LimitReader` to cap the fetched body to 4KB in key fetcher. This avoids passing a `nil` ResponseWriter and aligns with Go idioms. If more than 4KB is received, return a clear error. Previously `MaxBytesReader` returned a misleading `http: request body too large` error, which is server-oriented. With this change the error becomes `HTTP auth key response too large`. ## How Has This Been Tested? <!-- Have you tested this in a real application? Which scenarios were tested? --> Add tests covering success, oversized bodies, connection failure, non-200 status, and read failures. This covers almost all lines of `FetchKey` method of `DefaultHTTPKeyFetcher`. For tests, I also added a new initialiser `NewDefaultHTTPKeyFetcherWithClient` so that the HTTP client could be modified for test use. Tests use `httptest.NewTLSServer` plus a helper to construct a client pinned to the test server. `auth` package test coverage increased from 56.8% to 59.3%. ## Breaking Changes <!-- Will users need to update their code or configurations? --> None. ## Types of changes <!-- What types of changes does your code introduce? Put an `x` in all the boxes that apply: --> - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update ## Checklist <!-- Go over all the following points, and put an `x` in all the boxes that apply. --> - [x] I have read the [MCP Documentation](https://modelcontextprotocol.io) - [x] My code follows the repository's style guidelines - [x] New and existing tests pass locally - [x] I have added appropriate error handling - [ ] I have added or updated documentation as needed ## Additional context <!-- Add any other context, implementation notes, or design decisions --> Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
1 parent f0cce85 commit 76365a6

File tree

2 files changed

+159
-4
lines changed

2 files changed

+159
-4
lines changed

internal/api/handlers/v0/auth/http.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ import (
1818
"github.com/modelcontextprotocol/registry/internal/config"
1919
)
2020

21+
// MaxKeyResponseSize is the maximum size of the response body from the HTTP endpoint.
22+
const MaxKeyResponseSize = 4096
23+
2124
// HTTPTokenExchangeInput represents the input for HTTP-based authentication
2225
type HTTPTokenExchangeInput struct {
2326
Body struct {
@@ -51,6 +54,12 @@ func NewDefaultHTTPKeyFetcher() *DefaultHTTPKeyFetcher {
5154
}
5255
}
5356

57+
// NewDefaultHTTPKeyFetcherWithClient creates a new HTTP key fetcher with a custom HTTP client.
58+
// This is primarily useful in tests to inject transports or TLS settings.
59+
func NewDefaultHTTPKeyFetcherWithClient(client *http.Client) *DefaultHTTPKeyFetcher {
60+
return &DefaultHTTPKeyFetcher{client: client}
61+
}
62+
5463
// FetchKey fetches the public key from the well-known HTTP endpoint
5564
func (f *DefaultHTTPKeyFetcher) FetchKey(ctx context.Context, domain string) (string, error) {
5665
url := fmt.Sprintf("https://%s/.well-known/mcp-registry-auth", domain)
@@ -73,13 +82,16 @@ func (f *DefaultHTTPKeyFetcher) FetchKey(ctx context.Context, domain string) (st
7382
return "", fmt.Errorf("HTTP %d: failed to fetch key from %s", resp.StatusCode, url)
7483
}
7584

76-
// Limit response size to prevent DoS attacks
77-
resp.Body = http.MaxBytesReader(nil, resp.Body, 4096)
78-
79-
body, err := io.ReadAll(resp.Body)
85+
// Limit response size to prevent DoS attacks.
86+
// Read up to MaxKeyResponseSize+1 and error if exceeded.
87+
limited := io.LimitReader(resp.Body, MaxKeyResponseSize+1)
88+
body, err := io.ReadAll(limited)
8089
if err != nil {
8190
return "", fmt.Errorf("failed to read response body: %w", err)
8291
}
92+
if len(body) > MaxKeyResponseSize {
93+
return "", fmt.Errorf("HTTP auth key response too large")
94+
}
8395

8496
return strings.TrimSpace(string(body)), nil
8597
}

internal/api/handlers/v0/auth/http_test.go

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
package auth_test
22

33
import (
4+
"bytes"
45
"context"
56
"crypto/ed25519"
7+
"crypto/tls"
68
"encoding/base64"
79
"encoding/hex"
810
"fmt"
11+
"net"
12+
"net/http"
13+
"net/http/httptest"
914
"strings"
1015
"testing"
1116
"time"
@@ -18,6 +23,26 @@ import (
1823
"github.com/modelcontextprotocol/registry/internal/config"
1924
)
2025

26+
const wellKnownPath = "/.well-known/mcp-registry-auth"
27+
28+
func newClientForTLSServer(t *testing.T, srv *httptest.Server) *http.Client {
29+
t.Helper()
30+
31+
dialAddr := srv.Listener.Addr().String()
32+
transport := &http.Transport{
33+
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // testing only
34+
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
35+
d := &net.Dialer{}
36+
return d.DialContext(ctx, network, dialAddr)
37+
},
38+
ForceAttemptHTTP2: false,
39+
MaxIdleConns: 10,
40+
IdleConnTimeout: 30 * time.Second,
41+
TLSHandshakeTimeout: 5 * time.Second,
42+
}
43+
return &http.Client{Transport: transport, Timeout: 10 * time.Second}
44+
}
45+
2146
// MockHTTPKeyFetcher for testing
2247
type MockHTTPKeyFetcher struct {
2348
keyResponses map[string]string
@@ -241,3 +266,121 @@ func TestDefaultHTTPKeyFetcher_FetchKey(t *testing.T) {
241266
_, err := fetcher.FetchKey(context.Background(), "nonexistent-test-domain-12345.com")
242267
assert.Error(t, err)
243268
}
269+
270+
func TestDefaultHTTPKeyFetcher(t *testing.T) {
271+
tests := []struct {
272+
name string
273+
handler http.HandlerFunc
274+
wantErrSub string
275+
customClient *http.Client
276+
expectOK bool
277+
wantBody string
278+
}{
279+
{
280+
name: "oversized body",
281+
handler: func(w http.ResponseWriter, r *http.Request) {
282+
if r.URL.Path != wellKnownPath {
283+
w.WriteHeader(http.StatusNotFound)
284+
return
285+
}
286+
w.WriteHeader(http.StatusOK)
287+
w.Header().Set("Content-Type", "text/plain")
288+
_, _ = w.Write(bytes.Repeat([]byte("A"), 6000))
289+
},
290+
wantErrSub: "too large",
291+
},
292+
{
293+
name: "non-OK status",
294+
handler: func(w http.ResponseWriter, r *http.Request) {
295+
if r.URL.Path == wellKnownPath {
296+
w.WriteHeader(http.StatusInternalServerError)
297+
return
298+
}
299+
w.WriteHeader(http.StatusNotFound)
300+
},
301+
wantErrSub: "HTTP 500",
302+
},
303+
{
304+
name: "connection failure",
305+
handler: nil,
306+
customClient: &http.Client{
307+
Transport: &http.Transport{
308+
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // testing only
309+
DialContext: func(context.Context, string, string) (net.Conn, error) {
310+
return nil, fmt.Errorf("dial blocked")
311+
},
312+
},
313+
Timeout: 5 * time.Second,
314+
},
315+
wantErrSub: "failed to fetch key",
316+
},
317+
{
318+
name: "response body failure",
319+
handler: func(w http.ResponseWriter, r *http.Request) {
320+
if r.URL.Path != wellKnownPath {
321+
w.WriteHeader(http.StatusNotFound)
322+
return
323+
}
324+
w.Header().Set("Content-Type", "text/plain")
325+
w.Header().Set("Content-Length", "100")
326+
w.WriteHeader(http.StatusOK)
327+
_, _ = w.Write([]byte("PARTIAL"))
328+
},
329+
wantErrSub: "failed to read response body",
330+
},
331+
{
332+
name: "success",
333+
handler: func(w http.ResponseWriter, r *http.Request) {
334+
if r.URL.Path != wellKnownPath {
335+
w.WriteHeader(http.StatusNotFound)
336+
return
337+
}
338+
w.Header().Set("Content-Type", "text/plain")
339+
w.WriteHeader(http.StatusOK)
340+
_, _ = w.Write([]byte("response"))
341+
},
342+
expectOK: true,
343+
wantBody: "response",
344+
},
345+
}
346+
347+
for _, tt := range tests {
348+
t.Run(tt.name, func(t *testing.T) {
349+
if tt.handler != nil {
350+
srv := httptest.NewTLSServer(tt.handler)
351+
defer srv.Close()
352+
c := newClientForTLSServer(t, srv)
353+
f := auth.NewDefaultHTTPKeyFetcherWithClient(c)
354+
got, err := f.FetchKey(context.Background(), "example.com")
355+
if tt.expectOK {
356+
if err != nil {
357+
t.Fatalf("unexpected error: %v", err)
358+
}
359+
if got != tt.wantBody {
360+
t.Fatalf("unexpected body: got %q want %q", got, tt.wantBody)
361+
}
362+
return
363+
}
364+
if err == nil || !strings.Contains(err.Error(), tt.wantErrSub) {
365+
t.Fatalf("got err=%v, want substring %q", err, tt.wantErrSub)
366+
}
367+
return
368+
}
369+
370+
f := auth.NewDefaultHTTPKeyFetcherWithClient(tt.customClient)
371+
got, err := f.FetchKey(context.Background(), "example.com")
372+
if tt.expectOK {
373+
if err != nil {
374+
t.Fatalf("unexpected error: %v", err)
375+
}
376+
if got != tt.wantBody {
377+
t.Fatalf("unexpected body: got %q want %q", got, tt.wantBody)
378+
}
379+
return
380+
}
381+
if err == nil || !strings.Contains(err.Error(), tt.wantErrSub) {
382+
t.Fatalf("got err=%v, want substring %q", err, tt.wantErrSub)
383+
}
384+
})
385+
}
386+
}

0 commit comments

Comments
 (0)