Skip to content

Commit a7770cb

Browse files
safeio Add support for cancelling readers that make blocking kernel reads during copying (#703)
<!-- Copyright (C) 2020-2022 Arm Limited or its affiliates and Contributors. All rights reserved. SPDX-License-Identifier: Apache-2.0 --> ### Description <!-- Please add any detail or context that would be useful to a reviewer. --> Add support for cancelling readers that make blocking kernel reads during copying to safeio submodule. When copying, a lot of the time things were in the kernel read and therefore since the cancel in the other contextual readers only cancel before a read, they weren't be cancelled. This allows readclosers to be used and makes it so that context cancellation will call their close functions which will stop the kernel read. https://man7.org/linux/man-pages/man2/read.2.html ### Test Coverage <!-- Please put an `x` in the correct box e.g. `[x]` to indicate the testing coverage of this change. --> - [x] This change is covered by existing or additional automated tests. - [ ] Manual testing has been performed (and evidence provided) as automated testing was not feasible. - [ ] Additional tests are not required for this change (e.g. documentation update).
1 parent 7f6a579 commit a7770cb

File tree

6 files changed

+272
-0
lines changed

6 files changed

+272
-0
lines changed

changes/20250909150027.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:sparkles: `safeio` Add support for cancelling readers that make blocking kernel reads during copying

utils/safeio/copy.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,18 @@ func Cat(ctx context.Context, dst io.Writer, src ...io.Reader) (copied int64, er
2828
return CopyDataWithContext(ctx, NewContextualMultipleReader(ctx, src...), dst)
2929
}
3030

31+
// SafeCopyDataWithContext copies from src to dst similarly to io.Copy but with context control to stop when asked.
32+
// Unlike CopyWithContext it requires a ReadCloser, this allows it to stop even if the system is doing a kernel read.
33+
func SafeCopyDataWithContext(ctx context.Context, src io.ReadCloser, dst io.Writer) (copied int64, err error) {
34+
return safeCopyDataWithContext(ctx, src, dst, func(dst io.Writer, src io.ReadCloser) (int64, error) { return io.Copy(dst, src) })
35+
}
36+
37+
// SafeCopyNWithContext copies n bytes from src to dst similarly to io.CopyN but with context control to stop when asked.
38+
// Unlike CopyNWithContext it requires a ReadCloser, this allows it to stop even if the system is doing a kernel read.
39+
func SafeCopyNWithContext(ctx context.Context, src io.ReadCloser, dst io.Writer, n int64) (copied int64, err error) {
40+
return safeCopyDataWithContext(ctx, src, dst, func(dst io.Writer, src io.ReadCloser) (int64, error) { return io.CopyN(dst, src, n) })
41+
}
42+
3143
func copyDataWithContext(ctx context.Context, src io.Reader, dst io.Writer, copyFunc func(io.Writer, io.Reader) (int64, error)) (copied int64, err error) {
3244
err = parallelisation.DetermineContextError(ctx)
3345
if err != nil {
@@ -37,8 +49,23 @@ func copyDataWithContext(ctx context.Context, src io.Reader, dst io.Writer, copy
3749
return
3850
}
3951

52+
func safeCopyDataWithContext(ctx context.Context, src io.ReadCloser, dst io.Writer, copyFunc func(io.Writer, io.ReadCloser) (int64, error)) (copied int64, err error) {
53+
err = parallelisation.DetermineContextError(ctx)
54+
if err != nil {
55+
return
56+
}
57+
copied, err = reallySafeCopy(ContextualWriter(ctx, dst), NewContextualReadCloser(ctx, src), copyFunc)
58+
return
59+
}
60+
4061
func safeCopy(w io.Writer, r io.Reader, iocopyFunc func(io.Writer, io.Reader) (int64, error)) (int64, error) {
4162
copied, err := iocopyFunc(w, r)
4263
err = ConvertIOError(err)
4364
return copied, err
4465
}
66+
67+
func reallySafeCopy(w io.Writer, r io.ReadCloser, iocopyFunc func(io.Writer, io.ReadCloser) (int64, error)) (int64, error) {
68+
copied, err := iocopyFunc(w, r)
69+
err = ConvertIOError(err)
70+
return copied, err
71+
}

utils/safeio/copy_test.go

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@ package safeio
33
import (
44
"bytes"
55
"context"
6+
"io"
7+
"os"
68
"testing"
9+
"time"
710

811
"github.com/go-faker/faker/v4"
912
"github.com/stretchr/testify/assert"
1013
"github.com/stretchr/testify/require"
14+
"go.uber.org/goleak"
1115

1216
"github.com/ARM-software/golang-utils/utils/commonerrors"
1317
"github.com/ARM-software/golang-utils/utils/commonerrors/errortest"
@@ -81,6 +85,135 @@ func TestCopyNWithContext(t *testing.T) {
8185
assert.Equal(t, safecast.ToInt64(len(text)-1), n2)
8286
}
8387

88+
func TestSafeCopyDataWithContext(t *testing.T) {
89+
defer goleak.VerifyNone(t)
90+
var buf1, buf2 bytes.Buffer
91+
text := faker.Sentence()
92+
n, err := WriteString(context.Background(), &buf1, text)
93+
require.NoError(t, err)
94+
require.NotZero(t, n)
95+
assert.Equal(t, len(text), n)
96+
rc := io.NopCloser(bytes.NewReader(buf1.Bytes())) // make it an io.ReadCloser
97+
n2, err := SafeCopyDataWithContext(context.Background(), rc, &buf2)
98+
require.NoError(t, err)
99+
require.NotZero(t, n2)
100+
assert.Equal(t, safecast.ToInt64(len(text)), n2)
101+
assert.Equal(t, text, buf2.String())
102+
103+
ctx, cancel := context.WithCancel(context.Background())
104+
buf1.Reset()
105+
buf2.Reset()
106+
n, err = WriteString(context.Background(), &buf1, text)
107+
require.NoError(t, err)
108+
require.NotZero(t, n)
109+
assert.Equal(t, len(text), n)
110+
111+
cancel()
112+
rc = io.NopCloser(bytes.NewReader(buf1.Bytes()))
113+
n2, err = SafeCopyDataWithContext(ctx, rc, &buf2)
114+
require.Error(t, err)
115+
errortest.AssertError(t, err, commonerrors.ErrCancelled)
116+
assert.Zero(t, n2)
117+
assert.Empty(t, buf2.String())
118+
119+
r, w, err := os.Pipe()
120+
require.NoError(t, err)
121+
defer func() { _ = w.Close() }()
122+
ctx2, unblock := context.WithCancel(context.Background())
123+
done := make(chan struct{})
124+
125+
go func() {
126+
_, errCopy := SafeCopyDataWithContext(ctx2, r, io.Discard)
127+
_ = r.Close()
128+
_ = errCopy
129+
close(done)
130+
}()
131+
132+
time.Sleep(50 * time.Millisecond) // let it enter read(2) https://man7.org/linux/man-pages/man2/read.2.html
133+
unblock()
134+
135+
select {
136+
case <-done:
137+
// Expected case: unblocked
138+
case <-time.After(2 * time.Second):
139+
assert.FailNow(t, "context cancel should have unblocked copy")
140+
}
141+
}
142+
143+
func TestSafeCopyNWithContext(t *testing.T) {
144+
defer goleak.VerifyNone(t)
145+
var buf1, buf2 bytes.Buffer
146+
text := faker.Sentence()
147+
n, err := WriteString(context.Background(), &buf1, text)
148+
require.NoError(t, err)
149+
require.NotZero(t, n)
150+
assert.Equal(t, len(text), n)
151+
rc := io.NopCloser(bytes.NewReader(buf1.Bytes()))
152+
n2, err := SafeCopyNWithContext(context.Background(), rc, &buf2, safecast.ToInt64(len(text)))
153+
require.NoError(t, err)
154+
require.NotZero(t, n2)
155+
assert.Equal(t, safecast.ToInt64(len(text)), n2)
156+
assert.Equal(t, text, buf2.String())
157+
158+
ctx, cancel := context.WithCancel(context.Background())
159+
160+
buf1.Reset()
161+
buf2.Reset()
162+
n, err = WriteString(context.Background(), &buf1, text)
163+
require.NoError(t, err)
164+
require.NotZero(t, n)
165+
assert.Equal(t, len(text), n)
166+
167+
cancel()
168+
rc = io.NopCloser(bytes.NewReader(buf1.Bytes()))
169+
n2, err = SafeCopyNWithContext(ctx, rc, &buf2, safecast.ToInt64(len(text)))
170+
require.Error(t, err)
171+
errortest.AssertError(t, err, commonerrors.ErrCancelled)
172+
assert.Zero(t, n2)
173+
assert.Empty(t, buf2.String())
174+
175+
buf1.Reset()
176+
buf2.Reset()
177+
n, err = WriteString(context.Background(), &buf1, text)
178+
require.NoError(t, err)
179+
require.NotZero(t, n)
180+
rc = io.NopCloser(bytes.NewReader(buf1.Bytes()))
181+
182+
wantN := safecast.ToInt64(len(text) - 1)
183+
n2, err = SafeCopyNWithContext(context.Background(), rc, &buf2, wantN)
184+
require.NoError(t, err)
185+
require.NotZero(t, n2)
186+
assert.Equal(t, wantN, n2)
187+
assert.Equal(t, text[:len(text)-1], buf2.String())
188+
189+
r, w, err := os.Pipe()
190+
require.NoError(t, err)
191+
defer func() { _ = w.Close() }()
192+
ctx2, unblock := context.WithCancel(context.Background())
193+
done := make(chan struct{})
194+
var (
195+
copied int64
196+
copyErr error
197+
)
198+
199+
go func() {
200+
copied, copyErr = SafeCopyNWithContext(ctx2, r, io.Discard, 1024) // nothing to read means it blocks
201+
_ = r.Close()
202+
close(done)
203+
}()
204+
205+
time.Sleep(50 * time.Millisecond) // let it enter read(2) https://man7.org/linux/man-pages/man2/read.2.html
206+
unblock()
207+
208+
select {
209+
case <-done:
210+
errortest.AssertError(t, copyErr, commonerrors.ErrCancelled)
211+
assert.Zero(t, copied)
212+
case <-time.After(2 * time.Second):
213+
assert.FailNow(t, "context cancel should have unblocked copy")
214+
}
215+
}
216+
84217
func TestCat(t *testing.T) {
85218
var buf1, buf2, buf3 bytes.Buffer
86219
text1 := faker.Sentence()

utils/safeio/error.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package safeio
22

33
import (
44
"io"
5+
"os"
56

67
"github.com/ARM-software/golang-utils/utils/commonerrors"
78
)
@@ -16,6 +17,9 @@ func ConvertIOError(err error) (newErr error) {
1617
case commonerrors.Any(newErr, commonerrors.ErrEOF):
1718
case commonerrors.Any(newErr, io.EOF, io.ErrUnexpectedEOF):
1819
newErr = commonerrors.WrapError(commonerrors.ErrEOF, newErr, "")
20+
case commonerrors.Any(newErr, os.ErrClosed):
21+
// cancelling a reader on a copy will cause it to close the file and return os.ErrClosed so map it to cancelled for this package
22+
newErr = commonerrors.WrapError(commonerrors.ErrCancelled, newErr, "")
1923
}
2024
return
2125
}

utils/safeio/read.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"io"
77

88
"github.com/dolmen-go/contextio"
9+
"go.uber.org/atomic"
910

1011
"github.com/ARM-software/golang-utils/utils/commonerrors"
1112
"github.com/ARM-software/golang-utils/utils/parallelisation"
@@ -76,6 +77,42 @@ func NewContextualReader(ctx context.Context, reader io.Reader) io.Reader {
7677
return contextio.NewReader(ctx, reader)
7778
}
7879

80+
type safeReadCloser struct {
81+
reader io.Reader // use reader to ensure idempotency since you can't call close on the reader itself, only via the wrapper
82+
close parallelisation.CloseFunc
83+
closed *atomic.Bool
84+
}
85+
86+
func (r safeReadCloser) Read(p []byte) (int, error) {
87+
return r.reader.Read(p)
88+
}
89+
90+
func (r safeReadCloser) Close() error {
91+
if r.closed.Swap(true) {
92+
return nil
93+
}
94+
95+
return r.close()
96+
}
97+
98+
// NewContextualReadCloser returns a readcloser which is context aware.
99+
// Context state is checked during the read and close is called if the context is cancelled
100+
// This allows for readers that block on syscalls to be stopped via a context
101+
func NewContextualReadCloser(ctx context.Context, reader io.ReadCloser) io.ReadCloser {
102+
stop := context.AfterFunc(ctx, func() { _ = reader.Close() })
103+
104+
r := safeReadCloser{
105+
reader: contextio.NewReader(ctx, reader),
106+
close: func() error {
107+
_ = stop()
108+
return nil
109+
},
110+
closed: atomic.NewBool(false),
111+
}
112+
113+
return r
114+
}
115+
79116
func NewContextualMultipleReader(ctx context.Context, reader ...io.Reader) io.Reader {
80117
readers := make([]io.Reader, len(reader))
81118
for i := range reader {

utils/safeio/read_closer_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package safeio
2+
3+
import (
4+
"context"
5+
"io"
6+
"os"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
"go.uber.org/goleak"
13+
)
14+
15+
func TestNewContextualReadCloser(t *testing.T) {
16+
t.Run("Normal contextual reader blocks even after cancel", func(t *testing.T) {
17+
defer goleak.VerifyNone(t)
18+
19+
r, w, err := os.Pipe()
20+
require.NoError(t, err)
21+
defer func() { _ = r.Close(); _ = w.Close() }()
22+
23+
ctx, cancel := context.WithCancel(context.Background())
24+
reader := NewContextualReader(ctx, r)
25+
26+
done := make(chan struct{})
27+
go func() {
28+
_, _ = io.Copy(io.Discard, reader) // will block in read(2) https://man7.org/linux/man-pages/man2/read.2.html
29+
close(done)
30+
}()
31+
32+
// Allow io.Copy to enter kernel read then try to cancel
33+
time.Sleep(50 * time.Millisecond)
34+
cancel()
35+
36+
select {
37+
case <-done:
38+
assert.FailNow(t, "cancelling context shouldn't unblock a blocking Read in io.Copy")
39+
case <-time.After(200 * time.Millisecond):
40+
// Expected case: still blocked
41+
}
42+
})
43+
44+
t.Run("Contextual read closer does not block even on long running copies", func(t *testing.T) {
45+
defer goleak.VerifyNone(t)
46+
47+
r, w, err := os.Pipe()
48+
require.NoError(t, err)
49+
defer func() { _ = w.Close() }()
50+
51+
ctx, cancel := context.WithCancel(context.Background())
52+
rc := NewContextualReadCloser(ctx, r)
53+
54+
done := make(chan struct{})
55+
go func() {
56+
_, _ = io.Copy(io.Discard, rc) // will block in read(2) https://man7.org/linux/man-pages/man2/read.2.html
57+
close(done)
58+
}()
59+
60+
time.Sleep(50 * time.Millisecond)
61+
cancel()
62+
63+
select {
64+
case <-done:
65+
// Expected case: successfully unblocked
66+
case <-time.After(2 * time.Second):
67+
assert.FailNow(t, "copy should have been unblocked by context cancel")
68+
}
69+
})
70+
}

0 commit comments

Comments
 (0)