From 619c7ca34b1fb005276ee4f120efd9f12d03d2ee Mon Sep 17 00:00:00 2001 From: Derek Ray Date: Wed, 12 Nov 2025 16:50:07 +0800 Subject: [PATCH 1/4] feat(transport): add resumable transport for remote resources --- cmd/crane/cmd/pull.go | 6 + pkg/v1/remote/image_test.go | 44 ++++ pkg/v1/remote/options.go | 13 + pkg/v1/remote/transport/resumable.go | 242 ++++++++++++++++++ pkg/v1/remote/transport/resumable_test.go | 298 ++++++++++++++++++++++ 5 files changed, 603 insertions(+) create mode 100644 pkg/v1/remote/transport/resumable.go create mode 100644 pkg/v1/remote/transport/resumable_test.go diff --git a/cmd/crane/cmd/pull.go b/cmd/crane/cmd/pull.go index 41c6e95cd..4dd3a3360 100644 --- a/cmd/crane/cmd/pull.go +++ b/cmd/crane/cmd/pull.go @@ -32,6 +32,7 @@ func NewCmdPull(options *[]crane.Option) *cobra.Command { var ( cachePath, format string annotateRef bool + resumable bool ) cmd := &cobra.Command{ @@ -49,6 +50,10 @@ func NewCmdPull(options *[]crane.Option) *cobra.Command { return fmt.Errorf("parsing reference %q: %w", src, err) } + if resumable { + o.Remote = append(o.Remote, remote.WithResumable()) + } + rmt, err := remote.Get(ref, o.Remote...) if err != nil { return err @@ -133,6 +138,7 @@ func NewCmdPull(options *[]crane.Option) *cobra.Command { cmd.Flags().StringVarP(&cachePath, "cache_path", "c", "", "Path to cache image layers") cmd.Flags().StringVar(&format, "format", "tarball", fmt.Sprintf("Format in which to save images (%q, %q, or %q)", "tarball", "legacy", "oci")) cmd.Flags().BoolVar(&annotateRef, "annotate-ref", false, "Preserves image reference used to pull as an annotation when used with --format=oci") + cmd.Flags().BoolVar(&resumable, "resumable", false, "Enable resumable transport for pulling images") return cmd } diff --git a/pkg/v1/remote/image_test.go b/pkg/v1/remote/image_test.go index f15e96a6d..302076a15 100644 --- a/pkg/v1/remote/image_test.go +++ b/pkg/v1/remote/image_test.go @@ -17,6 +17,8 @@ package remote import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io" @@ -747,3 +749,45 @@ func TestData(t *testing.T) { t.Fatal(err) } } + +func TestImageResumable(t *testing.T) { + ref, err := name.ParseReference("ghcr.io/labring/fastgpt:v4.9.0") + if err != nil { + t.Fatal(err) + } + + image, err := Image(ref, WithResumable()) + if err != nil { + t.Fatal(err) + } + + layers, err := image.Layers() + if err != nil { + t.Fatal(err) + } + + for _, layer := range layers { + digest, err := layer.Digest() + if err != nil { + t.Fatal(err) + } + + rc, err := layer.Compressed() + if err != nil { + t.Fatal(err) + } + + hash := sha256.New() + _, err = io.Copy(hash, rc) + rc.Close() + if err != nil { + t.Fatal(err) + } + + if digest.Hex == hex.EncodeToString(hash.Sum(nil)) { + t.Logf("digest matches: %s", digest) + } else { + t.Errorf("digest mismatch: %s != %s", digest, hex.EncodeToString(hash.Sum(nil))) + } + } +} diff --git a/pkg/v1/remote/options.go b/pkg/v1/remote/options.go index 15b7da1e4..5f408dd7f 100644 --- a/pkg/v1/remote/options.go +++ b/pkg/v1/remote/options.go @@ -45,6 +45,7 @@ type options struct { retryBackoff Backoff retryPredicate retry.Predicate retryStatusCodes []int + resumable bool // Only these options can overwrite Reuse()d options. platform v1.Platform @@ -170,6 +171,11 @@ func makeOptions(opts ...Option) (*options, error) { // Wrap the transport in something that can retry network flakes. o.transport = transport.NewRetry(o.transport, transport.WithRetryBackoff(o.retryBackoff), transport.WithRetryPredicate(predicate), transport.WithRetryStatusCodes(o.retryStatusCodes...)) + + if o.resumable { + o.transport = transport.NewResumable(o.transport) + } + // Wrap this last to prevent transport.New from double-wrapping. if o.userAgent != "" { o.transport = transport.NewUserAgent(o.transport, o.userAgent) @@ -192,6 +198,13 @@ func WithTransport(t http.RoundTripper) Option { } } +func WithResumable() Option { + return func(o *options) error { + o.resumable = true + return nil + } +} + // WithAuth is a functional option for overriding the default authenticator // for remote operations. // It is an error to use both WithAuth and WithAuthFromKeychain in the same Option set. diff --git a/pkg/v1/remote/transport/resumable.go b/pkg/v1/remote/transport/resumable.go new file mode 100644 index 000000000..3819eba28 --- /dev/null +++ b/pkg/v1/remote/transport/resumable.go @@ -0,0 +1,242 @@ +package transport + +import ( + "errors" + "fmt" + "io" + "net/http" + "regexp" + "strconv" + "strings" + "sync/atomic" + + "github.com/google/go-containerregistry/pkg/logs" +) + +// NewResumable creates a http.RoundTripper that resumes http GET from error, +// and the inner should be wrapped with retry transport, otherwise, the +// transport will abort if resume() returns error. +func NewResumable(inner http.RoundTripper) http.RoundTripper { + return &resumableTransport{inner: inner} +} + +var ( + contentRangeRe = regexp.MustCompile(`^bytes (\d+)-(\d+)/(\d+|\*)$`) +) + +type resumableTransport struct { + inner http.RoundTripper +} + +func (rt *resumableTransport) RoundTrip(in *http.Request) (*http.Response, error) { + if in.Method != http.MethodGet { + return rt.inner.RoundTrip(in) + } + + req := in.Clone(in.Context()) + req.Header.Set("Range", "bytes=0-") + resp, err := rt.inner.RoundTrip(req) + if err != nil { + return resp, err + } + + switch resp.StatusCode { + case http.StatusPartialContent: + case http.StatusRequestedRangeNotSatisfiable: + // fallback to previous behavior + resp.Body.Close() + return rt.inner.RoundTrip(in) + default: + return resp, nil + } + + var contentLength int64 + if _, _, contentLength, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil || contentLength <= 0 { + // fallback to previous behavior + resp.Body.Close() + return rt.inner.RoundTrip(in) + } + + // modify response status to 200, ensure caller error checking works + resp.StatusCode = http.StatusOK + resp.Status = "200 OK" + resp.ContentLength = contentLength + resp.Body = &resumableBody{ + rc: resp.Body, + inner: rt.inner, + req: req, + total: contentLength, + transferred: 0, + } + + return resp, nil +} + +type resumableBody struct { + rc io.ReadCloser + + inner http.RoundTripper + req *http.Request + + transferred int64 + total int64 + + closed uint32 +} + +func (rb *resumableBody) Read(p []byte) (n int, err error) { + if atomic.LoadUint32(&rb.closed) == 1 { + // response body already closed + return 0, http.ErrBodyReadAfterClose + } else if rb.total >= 0 && rb.transferred >= rb.total { + return 0, io.EOF + } + +resume: + if n, err = rb.rc.Read(p); n > 0 { + rb.transferred += int64(n) + } + + if err == nil { + return + } + + if errors.Is(err, io.EOF) && rb.total >= 0 && rb.transferred == rb.total { + return + } + + if err = rb.resume(err); err == nil { + if n == 0 { + // zero bytes read, try reading again with new response.Body + goto resume + } + + // already read some bytes from previous response.Body, returns and waits for next Read operation + } + + return n, err +} + +func (rb *resumableBody) Close() (err error) { + if !atomic.CompareAndSwapUint32(&rb.closed, 0, 1) { + return nil + } + + return rb.rc.Close() +} + +func (rb *resumableBody) resume(reason error) error { + if reason != nil { + logs.Debug.Printf("Resume http transporting from error: %v", reason) + } + + ctx := rb.req.Context() + select { + case <-ctx.Done(): + // context already done, stop resuming from error + return ctx.Err() + default: + } + + req := rb.req.Clone(ctx) + req.Header.Set("Range", "bytes="+strconv.FormatInt(rb.transferred, 10)+"-") + resp, err := rb.inner.RoundTrip(req) + if err != nil { + return err + } + + if err = rb.validate(resp); err != nil { + resp.Body.Close() + return err + } + + if atomic.LoadUint32(&rb.closed) == 1 { + resp.Body.Close() + return http.ErrBodyReadAfterClose + } + + rb.rc.Close() + rb.rc = resp.Body + + return nil +} + +func (rb *resumableBody) validate(resp *http.Response) (err error) { + var start, total int64 + switch resp.StatusCode { + case http.StatusPartialContent: + if start, _, total, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil { + return err + } + + if total > rb.total { + rb.total = total + } + + if start == rb.transferred { + break + } else if start < rb.transferred { + if _, err := io.CopyN(io.Discard, resp.Body, rb.transferred-start); err != nil { + return fmt.Errorf("discard overlapped data failed, %v", err) + } + } else { + return fmt.Errorf("unexpected resume start %d, wanted: %d", start, rb.transferred) + } + case http.StatusOK: + if rb.transferred > 0 { + if _, err = io.CopyN(io.Discard, resp.Body, rb.transferred); err != nil { + return err + } + } + case http.StatusRequestedRangeNotSatisfiable: + if contentRange := resp.Header.Get("Content-Range"); contentRange != "" && strings.HasPrefix(contentRange, "bytes */") { + if total, err = strconv.ParseInt(strings.TrimPrefix(contentRange, "bytes */"), 10, 64); err == nil && total >= 0 && rb.transferred >= total { + return io.EOF + } + } + + fallthrough + default: + return fmt.Errorf("unexpected status code %d", resp.StatusCode) + } + + return nil +} + +func parseContentRange(contentRange string) (start, end, size int64, err error) { + if contentRange == "" { + return -1, -1, -1, errors.New("unexpected empty content range") + } + + matches := contentRangeRe.FindStringSubmatch(contentRange) + if len(matches) != 4 { + return -1, -1, -1, fmt.Errorf("invalid content range: %s", contentRange) + } + + if start, err = strconv.ParseInt(matches[1], 10, 64); err != nil { + return -1, -1, -1, fmt.Errorf("unexpected start from content range '%s', %v", contentRange, err) + } + + if end, err = strconv.ParseInt(matches[2], 10, 64); err != nil { + return -1, -1, -1, fmt.Errorf("unexpected end from content range '%s', %v", contentRange, err) + } + + if start > end { + return -1, -1, -1, fmt.Errorf("invalid content range: %s", contentRange) + } + + if matches[3] == "*" { + size = -1 + } else { + size, err = strconv.ParseInt(matches[3], 10, 64) + if err != nil { + return -1, -1, -1, fmt.Errorf("unexpected total from content range '%s', %v", contentRange, err) + } + + if end >= size { + return -1, -1, -1, fmt.Errorf("invalid content range: %s", contentRange) + } + } + + return +} diff --git a/pkg/v1/remote/transport/resumable_test.go b/pkg/v1/remote/transport/resumable_test.go new file mode 100644 index 000000000..3ab286078 --- /dev/null +++ b/pkg/v1/remote/transport/resumable_test.go @@ -0,0 +1,298 @@ +package transport + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "regexp" + "strconv" + "testing" + "time" + + stdrand "math/rand" + + "github.com/google/go-containerregistry/pkg/logs" + "github.com/google/go-containerregistry/pkg/v1/random" + "github.com/google/go-containerregistry/pkg/v1/types" +) + +var rangeRe = regexp.MustCompile(`bytes=(\d+)-(\d+)?`) + +func handleResumableLayer(data []byte, w http.ResponseWriter, r *http.Request, t *testing.T) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + contentRange := r.Header.Get("Range") + if contentRange == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + + matches := rangeRe.FindStringSubmatch(contentRange) + if len(matches) != 3 { + w.WriteHeader(http.StatusBadRequest) + return + } + + contentLength := int64(len(data)) + start, err := strconv.ParseInt(matches[1], 10, 64) + if err != nil || start < 0 { + w.WriteHeader(http.StatusBadRequest) + return + } + + if start >= int64(contentLength) { + w.WriteHeader(http.StatusRequestedRangeNotSatisfiable) + return + } + + var end = int64(contentLength) - 1 + if matches[2] != "" { + end, err = strconv.ParseInt(matches[2], 10, 64) + if err != nil || end < 0 { + w.WriteHeader(http.StatusBadRequest) + return + } + + if end >= int64(contentLength) { + w.WriteHeader(http.StatusRequestedRangeNotSatisfiable) + return + } + } + + var currentContentLength = end - start + 1 + if currentContentLength <= 0 { + w.WriteHeader(http.StatusInternalServerError) + return + } + + if currentContentLength > 4096 { + if currentContentLength = stdrand.Int63n(currentContentLength); currentContentLength < 1024 { + currentContentLength = 1024 + } + + if r.Header.Get("X-Overlap") == "true" { + overlapSize := int64(stdrand.Int31n(64)) + if start > overlapSize { + start -= overlapSize + // t.Logf("Overlap data size: %d", overlapSize) + } + } + } + + end = start + currentContentLength - 1 + + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, contentLength)) + w.Header().Set("Content-Length", strconv.FormatInt(currentContentLength, 10)) + w.WriteHeader(http.StatusPartialContent) + w.Write(data[start : end+1]) + time.Sleep(time.Second) +} + +func resumableRequest(client *http.Client, url string, size int64, digest string, overlap bool, t *testing.T) { + req, err := http.NewRequest(http.MethodGet, url, http.NoBody) + if err != nil { + t.Fatalf("http.NewRequest(): %v", err) + } + + if overlap { + req.Header.Set("X-Overlap", "true") + } + + resp, err := client.Do(req.WithContext(t.Context())) + if err != nil { + t.Fatalf("client.Do(): %v", err) + } + defer resp.Body.Close() + + if _, ok := resp.Body.(*resumableBody); !ok { + t.Error("expected resumable body") + return + } + + hash := sha256.New() + + if _, err = io.Copy(hash, resp.Body); err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + actualDigest := "sha256:" + hex.EncodeToString(hash.Sum(nil)) + + if actualDigest != digest { + t.Errorf("unexpected digest: %s, actually: %s", digest, actualDigest) + } +} + +func nonResumableRequest(client *http.Client, url string, t *testing.T) { + req, err := http.NewRequest(http.MethodGet, url, http.NoBody) + if err != nil { + t.Fatalf("http.NewRequest(): %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("client.Do(): %v", err) + } + + _, ok := resp.Body.(*resumableBody) + if ok { + t.Error("expected non-resumable body") + } +} + +func resumableStopByTimeoutRequest(client *http.Client, url string, t *testing.T) { + req, err := http.NewRequest(http.MethodGet, url, http.NoBody) + if err != nil { + t.Fatalf("http.NewRequest(): %v", err) + } + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*3) + defer cancel() + + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + t.Fatalf("client.Do(): %v", err) + } + defer resp.Body.Close() + + if _, ok := resp.Body.(*resumableBody); !ok { + t.Error("expected resumable body") + return + } + + if _, err = io.Copy(io.Discard, resp.Body); err != nil && !errors.Is(err, context.DeadlineExceeded) { + t.Error("expected context deadline exceeded error") + } +} + +func resumableStopByCancelRequest(client *http.Client, url string, t *testing.T) { + req, err := http.NewRequest(http.MethodGet, url, http.NoBody) + if err != nil { + t.Fatalf("http.NewRequest(): %v", err) + } + + ctx, cancel := context.WithCancel(t.Context()) + time.AfterFunc(time.Second*3, cancel) + + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + t.Fatalf("client.Do(): %v", err) + } + defer resp.Body.Close() + + if _, ok := resp.Body.(*resumableBody); !ok { + t.Error("expected resumable body") + return + } + + if _, err = io.Copy(io.Discard, resp.Body); err != nil && !errors.Is(err, context.Canceled) { + t.Error("expected context cancel error") + } +} + +func TestResumableTransport(t *testing.T) { + logs.Debug.SetOutput(os.Stdout) + layer, err := random.Layer(2<<20, types.DockerLayer) + if err != nil { + t.Fatalf("random.Layer(): %v", err) + } + + digest, err := layer.Digest() + if err != nil { + t.Fatalf("layer.Digest(): %v", err) + } + + size, err := layer.Size() + if err != nil { + t.Fatalf("layer.Size(): %v", err) + } + + rc, err := layer.Compressed() + if err != nil { + t.Fatalf("layer.Compressed(): %v", err) + } + + data, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("io.ReadAll(): %v", err) + } + + layerPath := fmt.Sprintf("/v2/foo/bar/blobs/%s", digest.String()) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case layerPath: + handleResumableLayer(data, w, r, t) + default: + http.Error(w, "not found", http.StatusNotFound) + } + })) + defer server.Close() + + address, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("url.Parse(%v) = %v", server.URL, err) + } + + client := &http.Client{ + Transport: NewResumable(http.DefaultTransport.(*http.Transport).Clone()), + } + + tests := []struct { + name string + digest string + size int64 + timeout bool + cancel bool + nonResumable bool + overlap bool + }{ + { + name: "resumable", + digest: digest.String(), + size: size, + }, + { + name: "resumable-overlap", + digest: digest.String(), + size: size, + overlap: true, + }, + { + name: "non-resumable", + nonResumable: true, + }, + { + name: "resumable stop by timeout", + cancel: true, + }, + { + name: "resumable stop by cancel", + cancel: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := address.String() + layerPath + if tt.nonResumable { + nonResumableRequest(client, address.String(), t) + } else if tt.cancel { + resumableStopByCancelRequest(client, url, t) + } else if tt.timeout { + resumableStopByTimeoutRequest(client, url, t) + } else if tt.digest != "" && tt.size > 0 { + resumableRequest(client, url, tt.size, tt.digest, tt.overlap, t) + } + }) + } +} From 26008119687ad0614605b8605a5ad27c79fa6aad Mon Sep 17 00:00:00 2001 From: Derek Ray Date: Thu, 13 Nov 2025 15:32:53 +0800 Subject: [PATCH 2/4] feat(transport): using loops instead of labels for resumableBoyd.Read() --- pkg/v1/remote/transport/resumable.go | 35 ++++++++++++++-------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/pkg/v1/remote/transport/resumable.go b/pkg/v1/remote/transport/resumable.go index 3819eba28..32749e0a9 100644 --- a/pkg/v1/remote/transport/resumable.go +++ b/pkg/v1/remote/transport/resumable.go @@ -92,29 +92,30 @@ func (rb *resumableBody) Read(p []byte) (n int, err error) { return 0, io.EOF } -resume: - if n, err = rb.rc.Read(p); n > 0 { - rb.transferred += int64(n) - } + for { + if n, err = rb.rc.Read(p); n > 0 { + rb.transferred += int64(n) + } - if err == nil { - return - } + if err == nil { + return + } - if errors.Is(err, io.EOF) && rb.total >= 0 && rb.transferred == rb.total { - return - } + if errors.Is(err, io.EOF) && rb.total >= 0 && rb.transferred == rb.total { + return + } + + if err = rb.resume(err); err == nil { + if n == 0 { + // zero bytes read, try reading again with new response.Body + continue + } - if err = rb.resume(err); err == nil { - if n == 0 { - // zero bytes read, try reading again with new response.Body - goto resume + // already read some bytes from previous response.Body, returns and waits for next Read operation } - // already read some bytes from previous response.Body, returns and waits for next Read operation + return n, err } - - return n, err } func (rb *resumableBody) Close() (err error) { From 6e297fbc1e0480fe53a68eae10bcfe1cde8d1dcf Mon Sep 17 00:00:00 2001 From: Derek Ray Date: Thu, 13 Nov 2025 22:42:09 +0800 Subject: [PATCH 3/4] optimize(transport): optimize resumable transport implementaion --- pkg/v1/remote/transport/resumable.go | 100 ++++++++++------ pkg/v1/remote/transport/resumable_test.go | 135 ++++++++++++++-------- 2 files changed, 153 insertions(+), 82 deletions(-) diff --git a/pkg/v1/remote/transport/resumable.go b/pkg/v1/remote/transport/resumable.go index 32749e0a9..092bf456e 100644 --- a/pkg/v1/remote/transport/resumable.go +++ b/pkg/v1/remote/transport/resumable.go @@ -22,51 +22,70 @@ func NewResumable(inner http.RoundTripper) http.RoundTripper { var ( contentRangeRe = regexp.MustCompile(`^bytes (\d+)-(\d+)/(\d+|\*)$`) + rangeRe = regexp.MustCompile(`bytes=(\d+)-(\d+)?`) ) type resumableTransport struct { inner http.RoundTripper } -func (rt *resumableTransport) RoundTrip(in *http.Request) (*http.Response, error) { - if in.Method != http.MethodGet { - return rt.inner.RoundTrip(in) +func (rt *resumableTransport) RoundTrip(in *http.Request) (resp *http.Response, err error) { + var total, start, end int64 + // check initial request, maybe resumable transport is already enabled + if contentRange := in.Header.Get("Range"); contentRange != "" { + if matches := rangeRe.FindStringSubmatch(contentRange); len(matches) == 3 { + if start, err = strconv.ParseInt(matches[1], 10, 64); err != nil { + return nil, fmt.Errorf("invalid content range %q: %w", contentRange, err) + } + + if len(matches[2]) == 0 { + // request whole file + end = -1 + } else if end, err = strconv.ParseInt(matches[2], 10, 64); err == nil { + if start > end { + return nil, fmt.Errorf("invalid content range %q", contentRange) + } + } else { + return nil, fmt.Errorf("invalid content range %q: %w", contentRange, err) + } + } } - req := in.Clone(in.Context()) - req.Header.Set("Range", "bytes=0-") - resp, err := rt.inner.RoundTrip(req) - if err != nil { + if resp, err = rt.inner.RoundTrip(in); err != nil { return resp, err } + if in.Method != http.MethodGet { + return resp, nil + } + switch resp.StatusCode { + case http.StatusOK: + if end != 0 { + // request range content, but unexpected status code, cant not resume for this request + return resp, nil + } + + total = resp.ContentLength case http.StatusPartialContent: - case http.StatusRequestedRangeNotSatisfiable: - // fallback to previous behavior - resp.Body.Close() - return rt.inner.RoundTrip(in) + // keep original response status code, which should be processed by original transport or operation + if start, _, total, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil || total <= 0 { + return resp, nil + } else if end > 0 { + total = end + 1 + } default: return resp, nil } - var contentLength int64 - if _, _, contentLength, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil || contentLength <= 0 { - // fallback to previous behavior - resp.Body.Close() - return rt.inner.RoundTrip(in) - } - - // modify response status to 200, ensure caller error checking works - resp.StatusCode = http.StatusOK - resp.Status = "200 OK" - resp.ContentLength = contentLength - resp.Body = &resumableBody{ - rc: resp.Body, - inner: rt.inner, - req: req, - total: contentLength, - transferred: 0, + if total > 0 { + resp.Body = &resumableBody{ + rc: resp.Body, + inner: rt.inner, + req: in, + total: total, + transferred: start, + } } return resp, nil @@ -94,6 +113,10 @@ func (rb *resumableBody) Read(p []byte) (n int, err error) { for { if n, err = rb.rc.Read(p); n > 0 { + if rb.transferred+int64(n) >= rb.total { + n = int(rb.total - rb.transferred) + err = io.EOF + } rb.transferred += int64(n) } @@ -101,7 +124,7 @@ func (rb *resumableBody) Read(p []byte) (n int, err error) { return } - if errors.Is(err, io.EOF) && rb.total >= 0 && rb.transferred == rb.total { + if errors.Is(err, io.EOF) && rb.total >= 0 && rb.transferred >= rb.total { return } @@ -148,7 +171,8 @@ func (rb *resumableBody) resume(reason error) error { if err = rb.validate(resp); err != nil { resp.Body.Close() - return err + // wraps original error + return fmt.Errorf("%w, %v", reason, err) } if atomic.LoadUint32(&rb.closed) == 1 { @@ -162,21 +186,21 @@ func (rb *resumableBody) resume(reason error) error { return nil } +const size100m = 100 << 20 + func (rb *resumableBody) validate(resp *http.Response) (err error) { var start, total int64 switch resp.StatusCode { case http.StatusPartialContent: - if start, _, total, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil { + // donot using total size from Content-Range header, keep rb.total unchanged + if start, _, _, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil { return err } - if total > rb.total { - rb.total = total - } - if start == rb.transferred { break } else if start < rb.transferred { + // incoming data is overlapped for somehow, just discard it if _, err := io.CopyN(io.Discard, resp.Body, rb.transferred-start); err != nil { return fmt.Errorf("discard overlapped data failed, %v", err) } @@ -185,6 +209,12 @@ func (rb *resumableBody) validate(resp *http.Response) (err error) { } case http.StatusOK: if rb.transferred > 0 { + // range is not supported, and transferred data is too large, stop resuming + if rb.transferred > size100m { + return fmt.Errorf("too large data transferred: %d", rb.transferred) + } + + // try resume from unsupported range request if _, err = io.CopyN(io.Discard, resp.Body, rb.transferred); err != nil { return err } diff --git a/pkg/v1/remote/transport/resumable_test.go b/pkg/v1/remote/transport/resumable_test.go index 3ab286078..cef0533e6 100644 --- a/pkg/v1/remote/transport/resumable_test.go +++ b/pkg/v1/remote/transport/resumable_test.go @@ -1,6 +1,7 @@ package transport import ( + "bytes" "context" "crypto/sha256" "encoding/hex" @@ -11,7 +12,6 @@ import ( "net/http/httptest" "net/url" "os" - "regexp" "strconv" "testing" "time" @@ -23,50 +23,52 @@ import ( "github.com/google/go-containerregistry/pkg/v1/types" ) -var rangeRe = regexp.MustCompile(`bytes=(\d+)-(\d+)?`) - func handleResumableLayer(data []byte, w http.ResponseWriter, r *http.Request, t *testing.T) { if r.Method != http.MethodGet { w.WriteHeader(http.StatusMethodNotAllowed) return } - contentRange := r.Header.Get("Range") - if contentRange == "" { - w.WriteHeader(http.StatusBadRequest) - return - } - - matches := rangeRe.FindStringSubmatch(contentRange) - if len(matches) != 3 { - w.WriteHeader(http.StatusBadRequest) - return - } - - contentLength := int64(len(data)) - start, err := strconv.ParseInt(matches[1], 10, 64) - if err != nil || start < 0 { - w.WriteHeader(http.StatusBadRequest) - return - } + var ( + contentLength, start, end int64 + statusCode = http.StatusOK + err error + ) - if start >= int64(contentLength) { - w.WriteHeader(http.StatusRequestedRangeNotSatisfiable) - return - } + contentLength = int64(len(data)) + end = contentLength - 1 + contentRange := r.Header.Get("Range") + if contentRange != "" { + matches := rangeRe.FindStringSubmatch(contentRange) + if len(matches) != 3 { + w.WriteHeader(http.StatusBadRequest) + return + } - var end = int64(contentLength) - 1 - if matches[2] != "" { - end, err = strconv.ParseInt(matches[2], 10, 64) - if err != nil || end < 0 { + if start, err = strconv.ParseInt(matches[1], 10, 64); err != nil || start < 0 { w.WriteHeader(http.StatusBadRequest) return } - if end >= int64(contentLength) { + if start >= int64(contentLength) { w.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return } + + if matches[2] != "" { + end, err = strconv.ParseInt(matches[2], 10, 64) + if err != nil || end < 0 { + w.WriteHeader(http.StatusBadRequest) + return + } + + if end >= int64(contentLength) { + w.WriteHeader(http.StatusRequestedRangeNotSatisfiable) + return + } + } + + statusCode = http.StatusPartialContent } var currentContentLength = end - start + 1 @@ -91,14 +93,19 @@ func handleResumableLayer(data []byte, w http.ResponseWriter, r *http.Request, t end = start + currentContentLength - 1 - w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, contentLength)) - w.Header().Set("Content-Length", strconv.FormatInt(currentContentLength, 10)) - w.WriteHeader(http.StatusPartialContent) + if statusCode == http.StatusPartialContent { + w.Header().Set("Content-Length", strconv.FormatInt(currentContentLength, 10)) + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, contentLength)) + } else { + w.Header().Set("Content-Length", strconv.FormatInt(contentLength, 10)) + } + + w.WriteHeader(statusCode) w.Write(data[start : end+1]) time.Sleep(time.Second) } -func resumableRequest(client *http.Client, url string, size int64, digest string, overlap bool, t *testing.T) { +func resumableRequest(client *http.Client, url string, leading, trailing []byte, size int64, digest string, overlap bool, t *testing.T) { req, err := http.NewRequest(http.MethodGet, url, http.NoBody) if err != nil { t.Fatalf("http.NewRequest(): %v", err) @@ -108,6 +115,16 @@ func resumableRequest(client *http.Client, url string, size int64, digest string req.Header.Set("X-Overlap", "true") } + if len(leading) > 0 || len(trailing) > 0 { + var buf bytes.Buffer + buf.WriteString("bytes=") + buf.WriteString(fmt.Sprintf("%d-", len(leading))) + if len(trailing) > 0 { + buf.WriteString(fmt.Sprintf("%d", size-int64(len(trailing))-1)) + } + req.Header.Set("Range", buf.String()) + } + resp, err := client.Do(req.WithContext(t.Context())) if err != nil { t.Fatalf("client.Do(): %v", err) @@ -120,12 +137,19 @@ func resumableRequest(client *http.Client, url string, size int64, digest string } hash := sha256.New() + if len(leading) > 0 { + io.Copy(hash, bytes.NewReader(leading)) + } if _, err = io.Copy(hash, resp.Body); err != nil { t.Errorf("unexpected error: %v", err) return } + if len(trailing) > 0 { + io.Copy(hash, bytes.NewReader(trailing)) + } + actualDigest := "sha256:" + hex.EncodeToString(hash.Sum(nil)) if actualDigest != digest { @@ -248,23 +272,40 @@ func TestResumableTransport(t *testing.T) { } tests := []struct { - name string - digest string - size int64 - timeout bool - cancel bool - nonResumable bool - overlap bool + name string + digest string + leading, trailing int64 + timeout bool + cancel bool + nonResumable bool + overlap bool + ranged bool }{ { - name: "resumable", - digest: digest.String(), - size: size, + name: "resumable", + digest: digest.String(), + leading: 0, + }, + { + name: "resumable-range-leading", + digest: digest.String(), + leading: 3, + }, + { + name: "resumable-range-trailing", + digest: digest.String(), + leading: 0, + }, + { + name: "resumable-range-leading-trailing", + digest: digest.String(), + leading: 3, + trailing: 6, }, { name: "resumable-overlap", digest: digest.String(), - size: size, + leading: 0, overlap: true, }, { @@ -290,8 +331,8 @@ func TestResumableTransport(t *testing.T) { resumableStopByCancelRequest(client, url, t) } else if tt.timeout { resumableStopByTimeoutRequest(client, url, t) - } else if tt.digest != "" && tt.size > 0 { - resumableRequest(client, url, tt.size, tt.digest, tt.overlap, t) + } else if tt.digest != "" { + resumableRequest(client, url, data[:tt.leading], data[size-tt.trailing:], size, tt.digest, tt.overlap, t) } }) } From 17f518aaa778aee47143414bd3779091fa990b49 Mon Sep 17 00:00:00 2001 From: Derek Ray Date: Mon, 17 Nov 2025 09:45:33 +0800 Subject: [PATCH 4/4] feat(transport): add resumable backoff option support --- pkg/v1/remote/options.go | 16 +++- pkg/v1/remote/transport/resumable.go | 95 +++++++++++++++-------- pkg/v1/remote/transport/resumable_test.go | 7 +- 3 files changed, 83 insertions(+), 35 deletions(-) diff --git a/pkg/v1/remote/options.go b/pkg/v1/remote/options.go index 5f408dd7f..e1224c84a 100644 --- a/pkg/v1/remote/options.go +++ b/pkg/v1/remote/options.go @@ -46,6 +46,7 @@ type options struct { retryPredicate retry.Predicate retryStatusCodes []int resumable bool + resumableBackoff Backoff // Only these options can overwrite Reuse()d options. platform v1.Platform @@ -136,6 +137,7 @@ func makeOptions(opts ...Option) (*options, error) { retryPredicate: defaultRetryPredicate, retryBackoff: defaultRetryBackoff, retryStatusCodes: defaultRetryStatusCodes, + resumableBackoff: defaultRetryBackoff, } for _, option := range opts { @@ -173,7 +175,7 @@ func makeOptions(opts ...Option) (*options, error) { o.transport = transport.NewRetry(o.transport, transport.WithRetryBackoff(o.retryBackoff), transport.WithRetryPredicate(predicate), transport.WithRetryStatusCodes(o.retryStatusCodes...)) if o.resumable { - o.transport = transport.NewResumable(o.transport) + o.transport = transport.NewResumable(o.transport, o.resumableBackoff) } // Wrap this last to prevent transport.New from double-wrapping. @@ -198,6 +200,8 @@ func WithTransport(t http.RoundTripper) Option { } } +// WithResumable is a functional option for enabling resumable downloads. and it will wrap retry transport by default. +// If configures retry and resumable backoff, should be aware of all backoff will be applied. func WithResumable() Option { return func(o *options) error { o.resumable = true @@ -205,6 +209,16 @@ func WithResumable() Option { } } +// WithResumableBackoff is a functional option for overriding the default resumable backoff for remote operations. +// Resumable backoff will resume failed requests after a delay, unlike retry actions, resumable backoff will ignore +// transport.RoundTripper.RoundTrip errors. +func WithResumableBackoff(backoff Backoff) Option { + return func(o *options) error { + o.resumableBackoff = backoff + return nil + } +} + // WithAuth is a functional option for overriding the default authenticator // for remote operations. // It is an error to use both WithAuth and WithAuthFromKeychain in the same Option set. diff --git a/pkg/v1/remote/transport/resumable.go b/pkg/v1/remote/transport/resumable.go index 092bf456e..98084aa2f 100644 --- a/pkg/v1/remote/transport/resumable.go +++ b/pkg/v1/remote/transport/resumable.go @@ -9,15 +9,24 @@ import ( "strconv" "strings" "sync/atomic" + "time" "github.com/google/go-containerregistry/pkg/logs" ) -// NewResumable creates a http.RoundTripper that resumes http GET from error, -// and the inner should be wrapped with retry transport, otherwise, the -// transport will abort if resume() returns error. -func NewResumable(inner http.RoundTripper) http.RoundTripper { - return &resumableTransport{inner: inner} +// NewResumable creates a http.RoundTripper that resumes http GET from error, and continue +// transfer data from last successful transfer offset. +func NewResumable(inner http.RoundTripper, backoff Backoff) http.RoundTripper { + if backoff.Steps <= 0 { + // resume once + backoff.Steps = 1 + } + + if backoff.Duration <= 0 { + backoff.Duration = 100 * time.Millisecond + } + + return &resumableTransport{inner: inner, backoff: backoff} } var ( @@ -26,7 +35,8 @@ var ( ) type resumableTransport struct { - inner http.RoundTripper + inner http.RoundTripper + backoff Backoff } func (rt *resumableTransport) RoundTrip(in *http.Request) (resp *http.Response, err error) { @@ -85,6 +95,7 @@ func (rt *resumableTransport) RoundTrip(in *http.Request) (resp *http.Response, req: in, total: total, transferred: start, + backoff: rt.backoff, } } @@ -97,6 +108,8 @@ type resumableBody struct { inner http.RoundTripper req *http.Request + backoff Backoff + transferred int64 total int64 @@ -128,7 +141,7 @@ func (rb *resumableBody) Read(p []byte) (n int, err error) { return } - if err = rb.resume(err); err == nil { + if err = rb.resume(rb.backoff, err); err == nil { if n == 0 { // zero bytes read, try reading again with new response.Body continue @@ -149,41 +162,57 @@ func (rb *resumableBody) Close() (err error) { return rb.rc.Close() } -func (rb *resumableBody) resume(reason error) error { +func (rb *resumableBody) resume(backoff Backoff, reason error) error { + if backoff.Steps <= 0 { + // resumable transport is disabled + return reason + } + if reason != nil { logs.Debug.Printf("Resume http transporting from error: %v", reason) } - ctx := rb.req.Context() - select { - case <-ctx.Done(): - // context already done, stop resuming from error - return ctx.Err() - default: - } + var ( + resp *http.Response + err error + ) - req := rb.req.Clone(ctx) - req.Header.Set("Range", "bytes="+strconv.FormatInt(rb.transferred, 10)+"-") - resp, err := rb.inner.RoundTrip(req) - if err != nil { - return err - } + for backoff.Steps > 0 { + time.Sleep(backoff.Step()) - if err = rb.validate(resp); err != nil { - resp.Body.Close() - // wraps original error - return fmt.Errorf("%w, %v", reason, err) - } + ctx := rb.req.Context() + select { + case <-ctx.Done(): + // context already done, stop resuming from error + return ctx.Err() + default: + } - if atomic.LoadUint32(&rb.closed) == 1 { - resp.Body.Close() - return http.ErrBodyReadAfterClose - } + req := rb.req.Clone(ctx) + req.Header.Set("Range", "bytes="+strconv.FormatInt(rb.transferred, 10)+"-") + if resp, err = rb.inner.RoundTrip(req); err != nil { + err = fmt.Errorf("unable to resume from '%v', %w", reason, err) + continue + } - rb.rc.Close() - rb.rc = resp.Body + if err = rb.validate(resp); err != nil { + resp.Body.Close() + // wraps original error + return fmt.Errorf("%w, %v", reason, err) + } - return nil + if atomic.LoadUint32(&rb.closed) == 1 { + resp.Body.Close() + return http.ErrBodyReadAfterClose + } + + rb.rc.Close() + rb.rc = resp.Body + + break + } + + return err } const size100m = 100 << 20 diff --git a/pkg/v1/remote/transport/resumable_test.go b/pkg/v1/remote/transport/resumable_test.go index cef0533e6..2e109a8d6 100644 --- a/pkg/v1/remote/transport/resumable_test.go +++ b/pkg/v1/remote/transport/resumable_test.go @@ -268,7 +268,12 @@ func TestResumableTransport(t *testing.T) { } client := &http.Client{ - Transport: NewResumable(http.DefaultTransport.(*http.Transport).Clone()), + Transport: NewResumable(http.DefaultTransport.(*http.Transport).Clone(), Backoff{ + Duration: 1.0 * time.Second, + Factor: 3.0, + Jitter: 0.1, + Steps: 3, + }), } tests := []struct {