diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 15ebafd..7d05290 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -68,6 +68,7 @@ type ReverseProxyConfig struct { Transport http.RoundTripper FlushInterval time.Duration BufferPool BufferPool + AllowH2CUpstream bool ModifyRequest func(*http.Request) ModifyResponse func(*http.Response) error @@ -191,6 +192,24 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ })) ``` +### `AllowH2CUpstream` + +允许代理使用未加密 HTTP/2(h2c)与 `http://` upstream 通信。 + +- 默认关闭 +- 这是一个显式配置项 +- 启用后,Touka 会为该 upstream 使用 h2c prior-knowledge 方式连接上游 +- 这意味着上游本身也必须显式支持 h2c;它不是“先试 h2c,失败再自动回退到 h1”的协商模式 + +```go +r.GET("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + AllowH2CUpstream: true, +})) +``` + +对于下游 HTTP/2 extended `CONNECT` websocket 场景,Touka 会只在该特殊桥接路径上强制与上游使用 HTTP/1.1 websocket upgrade,以匹配 Caddy 风格的桥接语义;普通 HTTP 请求不会因为这个特性而被强制降级为 HTTP/1.1。 + ### `Transport` 可选。用于自定义底层转发所使用的 `http.RoundTripper`。 diff --git a/http2xconnect.go b/http2xconnect.go index b3b12a0..872f5b3 100644 --- a/http2xconnect.go +++ b/http2xconnect.go @@ -5,12 +5,8 @@ package touka import ( - "context" "crypto/tls" - "net" "net/http" - "net/url" - "strings" "sync" _ "unsafe" @@ -36,18 +32,36 @@ func configureHTTP2ExtendedConnectServer(srv *http.Server) error { return http2.ConfigureServer(srv, nil) } -func newHTTP2ExtendedConnectTransport(target *url.URL) http.RoundTripper { +func newHTTP2ExtendedConnectTransport() http.RoundTripper { enableHTTP2ExtendedConnectProtocol() + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Protocols = new(http.Protocols) + transport.Protocols.SetHTTP1(true) + transport.Protocols.SetHTTP2(true) + return transport +} - transport := &http2.Transport{} - if target == nil || !strings.EqualFold(target.Scheme, "http") { - return transport +func newHTTP1BridgeTransport() http.RoundTripper { + return newHTTP1BridgeTransportWithTLSConfig(&tls.Config{NextProtos: []string{"http/1.1"}}) +} + +func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Protocols = new(http.Protocols) + transport.Protocols.SetHTTP1(true) + transport.TLSClientConfig = tlsConfig + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{} } - - transport.AllowHTTP = true - transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { - var dialer net.Dialer - return dialer.DialContext(ctx, network, addr) + if len(transport.TLSClientConfig.NextProtos) == 0 { + transport.TLSClientConfig.NextProtos = []string{"http/1.1"} } return transport } + +func newH2CTransport() http.RoundTripper { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Protocols = new(http.Protocols) + transport.Protocols.SetUnencryptedHTTP2(true) + return transport +} diff --git a/reverseproxy.go b/reverseproxy.go index 186e163..13ffe89 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -5,7 +5,10 @@ package touka import ( + "bufio" "context" + "crypto/rand" + "encoding/base64" "errors" "fmt" "io" @@ -52,9 +55,10 @@ type ReverseProxyConfig struct { LoadBalancing ReverseProxyLoadBalancingConfig PassiveHealth ReverseProxyPassiveHealthConfig - Transport http.RoundTripper - FlushInterval time.Duration - BufferPool BufferPool + Transport http.RoundTripper + FlushInterval time.Duration + BufferPool BufferPool + AllowH2CUpstream bool ModifyRequest func(*http.Request) ModifyResponse func(*http.Response) error @@ -86,6 +90,33 @@ type reverseProxyStatusError struct { err error } +type reverseProxyExtendedConnectBridge struct { + body io.ReadCloser +} + +type reverseProxyH2ReadWriteCloser struct { + io.ReadCloser + ResponseWriter +} + +func (rwc reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) { + n, err := rwc.ResponseWriter.Write(p) + if err != nil { + return 0, err + } + if err := http.NewResponseController(reverseProxyBaseResponseWriter(rwc.ResponseWriter)).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + return 0, err + } + return n, nil +} + +func (rwc reverseProxyH2ReadWriteCloser) Close() error { + if rwc.ReadCloser == nil { + return nil + } + return rwc.ReadCloser.Close() +} + func (e *reverseProxyStatusError) Error() string { if e == nil || e.err == nil { return "" @@ -314,7 +345,7 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte } defer cleanup() - transport := p.transportForUpstream(c.Request, upstream) + transport := p.transportForUpstream(outreq, upstream) rawWriter := reverseProxyBaseResponseWriter(c.Writer) var ( roundTripMu sync.Mutex @@ -353,6 +384,20 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte upstream.recordFailure(time.Now(), p.config.PassiveHealth) } + if bridge := reverseProxyExtendedConnectBridgeFromContext(outreq.Context()); bridge != nil { + if res.StatusCode == http.StatusSwitchingProtocols { + appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) + if !p.modifyResponse(c, res, outreq) { + return true, nil, false + } + if err := p.handleBridgedExtendedConnectResponse(c, outreq, res, bridge); err != nil { + return false, err, false + } + return true, nil, false + } + return false, &reverseProxyStatusError{status: http.StatusBadGateway, err: fmt.Errorf("extended CONNECT backend returned status %d instead of 101", res.StatusCode)}, false + } + if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices { removeHopByHopHeaders(res.Header) res.Header.Del("Content-Length") @@ -435,6 +480,10 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) { outreq := c.Request.Clone(ctx) + bridgeCtx, bridged := reverseProxyPrepareExtendedConnectBridge(outreq) + if bridged { + outreq = outreq.WithContext(bridgeCtx) + } if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 { outreq.Body = nil } else if c.Request.GetBody != nil { @@ -451,7 +500,7 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte } outreq.Close = false var connectWriter *io.PipeWriter - if outreq.Method == http.MethodConnect { + if outreq.Method == http.MethodConnect && !bridged { pipeReader, pipeWriter := io.Pipe() outreq.Body = pipeReader outreq.ContentLength = -1 @@ -467,7 +516,13 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte } if outreq.Method == http.MethodConnect { - if reverseProxyIsExtendedConnectRequest(outreq) { + if bridged { + rewriteReverseProxyURL(outreq, upstream.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + } else if reverseProxyIsExtendedConnectRequest(outreq) { rewriteReverseProxyURL(outreq, upstream.target) if !p.config.PreserveHost { outreq.Host = "" @@ -526,6 +581,15 @@ func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream * if p.config.Transport != nil { return p.config.Transport } + if reverseProxyExtendedConnectBridgeFromContext(req.Context()) != nil { + if upstream.bridgeTransport != nil { + return upstream.bridgeTransport + } + return http.DefaultTransport + } + if upstream.useH2C && upstream.h2cTransport != nil { + return upstream.h2cTransport + } if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil { return upstream.extendedConnectTransport } @@ -915,6 +979,71 @@ func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Reques return firstErr } +func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, bridge *reverseProxyExtendedConnectBridge) error { + if c == nil || c.Request == nil { + res.Body.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT bridge requires a valid request context")} + } + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: errors.New("backend returned bridged websocket response without writable body"), + } + } + + controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer)) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + backConn.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + responseHeader := c.Writer.Header() + reverseProxyCopyHeader(responseHeader, res.Header) + responseHeader.Del("Upgrade") + responseHeader.Del("Connection") + responseHeader.Del("Sec-WebSocket-Accept") + c.Writer.WriteHeader(http.StatusOK) + if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + backConn.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + conn := reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer} + brw := bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1)) + + backConnClosed := make(chan struct{}) + go func() { + select { + case <-req.Context().Done(): + case <-backConnClosed: + } + backConn.Close() + }() + defer close(backConnClosed) + defer conn.Close() + defer backConn.Close() + + if err := brw.Flush(); err != nil { + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + errc := make(chan error, 2) + copyer := switchProtocolCopier{user: conn, backend: backConn} + go copyer.copyToBackend(errc) + go copyer.copyFromBackend(errc) + + firstErr := <-errc + if firstErr == nil { + firstErr = <-errc + } + if reverseProxyIsBenignTunnelError(firstErr) { + return nil + } + return firstErr +} + func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error { if c == nil || c.Request == nil { res.Body.Close() @@ -1128,13 +1257,23 @@ func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstr upstreams := make([]*reverseProxyUpstream, 0, len(targets)) for i, target := range targets { + useH2C := strings.EqualFold(target.Scheme, "h2c") + if useH2C { + target = cloneReverseProxyURL(target) + target.Scheme = "http" + } upstream := &reverseProxyUpstream{ key: fmt.Sprintf("%d:%s", i, target.String()), target: target, index: i, + useH2C: useH2C || config.AllowH2CUpstream, } if config.Transport == nil { - upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target) + upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport() + upstream.bridgeTransport = newHTTP1BridgeTransport() + if upstream.useH2C { + upstream.h2cTransport = newH2CTransport() + } } upstreams = append(upstreams, upstream) } @@ -1237,6 +1376,47 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { return policy == ForwardedBoth || policy == ForwardedRFC7239Only } +func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool) { + protocol := reverseProxyExtendedConnectProtocol(req) + if req == nil || req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") { + if req == nil { + return context.Background(), false + } + return req.Context(), false + } + + bridge := &reverseProxyExtendedConnectBridge{body: req.Body} + ctx := context.WithValue(req.Context(), reverseProxyExtendedConnectBridge{}, bridge) + req.Header.Del(":protocol") + req.Method = http.MethodGet + req.Body = http.NoBody + req.ContentLength = 0 + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Sec-WebSocket-Version", "13") + key, err := reverseProxyGenerateWebSocketKey() + if err == nil { + req.Header.Set("Sec-WebSocket-Key", key) + } + return ctx, true +} + +func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseProxyExtendedConnectBridge { + if ctx == nil { + return nil + } + bridge, _ := ctx.Value(reverseProxyExtendedConnectBridge{}).(*reverseProxyExtendedConnectBridge) + return bridge +} + +func reverseProxyGenerateWebSocketKey() (string, error) { + key := make([]byte, 16) + if _, err := rand.Read(key); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(key), nil +} + func reverseProxyIsExtendedConnectRequest(req *http.Request) bool { return reverseProxyExtendedConnectProtocol(req) != "" } diff --git a/reverseproxy_lb.go b/reverseproxy_lb.go index 9b41af0..d2d45ab 100644 --- a/reverseproxy_lb.go +++ b/reverseproxy_lb.go @@ -57,7 +57,10 @@ type reverseProxyUpstream struct { key string target *url.URL index int + useH2C bool extendedConnectTransport http.RoundTripper + bridgeTransport http.RoundTripper + h2cTransport http.RoundTripper inFlight atomic.Int64 passiveMu sync.Mutex diff --git a/reverseproxy_test.go b/reverseproxy_test.go index b68f74e..b05f426 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -112,7 +112,8 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { t.Fatalf("unexpected body: %q", string(body)) } if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %d", resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) } if got.Path != "/base/api/ping" { t.Fatalf("unexpected upstream path: %q", got.Path) @@ -765,6 +766,43 @@ func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) { } } +func TestReverseProxyAllowH2CUpstream(t *testing.T) { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen h2c upstream: %v", err) + } + server := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Upstream-Proto", r.Proto) + _, _ = io.WriteString(w, "ok") + })} + server.Protocols = new(http.Protocols) + server.Protocols.SetUnencryptedHTTP2(true) + errCh := make(chan error, 1) + go func() { + errCh <- server.Serve(listener) + }() + defer func() { + _ = server.Close() + <-errCh + }() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://"+listener.Addr().String()), + AllowH2CUpstream: true, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusOK || rr.Body.String() != "ok" { + t.Fatalf("unexpected response: code=%d body=%q", rr.Code, rr.Body.String()) + } + if got := rr.Header().Get("X-Upstream-Proto"); got != "HTTP/2.0" { + t.Fatalf("expected h2c upstream proto, got %q", got) + } +} + func TestReverseProxyCustomErrorHandler(t *testing.T) { t.Helper() @@ -1363,19 +1401,29 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { enableHTTP2ExtendedConnectProtocol() errCh := make(chan error, 4) - upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } - if r.ProtoMajor != 2 { - errCh <- fmt.Errorf("unexpected upstream protocol version: %s", r.Proto) + if got := r.Header.Get(":protocol"); got != "" { + errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got) w.WriteHeader(http.StatusBadRequest) return } - if got := r.Header.Get(":protocol"); got != "websocket" { - errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got) + if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") { + errCh <- fmt.Errorf("unexpected upstream Connection header: %#v", r.Header.Values("Connection")) + w.WriteHeader(http.StatusBadRequest) + return + } + if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + errCh <- fmt.Errorf("unexpected upstream Upgrade header: %q", r.Header.Get("Upgrade")) + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.Header.Get("Sec-WebSocket-Key"); got == "" { + errCh <- errors.New("missing upstream Sec-WebSocket-Key header") w.WriteHeader(http.StatusBadRequest) return } @@ -1385,36 +1433,41 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { return } - controller := http.NewResponseController(w) - if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { - errCh <- fmt.Errorf("enable full duplex failed: %w", err) + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") return } - w.WriteHeader(http.StatusOK) - _ = controller.Flush() + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() - line, err := bufio.NewReader(r.Body).ReadString('\n') + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ignored\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("upstream flush failed: %w", err) + return + } + + line, err := brw.ReadString('\n') if err != nil { errCh <- fmt.Errorf("read tunneled request body failed: %w", err) return } - if _, err := io.WriteString(w, "echo:"+line); err != nil { + if _, err := io.WriteString(brw, "echo:"+line); err != nil { errCh <- fmt.Errorf("write tunneled response body failed: %w", err) return } - _ = controller.Flush() + _ = brw.Flush() })) - upstream.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { - t.Fatalf("configure upstream HTTP/2 server: %v", err) - } - upstream.StartTLS() defer upstream.Close() engine := New() engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ Target: mustParseURL(t, upstream.URL), - Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}), Via: "proxy.test", })) @@ -1445,7 +1498,10 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status: %d", resp.StatusCode) } - if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "2.0 proxy.test" { + if got := resp.Header.Get("Upgrade"); got != "" { + t.Fatalf("bridged extended CONNECT response should not expose Upgrade header, got %q", got) + } + if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" { t.Fatalf("unexpected Via response header: %#v", gotVia) } @@ -1470,6 +1526,116 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { } } +func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ProtoMajor != 1 { + errCh <- fmt.Errorf("expected bridged upstream protocol HTTP/1.x, got %s", r.Proto) + w.WriteHeader(http.StatusBadRequest) + return + } + if r.Method != http.MethodGet { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + errCh <- fmt.Errorf("unexpected websocket bridge headers: Connection=%#v Upgrade=%q", r.Header.Values("Connection"), r.Header.Get("Upgrade")) + w.WriteHeader(http.StatusBadRequest) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("upstream flush failed: %w", err) + return + } + + line, err := brw.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("read tunneled request body failed: %w", err) + return + } + if _, err := io.WriteString(brw, "echo:"+line); err != nil { + errCh <- fmt.Errorf("write tunneled response body failed: %w", err) + return + } + _ = brw.Flush() + })) + upstream.EnableHTTP2 = true + upstream.StartTLS() + defer upstream.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}), + Via: "proxy.test", + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) + } + if _, err := io.WriteString(pw, "ping\n"); err != nil { + t.Fatalf("write tunneled request body: %v", err) + } + message, err := bufio.NewReader(resp.Body).ReadString('\n') + if err != nil { + t.Fatalf("read tunneled response body: %v", err) + } + if message != "echo:ping\n" { + t.Fatalf("unexpected tunneled response body: %q", message) + } + _ = pw.Close() + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { t.Helper() @@ -1477,42 +1643,62 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { errCh := make(chan error, 8) newBackend := func(name string) *httptest.Server { - server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } - if got := r.Header.Get(":protocol"); got != "websocket" { + if got := r.Header.Get(":protocol"); got != "" { errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got) w.WriteHeader(http.StatusBadRequest) return } - - controller := http.NewResponseController(w) - if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { - errCh <- fmt.Errorf("%s enable full duplex failed: %w", name, err) + if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") { + errCh <- fmt.Errorf("%s unexpected upstream Connection header: %#v", name, r.Header.Values("Connection")) + w.WriteHeader(http.StatusBadRequest) + return + } + if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + errCh <- fmt.Errorf("%s unexpected upstream Upgrade header: %q", name, r.Header.Get("Upgrade")) + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.Header.Get("Sec-WebSocket-Key"); got == "" { + errCh <- fmt.Errorf("%s missing upstream Sec-WebSocket-Key header", name) + w.WriteHeader(http.StatusBadRequest) return } - w.WriteHeader(http.StatusOK) - _ = controller.Flush() - line, err := bufio.NewReader(r.Body).ReadString('\n') + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- fmt.Errorf("%s upstream response writer does not support hijack", name) + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("%s upstream hijack failed: %w", name, err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("%s upstream flush failed: %w", name, err) + return + } + + line, err := brw.ReadString('\n') if err != nil { errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err) return } - if _, err := io.WriteString(w, name+":"+line); err != nil { + if _, err := io.WriteString(brw, name+":"+line); err != nil { errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err) return } - _ = controller.Flush() + _ = brw.Flush() })) - server.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(server.Config); err != nil { - t.Fatalf("configure %s HTTP/2 server: %v", name, err) - } - server.StartTLS() return server } @@ -1527,8 +1713,7 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { LoadBalancing: ReverseProxyLoadBalancingConfig{ Policy: LBRoundRobin(), }, - Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, - Via: "proxy.test", + Via: "proxy.test", })) proxy := httptest.NewUnstartedServer(engine) @@ -1557,7 +1742,8 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %d", resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) } if _, err := io.WriteString(pw, payload+"\n"); err != nil { t.Fatalf("write tunneled request body: %v", err) @@ -1592,55 +1778,59 @@ func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) { enableHTTP2ExtendedConnectProtocol() errCh := make(chan error, 4) - upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } - controller := http.NewResponseController(w) - if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { - errCh <- fmt.Errorf("enable full duplex failed: %w", err) + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") return } - w.WriteHeader(http.StatusOK) - _ = controller.Flush() + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() - reader := bufio.NewReader(r.Body) + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("upstream flush failed: %w", err) + return + } + + reader := bufio.NewReader(brw) line, err := reader.ReadString('\n') if err != nil { errCh <- fmt.Errorf("read tunneled request body failed: %w", err) return } - if _, err := io.WriteString(w, "ack:"+line); err != nil { + if _, err := io.WriteString(brw, "ack:"+line); err != nil { errCh <- fmt.Errorf("write immediate tunneled response failed: %w", err) return } - _ = controller.Flush() + _ = brw.Flush() if _, err := io.Copy(io.Discard, reader); err != nil { errCh <- fmt.Errorf("wait for request half-close failed: %w", err) return } - if _, err := io.WriteString(w, "after-close\n"); err != nil { + if _, err := io.WriteString(brw, "after-close\n"); err != nil { errCh <- fmt.Errorf("write post-close tunneled response failed: %w", err) return } - _ = controller.Flush() + _ = brw.Flush() })) - upstream.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { - t.Fatalf("configure upstream HTTP/2 server: %v", err) - } - upstream.StartTLS() defer upstream.Close() engine := New() engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, upstream.URL), - Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, - Via: "proxy.test", + Target: mustParseURL(t, upstream.URL), + Via: "proxy.test", })) proxy := httptest.NewUnstartedServer(engine) @@ -1668,7 +1858,8 @@ func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %d", resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) } reader := bufio.NewReader(resp.Body) @@ -1707,36 +1898,37 @@ func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testi enableHTTP2ExtendedConnectProtocol() errCh := make(chan error, 4) - upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } - controller := http.NewResponseController(w) - if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { - errCh <- fmt.Errorf("enable full duplex failed: %w", err) + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") return } - w.WriteHeader(http.StatusOK) - _ = controller.Flush() + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + _ = brw.Flush() <-r.Context().Done() })) - upstream.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { - t.Fatalf("configure upstream HTTP/2 server: %v", err) - } - upstream.StartTLS() defer upstream.Close() proxyErrCh := make(chan error, 1) engine := New() engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, upstream.URL), - Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, - Via: "proxy.test", + Target: mustParseURL(t, upstream.URL), + Via: "proxy.test", ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { select { case proxyErrCh <- err: @@ -1772,7 +1964,8 @@ func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testi defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %d", resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) } writeErrCh := make(chan error, 1)