fix(reverseproxy): bridge websocket extended connect upstreams

This commit is contained in:
wjqserver 2026-04-02 18:19:41 +08:00
parent 919236665b
commit a9c1662333
5 changed files with 508 additions and 99 deletions

View file

@ -68,6 +68,7 @@ type ReverseProxyConfig struct {
Transport http.RoundTripper
FlushInterval time.Duration
BufferPool BufferPool
AllowH2CUpstream bool
ModifyRequest func(*http.Request)
ModifyResponse func(*http.Response) error
@ -191,6 +192,24 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
}))
```
### `AllowH2CUpstream`
允许代理使用未加密 HTTP/2h2c`http://` upstream 通信。
- 默认关闭
- 这是一个显式配置项
- 启用后Touka 会为该 upstream 使用 h2c prior-knowledge 方式连接上游
- 这意味着上游本身也必须显式支持 h2c它不是“先试 h2c失败再自动回退到 h1”的协商模式
```go
r.GET("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Target: target,
AllowH2CUpstream: true,
}))
```
对于下游 HTTP/2 extended `CONNECT` websocket 场景Touka 会只在该特殊桥接路径上强制与上游使用 HTTP/1.1 websocket upgrade以匹配 Caddy 风格的桥接语义;普通 HTTP 请求不会因为这个特性而被强制降级为 HTTP/1.1。
### `Transport`
可选。用于自定义底层转发所使用的 `http.RoundTripper`

View file

@ -5,12 +5,8 @@
package touka
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
"strings"
"sync"
_ "unsafe"
@ -36,18 +32,36 @@ func configureHTTP2ExtendedConnectServer(srv *http.Server) error {
return http2.ConfigureServer(srv, nil)
}
func newHTTP2ExtendedConnectTransport(target *url.URL) http.RoundTripper {
func newHTTP2ExtendedConnectTransport() http.RoundTripper {
enableHTTP2ExtendedConnectProtocol()
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.Protocols = new(http.Protocols)
transport.Protocols.SetHTTP1(true)
transport.Protocols.SetHTTP2(true)
return transport
}
transport := &http2.Transport{}
if target == nil || !strings.EqualFold(target.Scheme, "http") {
return transport
func newHTTP1BridgeTransport() http.RoundTripper {
return newHTTP1BridgeTransportWithTLSConfig(&tls.Config{NextProtos: []string{"http/1.1"}})
}
func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.Protocols = new(http.Protocols)
transport.Protocols.SetHTTP1(true)
transport.TLSClientConfig = tlsConfig
if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{}
}
transport.AllowHTTP = true
transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
var dialer net.Dialer
return dialer.DialContext(ctx, network, addr)
if len(transport.TLSClientConfig.NextProtos) == 0 {
transport.TLSClientConfig.NextProtos = []string{"http/1.1"}
}
return transport
}
func newH2CTransport() http.RoundTripper {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.Protocols = new(http.Protocols)
transport.Protocols.SetUnencryptedHTTP2(true)
return transport
}

View file

@ -5,7 +5,10 @@
package touka
import (
"bufio"
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
@ -52,9 +55,10 @@ type ReverseProxyConfig struct {
LoadBalancing ReverseProxyLoadBalancingConfig
PassiveHealth ReverseProxyPassiveHealthConfig
Transport http.RoundTripper
FlushInterval time.Duration
BufferPool BufferPool
Transport http.RoundTripper
FlushInterval time.Duration
BufferPool BufferPool
AllowH2CUpstream bool
ModifyRequest func(*http.Request)
ModifyResponse func(*http.Response) error
@ -86,6 +90,33 @@ type reverseProxyStatusError struct {
err error
}
type reverseProxyExtendedConnectBridge struct {
body io.ReadCloser
}
type reverseProxyH2ReadWriteCloser struct {
io.ReadCloser
ResponseWriter
}
func (rwc reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) {
n, err := rwc.ResponseWriter.Write(p)
if err != nil {
return 0, err
}
if err := http.NewResponseController(reverseProxyBaseResponseWriter(rwc.ResponseWriter)).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
return 0, err
}
return n, nil
}
func (rwc reverseProxyH2ReadWriteCloser) Close() error {
if rwc.ReadCloser == nil {
return nil
}
return rwc.ReadCloser.Close()
}
func (e *reverseProxyStatusError) Error() string {
if e == nil || e.err == nil {
return ""
@ -314,7 +345,7 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte
}
defer cleanup()
transport := p.transportForUpstream(c.Request, upstream)
transport := p.transportForUpstream(outreq, upstream)
rawWriter := reverseProxyBaseResponseWriter(c.Writer)
var (
roundTripMu sync.Mutex
@ -353,6 +384,20 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
}
if bridge := reverseProxyExtendedConnectBridgeFromContext(outreq.Context()); bridge != nil {
if res.StatusCode == http.StatusSwitchingProtocols {
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
if !p.modifyResponse(c, res, outreq) {
return true, nil, false
}
if err := p.handleBridgedExtendedConnectResponse(c, outreq, res, bridge); err != nil {
return false, err, false
}
return true, nil, false
}
return false, &reverseProxyStatusError{status: http.StatusBadGateway, err: fmt.Errorf("extended CONNECT backend returned status %d instead of 101", res.StatusCode)}, false
}
if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices {
removeHopByHopHeaders(res.Header)
res.Header.Del("Content-Length")
@ -435,6 +480,10 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte
func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) {
outreq := c.Request.Clone(ctx)
bridgeCtx, bridged := reverseProxyPrepareExtendedConnectBridge(outreq)
if bridged {
outreq = outreq.WithContext(bridgeCtx)
}
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
outreq.Body = nil
} else if c.Request.GetBody != nil {
@ -451,7 +500,7 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte
}
outreq.Close = false
var connectWriter *io.PipeWriter
if outreq.Method == http.MethodConnect {
if outreq.Method == http.MethodConnect && !bridged {
pipeReader, pipeWriter := io.Pipe()
outreq.Body = pipeReader
outreq.ContentLength = -1
@ -467,7 +516,13 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte
}
if outreq.Method == http.MethodConnect {
if reverseProxyIsExtendedConnectRequest(outreq) {
if bridged {
rewriteReverseProxyURL(outreq, upstream.target)
if !p.config.PreserveHost {
outreq.Host = ""
}
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
} else if reverseProxyIsExtendedConnectRequest(outreq) {
rewriteReverseProxyURL(outreq, upstream.target)
if !p.config.PreserveHost {
outreq.Host = ""
@ -526,6 +581,15 @@ func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream *
if p.config.Transport != nil {
return p.config.Transport
}
if reverseProxyExtendedConnectBridgeFromContext(req.Context()) != nil {
if upstream.bridgeTransport != nil {
return upstream.bridgeTransport
}
return http.DefaultTransport
}
if upstream.useH2C && upstream.h2cTransport != nil {
return upstream.h2cTransport
}
if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil {
return upstream.extendedConnectTransport
}
@ -915,6 +979,71 @@ func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Reques
return firstErr
}
func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, bridge *reverseProxyExtendedConnectBridge) error {
if c == nil || c.Request == nil {
res.Body.Close()
return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT bridge requires a valid request context")}
}
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
res.Body.Close()
return &reverseProxyStatusError{
status: http.StatusBadGateway,
err: errors.New("backend returned bridged websocket response without writable body"),
}
}
controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer))
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
backConn.Close()
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
responseHeader := c.Writer.Header()
reverseProxyCopyHeader(responseHeader, res.Header)
responseHeader.Del("Upgrade")
responseHeader.Del("Connection")
responseHeader.Del("Sec-WebSocket-Accept")
c.Writer.WriteHeader(http.StatusOK)
if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
backConn.Close()
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
conn := reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer}
brw := bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1))
backConnClosed := make(chan struct{})
go func() {
select {
case <-req.Context().Done():
case <-backConnClosed:
}
backConn.Close()
}()
defer close(backConnClosed)
defer conn.Close()
defer backConn.Close()
if err := brw.Flush(); err != nil {
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
errc := make(chan error, 2)
copyer := switchProtocolCopier{user: conn, backend: backConn}
go copyer.copyToBackend(errc)
go copyer.copyFromBackend(errc)
firstErr := <-errc
if firstErr == nil {
firstErr = <-errc
}
if reverseProxyIsBenignTunnelError(firstErr) {
return nil
}
return firstErr
}
func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error {
if c == nil || c.Request == nil {
res.Body.Close()
@ -1128,13 +1257,23 @@ func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstr
upstreams := make([]*reverseProxyUpstream, 0, len(targets))
for i, target := range targets {
useH2C := strings.EqualFold(target.Scheme, "h2c")
if useH2C {
target = cloneReverseProxyURL(target)
target.Scheme = "http"
}
upstream := &reverseProxyUpstream{
key: fmt.Sprintf("%d:%s", i, target.String()),
target: target,
index: i,
useH2C: useH2C || config.AllowH2CUpstream,
}
if config.Transport == nil {
upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport()
upstream.bridgeTransport = newHTTP1BridgeTransport()
if upstream.useH2C {
upstream.h2cTransport = newH2CTransport()
}
}
upstreams = append(upstreams, upstream)
}
@ -1237,6 +1376,47 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool {
return policy == ForwardedBoth || policy == ForwardedRFC7239Only
}
func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool) {
protocol := reverseProxyExtendedConnectProtocol(req)
if req == nil || req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") {
if req == nil {
return context.Background(), false
}
return req.Context(), false
}
bridge := &reverseProxyExtendedConnectBridge{body: req.Body}
ctx := context.WithValue(req.Context(), reverseProxyExtendedConnectBridge{}, bridge)
req.Header.Del(":protocol")
req.Method = http.MethodGet
req.Body = http.NoBody
req.ContentLength = 0
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Sec-WebSocket-Version", "13")
key, err := reverseProxyGenerateWebSocketKey()
if err == nil {
req.Header.Set("Sec-WebSocket-Key", key)
}
return ctx, true
}
func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseProxyExtendedConnectBridge {
if ctx == nil {
return nil
}
bridge, _ := ctx.Value(reverseProxyExtendedConnectBridge{}).(*reverseProxyExtendedConnectBridge)
return bridge
}
func reverseProxyGenerateWebSocketKey() (string, error) {
key := make([]byte, 16)
if _, err := rand.Read(key); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(key), nil
}
func reverseProxyIsExtendedConnectRequest(req *http.Request) bool {
return reverseProxyExtendedConnectProtocol(req) != ""
}

View file

@ -57,7 +57,10 @@ type reverseProxyUpstream struct {
key string
target *url.URL
index int
useH2C bool
extendedConnectTransport http.RoundTripper
bridgeTransport http.RoundTripper
h2cTransport http.RoundTripper
inFlight atomic.Int64
passiveMu sync.Mutex

View file

@ -112,7 +112,8 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
t.Fatalf("unexpected body: %q", string(body))
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
}
if got.Path != "/base/api/ping" {
t.Fatalf("unexpected upstream path: %q", got.Path)
@ -765,6 +766,43 @@ func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) {
}
}
func TestReverseProxyAllowH2CUpstream(t *testing.T) {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen h2c upstream: %v", err)
}
server := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Upstream-Proto", r.Proto)
_, _ = io.WriteString(w, "ok")
})}
server.Protocols = new(http.Protocols)
server.Protocols.SetUnencryptedHTTP2(true)
errCh := make(chan error, 1)
go func() {
errCh <- server.Serve(listener)
}()
defer func() {
_ = server.Close()
<-errCh
}()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, "http://"+listener.Addr().String()),
AllowH2CUpstream: true,
}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
t.Fatalf("unexpected response: code=%d body=%q", rr.Code, rr.Body.String())
}
if got := rr.Header().Get("X-Upstream-Proto"); got != "HTTP/2.0" {
t.Fatalf("expected h2c upstream proto, got %q", got)
}
}
func TestReverseProxyCustomErrorHandler(t *testing.T) {
t.Helper()
@ -1363,19 +1401,29 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
enableHTTP2ExtendedConnectProtocol()
errCh := make(chan error, 4)
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if r.ProtoMajor != 2 {
errCh <- fmt.Errorf("unexpected upstream protocol version: %s", r.Proto)
if got := r.Header.Get(":protocol"); got != "" {
errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got)
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.Header.Get(":protocol"); got != "websocket" {
errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got)
if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") {
errCh <- fmt.Errorf("unexpected upstream Connection header: %#v", r.Header.Values("Connection"))
w.WriteHeader(http.StatusBadRequest)
return
}
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
errCh <- fmt.Errorf("unexpected upstream Upgrade header: %q", r.Header.Get("Upgrade"))
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.Header.Get("Sec-WebSocket-Key"); got == "" {
errCh <- errors.New("missing upstream Sec-WebSocket-Key header")
w.WriteHeader(http.StatusBadRequest)
return
}
@ -1385,36 +1433,41 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
return
}
controller := http.NewResponseController(w)
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
hj, ok := w.(http.Hijacker)
if !ok {
errCh <- errors.New("upstream response writer does not support hijack")
return
}
w.WriteHeader(http.StatusOK)
_ = controller.Flush()
conn, brw, err := hj.Hijack()
if err != nil {
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
return
}
defer conn.Close()
line, err := bufio.NewReader(r.Body).ReadString('\n')
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ignored\r\n\r\n")
if err := brw.Flush(); err != nil {
errCh <- fmt.Errorf("upstream flush failed: %w", err)
return
}
line, err := brw.ReadString('\n')
if err != nil {
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
return
}
if _, err := io.WriteString(w, "echo:"+line); err != nil {
if _, err := io.WriteString(brw, "echo:"+line); err != nil {
errCh <- fmt.Errorf("write tunneled response body failed: %w", err)
return
}
_ = controller.Flush()
_ = brw.Flush()
}))
upstream.EnableHTTP2 = true
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
t.Fatalf("configure upstream HTTP/2 server: %v", err)
}
upstream.StartTLS()
defer upstream.Close()
engine := New()
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, upstream.URL),
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
Via: "proxy.test",
}))
@ -1445,7 +1498,10 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "2.0 proxy.test" {
if got := resp.Header.Get("Upgrade"); got != "" {
t.Fatalf("bridged extended CONNECT response should not expose Upgrade header, got %q", got)
}
if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" {
t.Fatalf("unexpected Via response header: %#v", gotVia)
}
@ -1470,6 +1526,116 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
}
}
func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) {
t.Helper()
enableHTTP2ExtendedConnectProtocol()
errCh := make(chan error, 4)
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ProtoMajor != 1 {
errCh <- fmt.Errorf("expected bridged upstream protocol HTTP/1.x, got %s", r.Proto)
w.WriteHeader(http.StatusBadRequest)
return
}
if r.Method != http.MethodGet {
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
errCh <- fmt.Errorf("unexpected websocket bridge headers: Connection=%#v Upgrade=%q", r.Header.Values("Connection"), r.Header.Get("Upgrade"))
w.WriteHeader(http.StatusBadRequest)
return
}
hj, ok := w.(http.Hijacker)
if !ok {
errCh <- errors.New("upstream response writer does not support hijack")
return
}
conn, brw, err := hj.Hijack()
if err != nil {
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
return
}
defer conn.Close()
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
if err := brw.Flush(); err != nil {
errCh <- fmt.Errorf("upstream flush failed: %w", err)
return
}
line, err := brw.ReadString('\n')
if err != nil {
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
return
}
if _, err := io.WriteString(brw, "echo:"+line); err != nil {
errCh <- fmt.Errorf("write tunneled response body failed: %w", err)
return
}
_ = brw.Flush()
}))
upstream.EnableHTTP2 = true
upstream.StartTLS()
defer upstream.Close()
engine := New()
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, upstream.URL),
Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
Via: "proxy.test",
}))
proxy := httptest.NewUnstartedServer(engine)
proxy.EnableHTTP2 = true
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
t.Fatalf("configure proxy HTTP/2 server: %v", err)
}
proxy.StartTLS()
defer proxy.Close()
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
defer transport.CloseIdleConnections()
pr, pw := io.Pipe()
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
if err != nil {
t.Fatalf("new CONNECT request: %v", err)
}
req.Header.Set(":protocol", "websocket")
resp, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("round trip extended CONNECT: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
}
if _, err := io.WriteString(pw, "ping\n"); err != nil {
t.Fatalf("write tunneled request body: %v", err)
}
message, err := bufio.NewReader(resp.Body).ReadString('\n')
if err != nil {
t.Fatalf("read tunneled response body: %v", err)
}
if message != "echo:ping\n" {
t.Fatalf("unexpected tunneled response body: %q", message)
}
_ = pw.Close()
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}
func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
t.Helper()
@ -1477,42 +1643,62 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
errCh := make(chan error, 8)
newBackend := func(name string) *httptest.Server {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if got := r.Header.Get(":protocol"); got != "websocket" {
if got := r.Header.Get(":protocol"); got != "" {
errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got)
w.WriteHeader(http.StatusBadRequest)
return
}
controller := http.NewResponseController(w)
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
errCh <- fmt.Errorf("%s enable full duplex failed: %w", name, err)
if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") {
errCh <- fmt.Errorf("%s unexpected upstream Connection header: %#v", name, r.Header.Values("Connection"))
w.WriteHeader(http.StatusBadRequest)
return
}
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
errCh <- fmt.Errorf("%s unexpected upstream Upgrade header: %q", name, r.Header.Get("Upgrade"))
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.Header.Get("Sec-WebSocket-Key"); got == "" {
errCh <- fmt.Errorf("%s missing upstream Sec-WebSocket-Key header", name)
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
_ = controller.Flush()
line, err := bufio.NewReader(r.Body).ReadString('\n')
hj, ok := w.(http.Hijacker)
if !ok {
errCh <- fmt.Errorf("%s upstream response writer does not support hijack", name)
return
}
conn, brw, err := hj.Hijack()
if err != nil {
errCh <- fmt.Errorf("%s upstream hijack failed: %w", name, err)
return
}
defer conn.Close()
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
if err := brw.Flush(); err != nil {
errCh <- fmt.Errorf("%s upstream flush failed: %w", name, err)
return
}
line, err := brw.ReadString('\n')
if err != nil {
errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err)
return
}
if _, err := io.WriteString(w, name+":"+line); err != nil {
if _, err := io.WriteString(brw, name+":"+line); err != nil {
errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err)
return
}
_ = controller.Flush()
_ = brw.Flush()
}))
server.EnableHTTP2 = true
if err := configureHTTP2ExtendedConnectServer(server.Config); err != nil {
t.Fatalf("configure %s HTTP/2 server: %v", name, err)
}
server.StartTLS()
return server
}
@ -1527,8 +1713,7 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBRoundRobin(),
},
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
Via: "proxy.test",
Via: "proxy.test",
}))
proxy := httptest.NewUnstartedServer(engine)
@ -1557,7 +1742,8 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
}
if _, err := io.WriteString(pw, payload+"\n"); err != nil {
t.Fatalf("write tunneled request body: %v", err)
@ -1592,55 +1778,59 @@ func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) {
enableHTTP2ExtendedConnectProtocol()
errCh := make(chan error, 4)
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
controller := http.NewResponseController(w)
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
hj, ok := w.(http.Hijacker)
if !ok {
errCh <- errors.New("upstream response writer does not support hijack")
return
}
w.WriteHeader(http.StatusOK)
_ = controller.Flush()
conn, brw, err := hj.Hijack()
if err != nil {
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
return
}
defer conn.Close()
reader := bufio.NewReader(r.Body)
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
if err := brw.Flush(); err != nil {
errCh <- fmt.Errorf("upstream flush failed: %w", err)
return
}
reader := bufio.NewReader(brw)
line, err := reader.ReadString('\n')
if err != nil {
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
return
}
if _, err := io.WriteString(w, "ack:"+line); err != nil {
if _, err := io.WriteString(brw, "ack:"+line); err != nil {
errCh <- fmt.Errorf("write immediate tunneled response failed: %w", err)
return
}
_ = controller.Flush()
_ = brw.Flush()
if _, err := io.Copy(io.Discard, reader); err != nil {
errCh <- fmt.Errorf("wait for request half-close failed: %w", err)
return
}
if _, err := io.WriteString(w, "after-close\n"); err != nil {
if _, err := io.WriteString(brw, "after-close\n"); err != nil {
errCh <- fmt.Errorf("write post-close tunneled response failed: %w", err)
return
}
_ = controller.Flush()
_ = brw.Flush()
}))
upstream.EnableHTTP2 = true
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
t.Fatalf("configure upstream HTTP/2 server: %v", err)
}
upstream.StartTLS()
defer upstream.Close()
engine := New()
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, upstream.URL),
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
Via: "proxy.test",
Target: mustParseURL(t, upstream.URL),
Via: "proxy.test",
}))
proxy := httptest.NewUnstartedServer(engine)
@ -1668,7 +1858,8 @@ func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
}
reader := bufio.NewReader(resp.Body)
@ -1707,36 +1898,37 @@ func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testi
enableHTTP2ExtendedConnectProtocol()
errCh := make(chan error, 4)
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
controller := http.NewResponseController(w)
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
hj, ok := w.(http.Hijacker)
if !ok {
errCh <- errors.New("upstream response writer does not support hijack")
return
}
w.WriteHeader(http.StatusOK)
_ = controller.Flush()
conn, brw, err := hj.Hijack()
if err != nil {
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
return
}
defer conn.Close()
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
_ = brw.Flush()
<-r.Context().Done()
}))
upstream.EnableHTTP2 = true
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
t.Fatalf("configure upstream HTTP/2 server: %v", err)
}
upstream.StartTLS()
defer upstream.Close()
proxyErrCh := make(chan error, 1)
engine := New()
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, upstream.URL),
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
Via: "proxy.test",
Target: mustParseURL(t, upstream.URL),
Via: "proxy.test",
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
select {
case proxyErrCh <- err:
@ -1772,7 +1964,8 @@ func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testi
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
}
writeErrCh := make(chan error, 1)