mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
fix(reverseproxy): align forwarding and tunnel semantics
This commit is contained in:
parent
c019f24e99
commit
ed44c592d3
6 changed files with 864 additions and 26 deletions
|
|
@ -2,6 +2,7 @@ package touka
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
|
@ -70,7 +71,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
|||
engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
|
||||
Target: target,
|
||||
ForwardedHeaders: ForwardedBoth,
|
||||
ForwardedBy: "proxy-node",
|
||||
ForwardedBy: "_proxy-node",
|
||||
Via: "proxy.test",
|
||||
}))
|
||||
|
||||
|
|
@ -144,7 +145,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
|||
if !strings.Contains(got.Forwarded, "for=198.51.100.10") {
|
||||
t.Fatalf("forwarded header missing client ip: %q", got.Forwarded)
|
||||
}
|
||||
if !strings.Contains(got.Forwarded, "by=proxy-node") {
|
||||
if !strings.Contains(got.Forwarded, "by=_proxy-node") {
|
||||
t.Fatalf("forwarded header missing by token: %q", got.Forwarded)
|
||||
}
|
||||
if !strings.Contains(got.Forwarded, "host=client.example") {
|
||||
|
|
@ -170,6 +171,61 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyRejectsInvalidForwardedBy(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, "http://example.com"),
|
||||
ForwardedHeaders: ForwardedBoth,
|
||||
ForwardedBy: "proxy-node",
|
||||
}))
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if rr.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyForwardedByTrimsWhitespace(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
forwardedCh := make(chan string, 1)
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
forwardedCh <- r.Header.Get("Forwarded")
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, backend.URL),
|
||||
ForwardedHeaders: ForwardedBoth,
|
||||
ForwardedBy: " _proxy-node ",
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
|
||||
req.RemoteAddr = "198.51.100.10:4567"
|
||||
rr := httptest.NewRecorder()
|
||||
engine.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusNoContent {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
|
||||
select {
|
||||
case forwarded := <-forwardedCh:
|
||||
if !strings.Contains(forwarded, "by=_proxy-node") {
|
||||
t.Fatalf("unexpected Forwarded header: %q", forwarded)
|
||||
}
|
||||
if strings.Contains(forwarded, `by=" _proxy-node "`) {
|
||||
t.Fatalf("forwarded header should not preserve surrounding whitespace: %q", forwarded)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for backend Forwarded header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyDefaultViaFallback(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -229,6 +285,23 @@ func TestReverseProxyCustomErrorHandler(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, "http://example.com"),
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return nil, context.DeadlineExceeded
|
||||
}),
|
||||
}))
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if rr.Code != http.StatusGatewayTimeout {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -452,6 +525,362 @@ func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyUpgradeNeedsHijacker(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
t.Fatal("backend response writer does not support hijack")
|
||||
}
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
t.Fatalf("backend hijack failed: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
||||
_ = brw.Flush()
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://client.example/ws", nil)
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
rr := httptest.NewRecorder()
|
||||
engine.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusNotImplemented {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyMaxForwardsTraceHandledLocally(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
called := make(chan struct{}, 1)
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called <- struct{}{}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil)
|
||||
req.RequestURI = "/trace"
|
||||
req.Header.Set("Max-Forwards", "0")
|
||||
req.Header.Set("Authorization", "secret")
|
||||
req.Header.Set("Cookie", "a=b")
|
||||
req.Header.Set("Forwarded", "for=192.0.2.1")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
engine.ServeHTTP(rr, req)
|
||||
|
||||
resp := rr.Result()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read body: %v", err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
if got := resp.Header.Get("Content-Type"); got != "message/http" {
|
||||
t.Fatalf("unexpected content type: %q", got)
|
||||
}
|
||||
if !strings.Contains(string(body), "TRACE /trace HTTP/1.1") {
|
||||
t.Fatalf("trace body missing request line: %q", string(body))
|
||||
}
|
||||
if strings.Contains(string(body), "Authorization:") {
|
||||
t.Fatalf("trace body leaked authorization header: %q", string(body))
|
||||
}
|
||||
if strings.Contains(string(body), "Cookie:") {
|
||||
t.Fatalf("trace body leaked cookie header: %q", string(body))
|
||||
}
|
||||
if strings.Contains(string(body), "Forwarded:") {
|
||||
t.Fatalf("trace body leaked forwarded header: %q", string(body))
|
||||
}
|
||||
|
||||
select {
|
||||
case <-called:
|
||||
t.Fatal("backend should not be called when Max-Forwards is zero")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyMaxForwardsTraceDecrementsBeforeForwarding(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
maxForwardsCh := make(chan string, 1)
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
maxForwardsCh <- r.Header.Get("Max-Forwards")
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil)
|
||||
req.Header.Set("Max-Forwards", "2")
|
||||
rr := httptest.NewRecorder()
|
||||
engine.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusNoContent {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-maxForwardsCh:
|
||||
if got != "1" {
|
||||
t.Fatalf("unexpected Max-Forwards header: %q", got)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for backend Max-Forwards")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
called := make(chan struct{}, 1)
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called <- struct{}{}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", func(c *Context) { c.Status(http.StatusNoContent) })
|
||||
engine.OPTIONS("/proxy", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "http://client.example/proxy", nil)
|
||||
req.Header.Set("Max-Forwards", "0")
|
||||
rr := httptest.NewRecorder()
|
||||
engine.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
allow := rr.Header().Get("Allow")
|
||||
if !strings.Contains(allow, http.MethodGet) || !strings.Contains(allow, http.MethodOptions) {
|
||||
t.Fatalf("unexpected Allow header: %q", allow)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-called:
|
||||
t.Fatal("backend should not be called when Max-Forwards is zero")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
engine := New()
|
||||
engine.OPTIONS("/", func(c *Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "http://client.example/", nil)
|
||||
req.RequestURI = "*"
|
||||
req.URL.Path = ""
|
||||
req.URL.RawPath = ""
|
||||
rr := httptest.NewRecorder()
|
||||
engine.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusNotFound {
|
||||
t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyConnectTunnel(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendAddr := ""
|
||||
errCh := make(chan error, 4)
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
errCh <- fmt.Errorf("unexpected method: %s", r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if got, want := r.RequestURI, backendAddr; got != want {
|
||||
errCh <- fmt.Errorf("unexpected CONNECT target %q, want %q", got, want)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
errCh <- errors.New("backend response writer does not support hijack")
|
||||
return
|
||||
}
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("backend hijack failed: %w", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\nVia: 1.1 upstream\r\n\r\n")
|
||||
if err := brw.Flush(); err != nil {
|
||||
errCh <- fmt.Errorf("backend flush failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
line, err := brw.ReadString('\n')
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("backend read failed: %w", err)
|
||||
return
|
||||
}
|
||||
_, _ = io.WriteString(brw, strings.ToUpper(line))
|
||||
if err := brw.Flush(); err != nil {
|
||||
errCh <- fmt.Errorf("backend write failed: %w", err)
|
||||
return
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
backendAddr = strings.TrimPrefix(backend.URL, "http://")
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/:authority", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, backend.URL),
|
||||
Via: "proxy.test",
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial proxy: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
t.Fatalf("set deadline: %v", err)
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintf(conn, "CONNECT origin.example:443 HTTP/1.1\r\nHost: origin.example:443\r\n\r\n")
|
||||
if err != nil {
|
||||
t.Fatalf("write connect request: %v", err)
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
statusLine, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("read status line: %v", err)
|
||||
}
|
||||
if !strings.Contains(statusLine, "200") {
|
||||
t.Fatalf("unexpected status line: %q", statusLine)
|
||||
}
|
||||
|
||||
headers, err := textproto.NewReader(reader).ReadMIMEHeader()
|
||||
if err != nil {
|
||||
t.Fatalf("read headers: %v", err)
|
||||
}
|
||||
respHeader := http.Header(headers)
|
||||
if got := respHeader.Get("Content-Length"); got != "" {
|
||||
t.Fatalf("CONNECT response should not include Content-Length, got %q", got)
|
||||
}
|
||||
if got := respHeader.Get("Transfer-Encoding"); got != "" {
|
||||
t.Fatalf("CONNECT response should not include Transfer-Encoding, got %q", got)
|
||||
}
|
||||
if gotVia := respHeader.Values("Via"); len(gotVia) != 2 || gotVia[0] != "1.1 upstream" || gotVia[1] != "1.1 proxy.test" {
|
||||
t.Fatalf("unexpected Via response header: %#v", gotVia)
|
||||
}
|
||||
|
||||
if _, err := io.WriteString(conn, "ping\n"); err != nil {
|
||||
t.Fatalf("write tunneled payload: %v", err)
|
||||
}
|
||||
message, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("read tunneled payload: %v", err)
|
||||
}
|
||||
if message != "PING\n" {
|
||||
t.Fatalf("unexpected tunneled payload: %q", message)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyConnectNeedsHijacker(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
t.Fatal("backend response writer does not support hijack")
|
||||
}
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
t.Fatalf("backend hijack failed: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\n\r\n")
|
||||
_ = brw.Flush()
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/tunnel", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodConnect, "http://client.example/tunnel", nil)
|
||||
req.URL.Path = "/tunnel"
|
||||
req.RequestURI = "/tunnel"
|
||||
rr := httptest.NewRecorder()
|
||||
engine.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusNotImplemented {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, "http://example.com"),
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/plain"},
|
||||
},
|
||||
Body: &failingReadCloser{chunks: []string{"ok"}, err: errors.New("boom")},
|
||||
ContentLength: -1,
|
||||
Request: req,
|
||||
}, nil
|
||||
}),
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
resp, err := proxy.Client().Get(proxy.URL + "/proxy")
|
||||
if err != nil {
|
||||
t.Fatalf("perform request: %v", err)
|
||||
}
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if err == nil {
|
||||
t.Fatal("expected body read to fail after upstream copy error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -568,3 +997,21 @@ func mustParseURL(t *testing.T, raw string) *url.URL {
|
|||
}
|
||||
return u
|
||||
}
|
||||
|
||||
type failingReadCloser struct {
|
||||
chunks []string
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *failingReadCloser) Read(p []byte) (int, error) {
|
||||
if len(r.chunks) == 0 {
|
||||
return 0, r.err
|
||||
}
|
||||
n := copy(p, r.chunks[0])
|
||||
r.chunks = r.chunks[1:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *failingReadCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue