Skip to content

Commit daf88a9

Browse files
authored
[http/proxy] Add helpers for proxying requests and responses (#743)
<!-- Copyright (C) 2020-2022 Arm Limited or its affiliates and Contributors. All rights reserved. SPDX-License-Identifier: Apache-2.0 --> ### Description Those helpers should ease proxying requests and responses and are inspired from the [lura project](https://github.com/luraproject/lura/blob/master/proxy/http.go) ### Test Coverage <!-- Please put an `x` in the correct box e.g. `[x]` to indicate the testing coverage of this change. --> - [x] This change is covered by existing or additional automated tests. - [ ] Manual testing has been performed (and evidence provided) as automated testing was not feasible. - [ ] Additional tests are not required for this change (e.g. documentation update).
1 parent 5610a29 commit daf88a9

File tree

4 files changed

+324
-1
lines changed

4 files changed

+324
-1
lines changed

changes/20251031150814.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:sparkles: `[http/proxy]` Add helpers for proxying requests and responses

utils/http/httptest/testing.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ import (
1919

2020
// NewTestServer creates a test server
2121
func NewTestServer(t *testing.T, ctx context.Context, handler http.Handler, port string) {
22+
t.Helper()
2223
list, err := net.Listen("tcp", fmt.Sprintf(":%v", port))
23-
require.Nil(t, err)
24+
require.NoError(t, err)
2425
srv := &http.Server{
2526
Handler: handler,
2627
ReadHeaderTimeout: time.Minute,

utils/http/proxy/proxy.go

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
package proxy
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"io"
7+
"net/http"
8+
"strconv"
9+
"strings"
10+
11+
"github.com/go-http-utils/headers"
12+
13+
"github.com/ARM-software/golang-utils/utils/commonerrors"
14+
httpheaders "github.com/ARM-software/golang-utils/utils/http/headers"
15+
"github.com/ARM-software/golang-utils/utils/reflection"
16+
"github.com/ARM-software/golang-utils/utils/safecast"
17+
"github.com/ARM-software/golang-utils/utils/safeio"
18+
)
19+
20+
// ProxyDisallowList describes headers which are not proxied back.
21+
var ProxyDisallowList = []string{
22+
headers.AccessControlAllowOrigin,
23+
headers.AccessControlAllowMethods,
24+
headers.AccessControlAllowHeaders,
25+
headers.AccessControlExposeHeaders,
26+
headers.AccessControlMaxAge,
27+
headers.AccessControlAllowCredentials,
28+
}
29+
30+
// ProxyRequest proxies a request to a new endpoint. The method can also be changed. Headers are sanitised during the process.
31+
func ProxyRequest(r *http.Request, proxyMethod, endpoint string) (proxiedRequest *http.Request, err error) {
32+
if reflection.IsEmpty(r) {
33+
err = commonerrors.UndefinedVariable("request to proxy")
34+
return
35+
}
36+
ctx := r.Context()
37+
// Note: It is important to know that an 0 or -1 content length does not mean there is no body. This is likely the case but it could also be because the body was never read and its size never assessed.
38+
contentLength := determineRequestContentLength(r)
39+
h := httpheaders.FromRequest(r).AllowList(headers.Authorization)
40+
if reflection.IsEmpty(proxyMethod) {
41+
proxyMethod = http.MethodGet
42+
}
43+
proxiedRequest, err = http.NewRequestWithContext(ctx, proxyMethod, endpoint, r.Body)
44+
if err != nil {
45+
err = commonerrors.WrapError(commonerrors.ErrUnexpected, err, "could not create a proxied request")
46+
return
47+
}
48+
49+
if proxiedRequest.ContentLength <= 0 {
50+
if proxiedRequest.Body == nil || proxiedRequest.Body == http.NoBody {
51+
if contentLength > 0 {
52+
// In this case, NewRequestWithContext does not understand/expect the request body type (not a string/byte buffer as it may be wrapped into a bigger structure) and so, the body of the proxied request is set to nil
53+
// This makes sure this does not happen without performing a copy of the body and the use of unnecessary memory.
54+
proxiedRequest.Body = r.Body
55+
proxiedRequest.GetBody = r.GetBody
56+
} else {
57+
// In this case, it will attempt a copy of the request body which should not be costly as the request is unlikely to have a body. Although it may still do as contentlength may not have actually been evaluated. However, we want to make sure it is set to the same type as the original request.
58+
proxiedRequest, err = http.NewRequestWithContext(ctx, proxyMethod, endpoint, convertBody(ctx, r.Body))
59+
if err != nil {
60+
err = commonerrors.WrapError(commonerrors.ErrUnexpected, err, "could not create a proxied request")
61+
return
62+
}
63+
}
64+
} else {
65+
// In this case, the original request is unlikely to have a body but we want to make sure that the body is of the same type.
66+
if contentLength <= 0 {
67+
proxiedRequest, err = http.NewRequestWithContext(ctx, proxyMethod, endpoint, convertBody(ctx, r.Body))
68+
if err != nil {
69+
err = commonerrors.WrapError(commonerrors.ErrUnexpected, err, "could not create a proxied request")
70+
return
71+
}
72+
}
73+
}
74+
if contentLength > 0 && proxiedRequest.ContentLength <= 0 {
75+
proxiedRequest.ContentLength = contentLength
76+
h.AppendHeader(headers.ContentLength, strconv.FormatInt(contentLength, 10))
77+
}
78+
}
79+
if contentLength > 0 && contentLength != proxiedRequest.ContentLength {
80+
err = commonerrors.Newf(commonerrors.ErrUnexpected, "proxied request does not have the same content length `%v` as original request `%v`", proxiedRequest.ContentLength, contentLength)
81+
return
82+
}
83+
h.AppendToRequest(proxiedRequest)
84+
return
85+
}
86+
87+
func determineRequestContentLength(r *http.Request) int64 {
88+
if reflection.IsEmpty(r) {
89+
return -1
90+
}
91+
if r.ContentLength > 0 {
92+
return r.ContentLength
93+
}
94+
// Following what was done in https://github.com/luraproject/lura/blob/b9ad9ab654dd6149aeb58a5d6ffe731aba41717e/proxy/http.go#L99C1-L105C4
95+
v := r.Header.Values(headers.ContentLength)
96+
if len(v) == 1 && v[0] != "chunked" {
97+
if size, err := strconv.Atoi(v[0]); err == nil {
98+
return safecast.ToInt64(size)
99+
}
100+
}
101+
return -1
102+
}
103+
104+
func convertBody(_ context.Context, body io.Reader) io.Reader {
105+
if body == nil || body == http.NoBody {
106+
return http.NoBody
107+
}
108+
switch v := body.(type) {
109+
case *bytes.Buffer:
110+
return body
111+
case *bytes.Reader:
112+
return body
113+
case *strings.Reader:
114+
return body
115+
default:
116+
// see example https://github.com/luraproject/lura/blob/b9ad9ab654dd6149aeb58a5d6ffe731aba41717e/proxy/http.go#L73
117+
buf := new(bytes.Buffer)
118+
_, err := buf.ReadFrom(v)
119+
if err != nil {
120+
return http.NoBody
121+
}
122+
if b, ok := body.(io.ReadCloser); ok {
123+
_ = b.Close()
124+
}
125+
return buf
126+
}
127+
}
128+
129+
// ProxyResponse proxies a response to a writer. Headers are sanitised and some headers such as CORS headers will be removed from the response.
130+
func ProxyResponse(ctx context.Context, resp *http.Response, w http.ResponseWriter) (err error) {
131+
if w == nil {
132+
err = commonerrors.UndefinedVariable("response writer")
133+
return
134+
}
135+
if reflection.IsEmpty(resp) {
136+
err = commonerrors.UndefinedVariable("response")
137+
return
138+
}
139+
h := httpheaders.FromResponse(resp)
140+
h.Sanitise()
141+
142+
var written int64
143+
_, err = safeio.CopyDataWithContext(ctx, resp.Body, w)
144+
if resp.Body != nil && resp.Body != http.NoBody {
145+
written, err = safeio.CopyDataWithContext(ctx, resp.Body, w)
146+
if err != nil {
147+
err = commonerrors.DescribeCircumstance(err, "failed copying response body")
148+
}
149+
}
150+
if written >= 0 {
151+
h.AppendHeader(headers.ContentLength, strconv.FormatInt(written, 10))
152+
}
153+
h.RemoveHeaders(ProxyDisallowList...)
154+
h.AppendToResponse(w)
155+
w.WriteHeader(resp.StatusCode)
156+
return
157+
}

utils/http/proxy/proxy_test.go

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
package proxy
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"io"
7+
"net/http"
8+
"net/http/httptest"
9+
"strconv"
10+
"strings"
11+
"testing"
12+
13+
"github.com/go-faker/faker/v4"
14+
"github.com/go-http-utils/headers"
15+
"github.com/stretchr/testify/assert"
16+
"github.com/stretchr/testify/require"
17+
18+
"github.com/ARM-software/golang-utils/utils/commonerrors"
19+
"github.com/ARM-software/golang-utils/utils/commonerrors/errortest"
20+
"github.com/ARM-software/golang-utils/utils/safecast"
21+
"github.com/ARM-software/golang-utils/utils/safeio"
22+
)
23+
24+
func TestProxy(t *testing.T) {
25+
content := faker.Paragraph()
26+
path := faker.URL()
27+
password := faker.Password()
28+
tests := []struct {
29+
request *http.Request
30+
}{
31+
{
32+
request: httptest.NewRequest(http.MethodGet, faker.URL(), io.NopCloser(strings.NewReader(content))),
33+
},
34+
{
35+
request: httptest.NewRequest(http.MethodGet, faker.URL(), strings.NewReader(content)),
36+
},
37+
{
38+
request: httptest.NewRequest(http.MethodGet, faker.URL(), io.NopCloser(bytes.NewReader([]byte(content)))),
39+
},
40+
{
41+
request: httptest.NewRequest(http.MethodGet, faker.URL(), bytes.NewReader([]byte(content))),
42+
},
43+
{
44+
request: httptest.NewRequest(http.MethodGet, faker.URL(), io.NopCloser(bytes.NewBuffer([]byte(content)))),
45+
},
46+
{
47+
request: httptest.NewRequest(http.MethodGet, faker.URL(), bytes.NewBuffer([]byte(content))),
48+
},
49+
}
50+
for i := range tests {
51+
test := tests[i]
52+
t.Run(strconv.Itoa(i), func(t *testing.T) {
53+
req := test.request
54+
req.Header.Set(headers.AccessControlAllowOrigin, faker.Word())
55+
req.Header.Set(headers.XHTTPMethodOverride, http.MethodPut)
56+
req.Header.Set(headers.Authorization, password)
57+
assert.NotEqual(t, req.URL.String(), path)
58+
_, err := ProxyRequest(nil, http.MethodPost, "/")
59+
errortest.AssertError(t, err, commonerrors.ErrUndefined)
60+
preq, err := ProxyRequest(req, " ", path)
61+
require.NoError(t, err)
62+
require.NotNil(t, preq)
63+
assert.Equal(t, path, preq.URL.String())
64+
assert.Equal(t, http.MethodGet, preq.Method)
65+
assert.NotEmpty(t, preq.Header.Get(headers.AccessControlAllowOrigin))
66+
assert.NotEmpty(t, preq.Header.Get(headers.Authorization))
67+
assert.NotZero(t, preq.ContentLength)
68+
resp := generateTestResponseBasedOnRequest(t, preq)
69+
defer func() {
70+
if resp != nil {
71+
_ = resp.Body.Close()
72+
}
73+
}()
74+
w := httptest.NewRecorder()
75+
require.NoError(t, ProxyResponse(context.Background(), resp, w))
76+
proxiedResp := w.Result()
77+
defer func() { _ = proxiedResp.Body.Close() }()
78+
assert.Empty(t, w.Header().Get(headers.AccessControlAllowOrigin))
79+
assert.Equal(t, http.MethodPut, w.Header().Get(headers.XHTTPMethodOverride))
80+
assert.Equal(t, http.StatusOK, resp.StatusCode)
81+
responseContent, err := safeio.ReadAll(context.Background(), proxiedResp.Body)
82+
require.NoError(t, err)
83+
assert.Equal(t, content, string(responseContent))
84+
})
85+
}
86+
}
87+
88+
func TestEmptyResponse(t *testing.T) {
89+
path := faker.URL()
90+
tests := []struct {
91+
request *http.Request
92+
}{
93+
{
94+
httptest.NewRequest(http.MethodGet, faker.URL(), nil),
95+
},
96+
{
97+
request: httptest.NewRequest(http.MethodGet, faker.URL(), http.NoBody),
98+
},
99+
{
100+
request: httptest.NewRequest(http.MethodGet, faker.URL(), io.NopCloser(http.NoBody)),
101+
},
102+
{
103+
request: httptest.NewRequest(http.MethodGet, faker.URL(), bytes.NewReader(nil)),
104+
},
105+
{
106+
request: httptest.NewRequest(http.MethodGet, faker.URL(), io.NopCloser(bytes.NewBuffer(nil))),
107+
},
108+
{
109+
request: httptest.NewRequest(http.MethodGet, faker.URL(), strings.NewReader("")),
110+
},
111+
{
112+
request: httptest.NewRequest(http.MethodGet, faker.URL(), io.NopCloser(strings.NewReader(""))),
113+
},
114+
}
115+
for i := range tests {
116+
test := tests[i]
117+
t.Run(strconv.Itoa(i), func(t *testing.T) {
118+
req := test.request
119+
assert.NotEqual(t, req.URL.String(), path)
120+
preq, err := ProxyRequest(req, http.MethodPost, path)
121+
require.NoError(t, err)
122+
require.NotNil(t, preq)
123+
assert.Equal(t, path, preq.URL.String())
124+
assert.Equal(t, http.MethodPost, preq.Method)
125+
assert.Zero(t, preq.ContentLength)
126+
127+
resp := generateTestResponseBasedOnRequest(t, preq)
128+
defer func() {
129+
if resp != nil {
130+
_ = resp.Body.Close()
131+
}
132+
}()
133+
w := httptest.NewRecorder()
134+
require.NoError(t, ProxyResponse(context.Background(), resp, w))
135+
require.NoError(t, err)
136+
returnedResp := w.Result()
137+
assert.LessOrEqual(t, returnedResp.ContentLength, safecast.ToInt64(0))
138+
assert.Equal(t, http.StatusOK, returnedResp.StatusCode)
139+
})
140+
}
141+
}
142+
143+
func loopTestHandler(t *testing.T, w http.ResponseWriter, r *http.Request) {
144+
t.Helper()
145+
require.NotNil(t, r)
146+
require.NotNil(t, w)
147+
for k, v := range r.Header {
148+
for h := range v {
149+
w.Header().Add(k, v[h])
150+
}
151+
}
152+
written, err := safeio.CopyDataWithContext(r.Context(), r.Body, w)
153+
require.NoError(t, err)
154+
w.Header().Add(headers.ContentLength, strconv.FormatInt(written, 10))
155+
w.WriteHeader(http.StatusOK)
156+
}
157+
158+
func generateTestResponseBasedOnRequest(t *testing.T, r *http.Request) *http.Response {
159+
t.Helper()
160+
require.NotNil(t, r)
161+
w := httptest.NewRecorder()
162+
loopTestHandler(t, w, r)
163+
return w.Result()
164+
}

0 commit comments

Comments
 (0)