Skip to content

Commit 9b6b14c

Browse files
committed
io: add Err field to LimitedReader
Add an Err field to LimitedReader that allows callers to return a custom error when the read limit is exceeded, instead of always returning EOF. When Err is set to a non-nil, non-EOF value, and the limit is reached, LimitedReader.Read probes the underlying reader with a 1-byte read to distinguish two cases: stream had exactly N bytes (returns EOF), or stream has more data (returns the custom Err). The probe result is cached using negative N values to avoid repeated reads. When Err is nil or EOF, Read returns EOF, maintaining backward compatibility. Zero-length reads return (0, nil) without side effects. Fixes #51115
1 parent 388c41c commit 9b6b14c

File tree

4 files changed

+290
-13
lines changed

4 files changed

+290
-13
lines changed

api/next/51115.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pkg io, type LimitedReader struct, Err error #51115

src/io/example_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package io_test
66

77
import (
8+
"errors"
89
"fmt"
910
"io"
1011
"log"
@@ -121,6 +122,29 @@ func ExampleLimitReader() {
121122
// some
122123
}
123124

125+
func ExampleLimitedReader_Err() {
126+
r := strings.NewReader("some io.Reader stream to be read\n")
127+
sentinel := errors.New("read limit reached")
128+
lr := &io.LimitedReader{R: r, N: 4, Err: sentinel}
129+
130+
buf := make([]byte, 10)
131+
n, err := lr.Read(buf)
132+
if err != nil {
133+
log.Fatal(err)
134+
}
135+
fmt.Printf("read %d bytes: %q\n", n, buf[:n])
136+
137+
// try to read more and get the custom error
138+
n, err = lr.Read(buf)
139+
if errors.Is(err, sentinel) {
140+
fmt.Println("error:", err)
141+
}
142+
143+
// Output:
144+
// read 4 bytes: "some"
145+
// error: read limit reached
146+
}
147+
124148
func ExampleMultiReader() {
125149
r1 := strings.NewReader("first reader ")
126150
r2 := strings.NewReader("second reader ")

src/io/io.go

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -457,28 +457,66 @@ func copyBuffer(dst Writer, src Reader, buf []byte) (written int64, err error) {
457457

458458
// LimitReader returns a Reader that reads from r
459459
// but stops with EOF after n bytes.
460-
// The underlying implementation is a *LimitedReader.
461-
func LimitReader(r Reader, n int64) Reader { return &LimitedReader{r, n} }
460+
// To return a custom error when the limit is reached, construct
461+
// a *LimitedReader directly with the desired Err field.
462+
func LimitReader(r Reader, n int64) Reader { return &LimitedReader{R: r, N: n} }
462463

463464
// A LimitedReader reads from R but limits the amount of
464465
// data returned to just N bytes. Each call to Read
465466
// updates N to reflect the new amount remaining.
466-
// Read returns EOF when N <= 0 or when the underlying R returns EOF.
467+
//
468+
// Negative values of N mean that the limit has been exceeded.
469+
// Read returns Err when more than N bytes are read from R.
470+
// If Err is nil or EOF, Read returns EOF instead.
467471
type LimitedReader struct {
468-
R Reader // underlying reader
469-
N int64 // max bytes remaining
472+
R Reader // underlying reader
473+
N int64 // max bytes remaining
474+
Err error // error to return when limit is exceeded; defaults to EOF if nil
470475
}
471476

472477
func (l *LimitedReader) Read(p []byte) (n int, err error) {
473-
if l.N <= 0 {
478+
if len(p) == 0 {
479+
return 0, nil
480+
}
481+
// We use negative l.N values to signal that we've exceeded the limit and cached the result:
482+
// -1 means more data is available
483+
// -2 means hit EOF exactly
484+
485+
if l.N > 0 {
486+
if int64(len(p)) > l.N {
487+
p = p[0:l.N]
488+
}
489+
n, err = l.R.Read(p)
490+
l.N -= int64(n)
491+
return
492+
}
493+
494+
if l.N < 0 {
495+
if l.N == -1 && l.Err != nil && l.Err != EOF {
496+
return 0, l.Err // limit was exceeded
497+
}
498+
return 0, EOF // stream was exactly N bytes, or already past limit
499+
}
500+
501+
// At limit (N == 0) - need to determine if stream has more data
502+
503+
if l.Err == nil || l.Err == EOF {
474504
return 0, EOF
475505
}
476-
if int64(len(p)) > l.N {
477-
p = p[0:l.N]
506+
507+
// Probe with one byte to distinguish two cases:
508+
// - Stream had exactly N bytes -> return EOF
509+
// - Stream has more than N bytes -> return custom error
510+
// We can't tell without reading ahead. This probe permanently consumes
511+
// a byte from R, so we cache the result in N to avoid re-probing.
512+
var probe [1]byte
513+
probeN, probeErr := l.R.Read(probe[:])
514+
if probeN > 0 || (probeErr != nil && probeErr != EOF) {
515+
l.N = -1 // more data available, limit exceeded
516+
return 0, l.Err
478517
}
479-
n, err = l.R.Read(p)
480-
l.N -= int64(n)
481-
return
518+
l.N = -2 // hit EOF, stream was exactly N bytes
519+
return 0, EOF
482520
}
483521

484522
// NewSectionReader returns a [SectionReader] that reads from r
@@ -518,8 +556,10 @@ func (s *SectionReader) Read(p []byte) (n int, err error) {
518556
return
519557
}
520558

521-
var errWhence = errors.New("Seek: invalid whence")
522-
var errOffset = errors.New("Seek: invalid offset")
559+
var (
560+
errWhence = errors.New("Seek: invalid whence")
561+
errOffset = errors.New("Seek: invalid offset")
562+
)
523563

524564
func (s *SectionReader) Seek(offset int64, whence int) (int64, error) {
525565
switch whence {

src/io/io_test.go

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,3 +692,215 @@ func TestOffsetWriter_Write(t *testing.T) {
692692
checkContent(name, f)
693693
})
694694
}
695+
696+
var errLimit = errors.New("limit exceeded")
697+
698+
func TestLimitedReader(t *testing.T) {
699+
src := strings.NewReader("abc")
700+
r := LimitReader(src, 5)
701+
lr, ok := r.(*LimitedReader)
702+
if !ok {
703+
t.Fatalf("LimitReader should return *LimitedReader, got %T", r)
704+
}
705+
if lr.R != src || lr.N != 5 || lr.Err != nil {
706+
t.Fatalf("LimitReader() = {R: %v, N: %d, Err: %v}, want {R: %v, N: 5, Err: nil}", lr.R, lr.N, lr.Err, src)
707+
}
708+
709+
t.Run("WithoutCustomErr", func(t *testing.T) {
710+
tests := []struct {
711+
name string
712+
data string
713+
limit int64
714+
want1N int
715+
want1E error
716+
want2E error
717+
}{
718+
{"UnderLimit", "hello", 10, 5, nil, EOF},
719+
{"ExactLimit", "hello", 5, 5, nil, EOF},
720+
{"OverLimit", "hello world", 5, 5, nil, EOF},
721+
{"ZeroLimit", "hello", 0, 0, EOF, EOF},
722+
}
723+
724+
for _, tt := range tests {
725+
t.Run(tt.name, func(t *testing.T) {
726+
lr := &LimitedReader{R: strings.NewReader(tt.data), N: tt.limit}
727+
buf := make([]byte, 10)
728+
729+
n, err := lr.Read(buf)
730+
if n != tt.want1N || err != tt.want1E {
731+
t.Errorf("first Read() = (%d, %v), want (%d, %v)", n, err, tt.want1N, tt.want1E)
732+
}
733+
734+
n, err = lr.Read(buf)
735+
if n != 0 || err != tt.want2E {
736+
t.Errorf("second Read() = (%d, %v), want (0, %v)", n, err, tt.want2E)
737+
}
738+
})
739+
}
740+
})
741+
742+
t.Run("WithCustomErr", func(t *testing.T) {
743+
tests := []struct {
744+
name string
745+
data string
746+
limit int64
747+
err error
748+
wantFirst string
749+
wantErr1 error
750+
wantErr2 error
751+
}{
752+
{"ExactLimit", "hello", 5, errLimit, "hello", nil, EOF},
753+
{"OverLimit", "hello world", 5, errLimit, "hello", nil, errLimit},
754+
{"UnderLimit", "hi", 5, errLimit, "hi", nil, EOF},
755+
{"ZeroLimitEmpty", "", 0, errLimit, "", EOF, EOF},
756+
{"ZeroLimitNonEmpty", "hello", 0, errLimit, "", errLimit, errLimit},
757+
}
758+
759+
for _, tt := range tests {
760+
t.Run(tt.name, func(t *testing.T) {
761+
lr := &LimitedReader{R: strings.NewReader(tt.data), N: tt.limit, Err: tt.err}
762+
buf := make([]byte, 10)
763+
764+
n, err := lr.Read(buf)
765+
if n != len(tt.wantFirst) || string(buf[:n]) != tt.wantFirst || err != tt.wantErr1 {
766+
t.Errorf("first Read() = (%d, %q, %v), want (%d, %q, %v)", n, buf[:n], err, len(tt.wantFirst), tt.wantFirst, tt.wantErr1)
767+
}
768+
769+
n, err = lr.Read(buf)
770+
if n != 0 || err != tt.wantErr2 {
771+
t.Errorf("second Read() = (%d, %v), want (0, %v)", n, err, tt.wantErr2)
772+
}
773+
})
774+
}
775+
})
776+
777+
t.Run("CustomErrPersists", func(t *testing.T) {
778+
lr := &LimitedReader{R: strings.NewReader("hello world"), N: 5, Err: errLimit}
779+
buf := make([]byte, 10)
780+
781+
n, err := lr.Read(buf)
782+
if n != 5 || err != nil || string(buf[:5]) != "hello" {
783+
t.Errorf("Read() = (%d, %v, %q), want (5, nil, \"hello\")", n, err, buf[:5])
784+
}
785+
786+
n, err = lr.Read(buf)
787+
if n != 0 || err != errLimit {
788+
t.Errorf("Read() = (%d, %v), want (0, errLimit)", n, err)
789+
}
790+
791+
n, err = lr.Read(buf)
792+
if n != 0 || err != errLimit {
793+
t.Errorf("Read() = (%d, %v), want (0, errLimit)", n, err)
794+
}
795+
})
796+
797+
t.Run("ErrEOF", func(t *testing.T) {
798+
lr := &LimitedReader{R: strings.NewReader("hello world"), N: 5, Err: EOF}
799+
buf := make([]byte, 10)
800+
801+
n, err := lr.Read(buf)
802+
if n != 5 || err != nil {
803+
t.Errorf("Read() = (%d, %v), want (5, nil)", n, err)
804+
}
805+
806+
n, err = lr.Read(buf)
807+
if n != 0 || err != EOF {
808+
t.Errorf("Read() = (%d, %v), want (0, EOF)", n, err)
809+
}
810+
})
811+
812+
t.Run("NoSideEffects", func(t *testing.T) {
813+
lr := &LimitedReader{R: strings.NewReader("hello"), N: 5, Err: errLimit}
814+
buf := make([]byte, 0)
815+
816+
for i := 0; i < 3; i++ {
817+
n, err := lr.Read(buf)
818+
if n != 0 || err != nil {
819+
t.Errorf("zero-length read #%d = (%d, %v), want (0, nil)", i+1, n, err)
820+
}
821+
if lr.N != 5 {
822+
t.Errorf("N after zero-length read #%d = %d, want 5", i+1, lr.N)
823+
}
824+
}
825+
826+
buf = make([]byte, 10)
827+
n, err := lr.Read(buf)
828+
if n != 5 || string(buf[:5]) != "hello" || err != nil {
829+
t.Errorf("normal Read() = (%d, %q, %v), want (5, \"hello\", nil)", n, buf[:5], err)
830+
}
831+
})
832+
}
833+
834+
type errorReader struct {
835+
data []byte
836+
pos int
837+
err error
838+
}
839+
840+
func (r *errorReader) Read(p []byte) (int, error) {
841+
if r.pos >= len(r.data) {
842+
return 0, r.err
843+
}
844+
n := copy(p, r.data[r.pos:])
845+
r.pos += n
846+
return n, nil
847+
}
848+
849+
func TestLimitedReaderErrors(t *testing.T) {
850+
t.Run("UnderlyingError", func(t *testing.T) {
851+
underlyingErr := errors.New("boom")
852+
lr := &LimitedReader{R: &errorReader{data: []byte("hello"), err: underlyingErr}, N: 10}
853+
buf := make([]byte, 10)
854+
855+
n, err := lr.Read(buf)
856+
if n != 5 || string(buf[:5]) != "hello" || err != nil {
857+
t.Errorf("first Read() = (%d, %q, %v), want (5, \"hello\", nil)", n, buf[:5], err)
858+
}
859+
860+
n, err = lr.Read(buf)
861+
if n != 0 || err != underlyingErr {
862+
t.Errorf("second Read() = (%d, %v), want (0, %v)", n, err, underlyingErr)
863+
}
864+
})
865+
866+
t.Run("SentinelMasksProbeError", func(t *testing.T) {
867+
probeErr := errors.New("probe failed")
868+
lr := &LimitedReader{R: &errorReader{data: []byte("hello"), err: probeErr}, N: 5, Err: errLimit}
869+
buf := make([]byte, 10)
870+
871+
n, err := lr.Read(buf)
872+
if n != 5 || string(buf[:5]) != "hello" || err != nil {
873+
t.Errorf("first Read() = (%d, %q, %v), want (5, \"hello\", nil)", n, buf[:5], err)
874+
}
875+
876+
n, err = lr.Read(buf)
877+
if n != 0 || err != errLimit {
878+
t.Errorf("second Read() = (%d, %v), want (0, errLimit)", n, err)
879+
}
880+
})
881+
}
882+
883+
func TestLimitedReaderCopy(t *testing.T) {
884+
tests := []struct {
885+
name string
886+
input string
887+
limit int64
888+
wantN int64
889+
wantErr error
890+
}{
891+
{"Exact", "hello", 5, 5, nil},
892+
{"Under", "hi", 5, 2, nil},
893+
{"Over", "hello world", 5, 5, errLimit},
894+
}
895+
896+
for _, tt := range tests {
897+
t.Run(tt.name, func(t *testing.T) {
898+
lr := &LimitedReader{R: strings.NewReader(tt.input), N: tt.limit, Err: errLimit}
899+
var dst Buffer
900+
n, err := Copy(&dst, lr)
901+
if n != tt.wantN || err != tt.wantErr {
902+
t.Errorf("Copy() = (%d, %v), want (%d, %v)", n, err, tt.wantN, tt.wantErr)
903+
}
904+
})
905+
}
906+
}

0 commit comments

Comments
 (0)