mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
fix(reverseproxy): bridge websocket extended connect upstreams
This commit is contained in:
parent
919236665b
commit
a9c1662333
5 changed files with 508 additions and 99 deletions
|
|
@ -112,7 +112,8 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
|||
t.Fatalf("unexpected body: %q", string(body))
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
|
||||
}
|
||||
if got.Path != "/base/api/ping" {
|
||||
t.Fatalf("unexpected upstream path: %q", got.Path)
|
||||
|
|
@ -765,6 +766,43 @@ func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyAllowH2CUpstream(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen h2c upstream: %v", err)
|
||||
}
|
||||
server := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Upstream-Proto", r.Proto)
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
})}
|
||||
server.Protocols = new(http.Protocols)
|
||||
server.Protocols.SetUnencryptedHTTP2(true)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- server.Serve(listener)
|
||||
}()
|
||||
defer func() {
|
||||
_ = server.Close()
|
||||
<-errCh
|
||||
}()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, "http://"+listener.Addr().String()),
|
||||
AllowH2CUpstream: true,
|
||||
}))
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
|
||||
t.Fatalf("unexpected response: code=%d body=%q", rr.Code, rr.Body.String())
|
||||
}
|
||||
if got := rr.Header().Get("X-Upstream-Proto"); got != "HTTP/2.0" {
|
||||
t.Fatalf("expected h2c upstream proto, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyCustomErrorHandler(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -1363,19 +1401,29 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
|||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if r.ProtoMajor != 2 {
|
||||
errCh <- fmt.Errorf("unexpected upstream protocol version: %s", r.Proto)
|
||||
if got := r.Header.Get(":protocol"); got != "" {
|
||||
errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get(":protocol"); got != "websocket" {
|
||||
errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got)
|
||||
if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") {
|
||||
errCh <- fmt.Errorf("unexpected upstream Connection header: %#v", r.Header.Values("Connection"))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
||||
errCh <- fmt.Errorf("unexpected upstream Upgrade header: %q", r.Header.Get("Upgrade"))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Sec-WebSocket-Key"); got == "" {
|
||||
errCh <- errors.New("missing upstream Sec-WebSocket-Key header")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
|
@ -1385,36 +1433,41 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(w)
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
errCh <- errors.New("upstream response writer does not support hijack")
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = controller.Flush()
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
line, err := bufio.NewReader(r.Body).ReadString('\n')
|
||||
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ignored\r\n\r\n")
|
||||
if err := brw.Flush(); err != nil {
|
||||
errCh <- fmt.Errorf("upstream flush failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
line, err := brw.ReadString('\n')
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
|
||||
return
|
||||
}
|
||||
if _, err := io.WriteString(w, "echo:"+line); err != nil {
|
||||
if _, err := io.WriteString(brw, "echo:"+line); err != nil {
|
||||
errCh <- fmt.Errorf("write tunneled response body failed: %w", err)
|
||||
return
|
||||
}
|
||||
_ = controller.Flush()
|
||||
_ = brw.Flush()
|
||||
}))
|
||||
upstream.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
|
||||
t.Fatalf("configure upstream HTTP/2 server: %v", err)
|
||||
}
|
||||
upstream.StartTLS()
|
||||
defer upstream.Close()
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, upstream.URL),
|
||||
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||
Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
|
||||
Via: "proxy.test",
|
||||
}))
|
||||
|
||||
|
|
@ -1445,7 +1498,10 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
|||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "2.0 proxy.test" {
|
||||
if got := resp.Header.Get("Upgrade"); got != "" {
|
||||
t.Fatalf("bridged extended CONNECT response should not expose Upgrade header, got %q", got)
|
||||
}
|
||||
if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" {
|
||||
t.Fatalf("unexpected Via response header: %#v", gotVia)
|
||||
}
|
||||
|
||||
|
|
@ -1470,6 +1526,116 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.ProtoMajor != 1 {
|
||||
errCh <- fmt.Errorf("expected bridged upstream protocol HTTP/1.x, got %s", r.Proto)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodGet {
|
||||
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
||||
errCh <- fmt.Errorf("unexpected websocket bridge headers: Connection=%#v Upgrade=%q", r.Header.Values("Connection"), r.Header.Get("Upgrade"))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
errCh <- errors.New("upstream response writer does not support hijack")
|
||||
return
|
||||
}
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
||||
if err := brw.Flush(); err != nil {
|
||||
errCh <- fmt.Errorf("upstream flush failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
line, err := brw.ReadString('\n')
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
|
||||
return
|
||||
}
|
||||
if _, err := io.WriteString(brw, "echo:"+line); err != nil {
|
||||
errCh <- fmt.Errorf("write tunneled response body failed: %w", err)
|
||||
return
|
||||
}
|
||||
_ = brw.Flush()
|
||||
}))
|
||||
upstream.EnableHTTP2 = true
|
||||
upstream.StartTLS()
|
||||
defer upstream.Close()
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, upstream.URL),
|
||||
Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
|
||||
Via: "proxy.test",
|
||||
}))
|
||||
|
||||
proxy := httptest.NewUnstartedServer(engine)
|
||||
proxy.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
||||
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
||||
}
|
||||
proxy.StartTLS()
|
||||
defer proxy.Close()
|
||||
|
||||
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
defer transport.CloseIdleConnections()
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
||||
if err != nil {
|
||||
t.Fatalf("new CONNECT request: %v", err)
|
||||
}
|
||||
req.Header.Set(":protocol", "websocket")
|
||||
|
||||
resp, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("round trip extended CONNECT: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
|
||||
}
|
||||
if _, err := io.WriteString(pw, "ping\n"); err != nil {
|
||||
t.Fatalf("write tunneled request body: %v", err)
|
||||
}
|
||||
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("read tunneled response body: %v", err)
|
||||
}
|
||||
if message != "echo:ping\n" {
|
||||
t.Fatalf("unexpected tunneled response body: %q", message)
|
||||
}
|
||||
_ = pw.Close()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -1477,42 +1643,62 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
|
|||
|
||||
errCh := make(chan error, 8)
|
||||
newBackend := func(name string) *httptest.Server {
|
||||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get(":protocol"); got != "websocket" {
|
||||
if got := r.Header.Get(":protocol"); got != "" {
|
||||
errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(w)
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
errCh <- fmt.Errorf("%s enable full duplex failed: %w", name, err)
|
||||
if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") {
|
||||
errCh <- fmt.Errorf("%s unexpected upstream Connection header: %#v", name, r.Header.Values("Connection"))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
||||
errCh <- fmt.Errorf("%s unexpected upstream Upgrade header: %q", name, r.Header.Get("Upgrade"))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Sec-WebSocket-Key"); got == "" {
|
||||
errCh <- fmt.Errorf("%s missing upstream Sec-WebSocket-Key header", name)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = controller.Flush()
|
||||
|
||||
line, err := bufio.NewReader(r.Body).ReadString('\n')
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
errCh <- fmt.Errorf("%s upstream response writer does not support hijack", name)
|
||||
return
|
||||
}
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("%s upstream hijack failed: %w", name, err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
||||
if err := brw.Flush(); err != nil {
|
||||
errCh <- fmt.Errorf("%s upstream flush failed: %w", name, err)
|
||||
return
|
||||
}
|
||||
|
||||
line, err := brw.ReadString('\n')
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err)
|
||||
return
|
||||
}
|
||||
if _, err := io.WriteString(w, name+":"+line); err != nil {
|
||||
if _, err := io.WriteString(brw, name+":"+line); err != nil {
|
||||
errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err)
|
||||
return
|
||||
}
|
||||
_ = controller.Flush()
|
||||
_ = brw.Flush()
|
||||
}))
|
||||
server.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(server.Config); err != nil {
|
||||
t.Fatalf("configure %s HTTP/2 server: %v", name, err)
|
||||
}
|
||||
server.StartTLS()
|
||||
return server
|
||||
}
|
||||
|
||||
|
|
@ -1527,8 +1713,7 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
|
|||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBRoundRobin(),
|
||||
},
|
||||
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||
Via: "proxy.test",
|
||||
Via: "proxy.test",
|
||||
}))
|
||||
|
||||
proxy := httptest.NewUnstartedServer(engine)
|
||||
|
|
@ -1557,7 +1742,8 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
|
|||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
|
||||
}
|
||||
if _, err := io.WriteString(pw, payload+"\n"); err != nil {
|
||||
t.Fatalf("write tunneled request body: %v", err)
|
||||
|
|
@ -1592,55 +1778,59 @@ func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) {
|
|||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(w)
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
errCh <- errors.New("upstream response writer does not support hijack")
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = controller.Flush()
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
reader := bufio.NewReader(r.Body)
|
||||
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
||||
if err := brw.Flush(); err != nil {
|
||||
errCh <- fmt.Errorf("upstream flush failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(brw)
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
|
||||
return
|
||||
}
|
||||
if _, err := io.WriteString(w, "ack:"+line); err != nil {
|
||||
if _, err := io.WriteString(brw, "ack:"+line); err != nil {
|
||||
errCh <- fmt.Errorf("write immediate tunneled response failed: %w", err)
|
||||
return
|
||||
}
|
||||
_ = controller.Flush()
|
||||
_ = brw.Flush()
|
||||
|
||||
if _, err := io.Copy(io.Discard, reader); err != nil {
|
||||
errCh <- fmt.Errorf("wait for request half-close failed: %w", err)
|
||||
return
|
||||
}
|
||||
if _, err := io.WriteString(w, "after-close\n"); err != nil {
|
||||
if _, err := io.WriteString(brw, "after-close\n"); err != nil {
|
||||
errCh <- fmt.Errorf("write post-close tunneled response failed: %w", err)
|
||||
return
|
||||
}
|
||||
_ = controller.Flush()
|
||||
_ = brw.Flush()
|
||||
}))
|
||||
upstream.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
|
||||
t.Fatalf("configure upstream HTTP/2 server: %v", err)
|
||||
}
|
||||
upstream.StartTLS()
|
||||
defer upstream.Close()
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, upstream.URL),
|
||||
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||
Via: "proxy.test",
|
||||
Target: mustParseURL(t, upstream.URL),
|
||||
Via: "proxy.test",
|
||||
}))
|
||||
|
||||
proxy := httptest.NewUnstartedServer(engine)
|
||||
|
|
@ -1668,7 +1858,8 @@ func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) {
|
|||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
|
|
@ -1707,36 +1898,37 @@ func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testi
|
|||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(w)
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
errCh <- errors.New("upstream response writer does not support hijack")
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = controller.Flush()
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
||||
_ = brw.Flush()
|
||||
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
upstream.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
|
||||
t.Fatalf("configure upstream HTTP/2 server: %v", err)
|
||||
}
|
||||
upstream.StartTLS()
|
||||
defer upstream.Close()
|
||||
|
||||
proxyErrCh := make(chan error, 1)
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, upstream.URL),
|
||||
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||
Via: "proxy.test",
|
||||
Target: mustParseURL(t, upstream.URL),
|
||||
Via: "proxy.test",
|
||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
select {
|
||||
case proxyErrCh <- err:
|
||||
|
|
@ -1772,7 +1964,8 @@ func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testi
|
|||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
writeErrCh := make(chan error, 1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue