Skip to content

Commit 6e297fb

Browse files
committed
optimize(transport): optimize resumable transport implementaion
1 parent 2600811 commit 6e297fb

File tree

2 files changed

+153
-82
lines changed

2 files changed

+153
-82
lines changed

pkg/v1/remote/transport/resumable.go

Lines changed: 65 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,51 +22,70 @@ func NewResumable(inner http.RoundTripper) http.RoundTripper {
2222

2323
var (
2424
contentRangeRe = regexp.MustCompile(`^bytes (\d+)-(\d+)/(\d+|\*)$`)
25+
rangeRe = regexp.MustCompile(`bytes=(\d+)-(\d+)?`)
2526
)
2627

2728
type resumableTransport struct {
2829
inner http.RoundTripper
2930
}
3031

31-
func (rt *resumableTransport) RoundTrip(in *http.Request) (*http.Response, error) {
32-
if in.Method != http.MethodGet {
33-
return rt.inner.RoundTrip(in)
32+
func (rt *resumableTransport) RoundTrip(in *http.Request) (resp *http.Response, err error) {
33+
var total, start, end int64
34+
// check initial request, maybe resumable transport is already enabled
35+
if contentRange := in.Header.Get("Range"); contentRange != "" {
36+
if matches := rangeRe.FindStringSubmatch(contentRange); len(matches) == 3 {
37+
if start, err = strconv.ParseInt(matches[1], 10, 64); err != nil {
38+
return nil, fmt.Errorf("invalid content range %q: %w", contentRange, err)
39+
}
40+
41+
if len(matches[2]) == 0 {
42+
// request whole file
43+
end = -1
44+
} else if end, err = strconv.ParseInt(matches[2], 10, 64); err == nil {
45+
if start > end {
46+
return nil, fmt.Errorf("invalid content range %q", contentRange)
47+
}
48+
} else {
49+
return nil, fmt.Errorf("invalid content range %q: %w", contentRange, err)
50+
}
51+
}
3452
}
3553

36-
req := in.Clone(in.Context())
37-
req.Header.Set("Range", "bytes=0-")
38-
resp, err := rt.inner.RoundTrip(req)
39-
if err != nil {
54+
if resp, err = rt.inner.RoundTrip(in); err != nil {
4055
return resp, err
4156
}
4257

58+
if in.Method != http.MethodGet {
59+
return resp, nil
60+
}
61+
4362
switch resp.StatusCode {
63+
case http.StatusOK:
64+
if end != 0 {
65+
// request range content, but unexpected status code, cant not resume for this request
66+
return resp, nil
67+
}
68+
69+
total = resp.ContentLength
4470
case http.StatusPartialContent:
45-
case http.StatusRequestedRangeNotSatisfiable:
46-
// fallback to previous behavior
47-
resp.Body.Close()
48-
return rt.inner.RoundTrip(in)
71+
// keep original response status code, which should be processed by original transport or operation
72+
if start, _, total, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil || total <= 0 {
73+
return resp, nil
74+
} else if end > 0 {
75+
total = end + 1
76+
}
4977
default:
5078
return resp, nil
5179
}
5280

53-
var contentLength int64
54-
if _, _, contentLength, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil || contentLength <= 0 {
55-
// fallback to previous behavior
56-
resp.Body.Close()
57-
return rt.inner.RoundTrip(in)
58-
}
59-
60-
// modify response status to 200, ensure caller error checking works
61-
resp.StatusCode = http.StatusOK
62-
resp.Status = "200 OK"
63-
resp.ContentLength = contentLength
64-
resp.Body = &resumableBody{
65-
rc: resp.Body,
66-
inner: rt.inner,
67-
req: req,
68-
total: contentLength,
69-
transferred: 0,
81+
if total > 0 {
82+
resp.Body = &resumableBody{
83+
rc: resp.Body,
84+
inner: rt.inner,
85+
req: in,
86+
total: total,
87+
transferred: start,
88+
}
7089
}
7190

7291
return resp, nil
@@ -94,14 +113,18 @@ func (rb *resumableBody) Read(p []byte) (n int, err error) {
94113

95114
for {
96115
if n, err = rb.rc.Read(p); n > 0 {
116+
if rb.transferred+int64(n) >= rb.total {
117+
n = int(rb.total - rb.transferred)
118+
err = io.EOF
119+
}
97120
rb.transferred += int64(n)
98121
}
99122

100123
if err == nil {
101124
return
102125
}
103126

104-
if errors.Is(err, io.EOF) && rb.total >= 0 && rb.transferred == rb.total {
127+
if errors.Is(err, io.EOF) && rb.total >= 0 && rb.transferred >= rb.total {
105128
return
106129
}
107130

@@ -148,7 +171,8 @@ func (rb *resumableBody) resume(reason error) error {
148171

149172
if err = rb.validate(resp); err != nil {
150173
resp.Body.Close()
151-
return err
174+
// wraps original error
175+
return fmt.Errorf("%w, %v", reason, err)
152176
}
153177

154178
if atomic.LoadUint32(&rb.closed) == 1 {
@@ -162,21 +186,21 @@ func (rb *resumableBody) resume(reason error) error {
162186
return nil
163187
}
164188

189+
const size100m = 100 << 20
190+
165191
func (rb *resumableBody) validate(resp *http.Response) (err error) {
166192
var start, total int64
167193
switch resp.StatusCode {
168194
case http.StatusPartialContent:
169-
if start, _, total, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil {
195+
// donot using total size from Content-Range header, keep rb.total unchanged
196+
if start, _, _, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil {
170197
return err
171198
}
172199

173-
if total > rb.total {
174-
rb.total = total
175-
}
176-
177200
if start == rb.transferred {
178201
break
179202
} else if start < rb.transferred {
203+
// incoming data is overlapped for somehow, just discard it
180204
if _, err := io.CopyN(io.Discard, resp.Body, rb.transferred-start); err != nil {
181205
return fmt.Errorf("discard overlapped data failed, %v", err)
182206
}
@@ -185,6 +209,12 @@ func (rb *resumableBody) validate(resp *http.Response) (err error) {
185209
}
186210
case http.StatusOK:
187211
if rb.transferred > 0 {
212+
// range is not supported, and transferred data is too large, stop resuming
213+
if rb.transferred > size100m {
214+
return fmt.Errorf("too large data transferred: %d", rb.transferred)
215+
}
216+
217+
// try resume from unsupported range request
188218
if _, err = io.CopyN(io.Discard, resp.Body, rb.transferred); err != nil {
189219
return err
190220
}

pkg/v1/remote/transport/resumable_test.go

Lines changed: 88 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package transport
22

33
import (
4+
"bytes"
45
"context"
56
"crypto/sha256"
67
"encoding/hex"
@@ -11,7 +12,6 @@ import (
1112
"net/http/httptest"
1213
"net/url"
1314
"os"
14-
"regexp"
1515
"strconv"
1616
"testing"
1717
"time"
@@ -23,50 +23,52 @@ import (
2323
"github.com/google/go-containerregistry/pkg/v1/types"
2424
)
2525

26-
var rangeRe = regexp.MustCompile(`bytes=(\d+)-(\d+)?`)
27-
2826
func handleResumableLayer(data []byte, w http.ResponseWriter, r *http.Request, t *testing.T) {
2927
if r.Method != http.MethodGet {
3028
w.WriteHeader(http.StatusMethodNotAllowed)
3129
return
3230
}
3331

34-
contentRange := r.Header.Get("Range")
35-
if contentRange == "" {
36-
w.WriteHeader(http.StatusBadRequest)
37-
return
38-
}
39-
40-
matches := rangeRe.FindStringSubmatch(contentRange)
41-
if len(matches) != 3 {
42-
w.WriteHeader(http.StatusBadRequest)
43-
return
44-
}
45-
46-
contentLength := int64(len(data))
47-
start, err := strconv.ParseInt(matches[1], 10, 64)
48-
if err != nil || start < 0 {
49-
w.WriteHeader(http.StatusBadRequest)
50-
return
51-
}
32+
var (
33+
contentLength, start, end int64
34+
statusCode = http.StatusOK
35+
err error
36+
)
5237

53-
if start >= int64(contentLength) {
54-
w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
55-
return
56-
}
38+
contentLength = int64(len(data))
39+
end = contentLength - 1
40+
contentRange := r.Header.Get("Range")
41+
if contentRange != "" {
42+
matches := rangeRe.FindStringSubmatch(contentRange)
43+
if len(matches) != 3 {
44+
w.WriteHeader(http.StatusBadRequest)
45+
return
46+
}
5747

58-
var end = int64(contentLength) - 1
59-
if matches[2] != "" {
60-
end, err = strconv.ParseInt(matches[2], 10, 64)
61-
if err != nil || end < 0 {
48+
if start, err = strconv.ParseInt(matches[1], 10, 64); err != nil || start < 0 {
6249
w.WriteHeader(http.StatusBadRequest)
6350
return
6451
}
6552

66-
if end >= int64(contentLength) {
53+
if start >= int64(contentLength) {
6754
w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
6855
return
6956
}
57+
58+
if matches[2] != "" {
59+
end, err = strconv.ParseInt(matches[2], 10, 64)
60+
if err != nil || end < 0 {
61+
w.WriteHeader(http.StatusBadRequest)
62+
return
63+
}
64+
65+
if end >= int64(contentLength) {
66+
w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
67+
return
68+
}
69+
}
70+
71+
statusCode = http.StatusPartialContent
7072
}
7173

7274
var currentContentLength = end - start + 1
@@ -91,14 +93,19 @@ func handleResumableLayer(data []byte, w http.ResponseWriter, r *http.Request, t
9193

9294
end = start + currentContentLength - 1
9395

94-
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, contentLength))
95-
w.Header().Set("Content-Length", strconv.FormatInt(currentContentLength, 10))
96-
w.WriteHeader(http.StatusPartialContent)
96+
if statusCode == http.StatusPartialContent {
97+
w.Header().Set("Content-Length", strconv.FormatInt(currentContentLength, 10))
98+
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, contentLength))
99+
} else {
100+
w.Header().Set("Content-Length", strconv.FormatInt(contentLength, 10))
101+
}
102+
103+
w.WriteHeader(statusCode)
97104
w.Write(data[start : end+1])
98105
time.Sleep(time.Second)
99106
}
100107

101-
func resumableRequest(client *http.Client, url string, size int64, digest string, overlap bool, t *testing.T) {
108+
func resumableRequest(client *http.Client, url string, leading, trailing []byte, size int64, digest string, overlap bool, t *testing.T) {
102109
req, err := http.NewRequest(http.MethodGet, url, http.NoBody)
103110
if err != nil {
104111
t.Fatalf("http.NewRequest(): %v", err)
@@ -108,6 +115,16 @@ func resumableRequest(client *http.Client, url string, size int64, digest string
108115
req.Header.Set("X-Overlap", "true")
109116
}
110117

118+
if len(leading) > 0 || len(trailing) > 0 {
119+
var buf bytes.Buffer
120+
buf.WriteString("bytes=")
121+
buf.WriteString(fmt.Sprintf("%d-", len(leading)))
122+
if len(trailing) > 0 {
123+
buf.WriteString(fmt.Sprintf("%d", size-int64(len(trailing))-1))
124+
}
125+
req.Header.Set("Range", buf.String())
126+
}
127+
111128
resp, err := client.Do(req.WithContext(t.Context()))
112129
if err != nil {
113130
t.Fatalf("client.Do(): %v", err)
@@ -120,12 +137,19 @@ func resumableRequest(client *http.Client, url string, size int64, digest string
120137
}
121138

122139
hash := sha256.New()
140+
if len(leading) > 0 {
141+
io.Copy(hash, bytes.NewReader(leading))
142+
}
123143

124144
if _, err = io.Copy(hash, resp.Body); err != nil {
125145
t.Errorf("unexpected error: %v", err)
126146
return
127147
}
128148

149+
if len(trailing) > 0 {
150+
io.Copy(hash, bytes.NewReader(trailing))
151+
}
152+
129153
actualDigest := "sha256:" + hex.EncodeToString(hash.Sum(nil))
130154

131155
if actualDigest != digest {
@@ -248,23 +272,40 @@ func TestResumableTransport(t *testing.T) {
248272
}
249273

250274
tests := []struct {
251-
name string
252-
digest string
253-
size int64
254-
timeout bool
255-
cancel bool
256-
nonResumable bool
257-
overlap bool
275+
name string
276+
digest string
277+
leading, trailing int64
278+
timeout bool
279+
cancel bool
280+
nonResumable bool
281+
overlap bool
282+
ranged bool
258283
}{
259284
{
260-
name: "resumable",
261-
digest: digest.String(),
262-
size: size,
285+
name: "resumable",
286+
digest: digest.String(),
287+
leading: 0,
288+
},
289+
{
290+
name: "resumable-range-leading",
291+
digest: digest.String(),
292+
leading: 3,
293+
},
294+
{
295+
name: "resumable-range-trailing",
296+
digest: digest.String(),
297+
leading: 0,
298+
},
299+
{
300+
name: "resumable-range-leading-trailing",
301+
digest: digest.String(),
302+
leading: 3,
303+
trailing: 6,
263304
},
264305
{
265306
name: "resumable-overlap",
266307
digest: digest.String(),
267-
size: size,
308+
leading: 0,
268309
overlap: true,
269310
},
270311
{
@@ -290,8 +331,8 @@ func TestResumableTransport(t *testing.T) {
290331
resumableStopByCancelRequest(client, url, t)
291332
} else if tt.timeout {
292333
resumableStopByTimeoutRequest(client, url, t)
293-
} else if tt.digest != "" && tt.size > 0 {
294-
resumableRequest(client, url, tt.size, tt.digest, tt.overlap, t)
334+
} else if tt.digest != "" {
335+
resumableRequest(client, url, data[:tt.leading], data[size-tt.trailing:], size, tt.digest, tt.overlap, t)
295336
}
296337
})
297338
}

0 commit comments

Comments
 (0)