mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
fix(http2): preserve extended CONNECT tunnel shutdown semantics
This commit is contained in:
parent
2165cc4114
commit
59f190ce3a
2 changed files with 258 additions and 11 deletions
|
|
@ -829,6 +829,19 @@ func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *htt
|
||||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var closeOnce sync.Once
|
||||||
|
closeTunnel := func() {
|
||||||
|
closeOnce.Do(func() {
|
||||||
|
_ = c.Request.Body.Close()
|
||||||
|
_ = backWrite.Close()
|
||||||
|
_ = res.Body.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
<-req.Context().Done()
|
||||||
|
closeTunnel()
|
||||||
|
}()
|
||||||
|
|
||||||
errc := make(chan error, 2)
|
errc := make(chan error, 2)
|
||||||
go func() {
|
go func() {
|
||||||
_, err := io.Copy(backWrite, c.Request.Body)
|
_, err := io.Copy(backWrite, c.Request.Body)
|
||||||
|
|
@ -849,19 +862,24 @@ func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *htt
|
||||||
errc <- closeErr
|
errc <- closeErr
|
||||||
}()
|
}()
|
||||||
|
|
||||||
firstErr := <-errc
|
var firstErr error
|
||||||
_ = c.Request.Body.Close()
|
for i := 0; i < 2; i++ {
|
||||||
_ = backWrite.Close()
|
err := <-errc
|
||||||
_ = res.Body.Close()
|
|
||||||
secondErr := <-errc
|
|
||||||
|
|
||||||
for _, err := range []error{firstErr, secondErr} {
|
|
||||||
if reverseProxyIsBenignTunnelError(err) {
|
if reverseProxyIsBenignTunnelError(err) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return err
|
if firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
closeTunnel()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
closeTunnel()
|
||||||
|
if reverseProxyIsBenignTunnelError(firstErr) {
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return firstErr
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration {
|
func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration {
|
||||||
|
|
@ -902,7 +920,7 @@ func (p *reverseProxyHandler) copyBuffer(dst io.Writer, src io.Reader, buf []byt
|
||||||
var written int64
|
var written int64
|
||||||
for {
|
for {
|
||||||
nr, rerr := src.Read(buf)
|
nr, rerr := src.Read(buf)
|
||||||
if rerr != nil && !errors.Is(rerr, io.EOF) && !errors.Is(rerr, context.Canceled) {
|
if rerr != nil && !errors.Is(rerr, io.EOF) && !reverseProxyIsBenignTunnelError(rerr) {
|
||||||
p.logf(nil, "reverse proxy read error during body copy: %v", rerr)
|
p.logf(nil, "reverse proxy read error during body copy: %v", rerr)
|
||||||
}
|
}
|
||||||
if nr > 0 {
|
if nr > 0 {
|
||||||
|
|
@ -1371,7 +1389,19 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func reverseProxyIsBenignTunnelError(err error) bool {
|
func reverseProxyIsBenignTunnelError(err error) bool {
|
||||||
return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler)
|
return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxyIsClosedBodyError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch err.Error() {
|
||||||
|
case "body closed by handler", "http2: response body closed", "response body closed":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter {
|
func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter {
|
||||||
|
|
|
||||||
|
|
@ -967,6 +967,223 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
enableHTTP2ExtendedConnectProtocol()
|
||||||
|
|
||||||
|
errCh := make(chan error, 4)
|
||||||
|
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodConnect {
|
||||||
|
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
||||||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
controller := http.NewResponseController(w)
|
||||||
|
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||||
|
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_ = controller.Flush()
|
||||||
|
|
||||||
|
reader := bufio.NewReader(r.Body)
|
||||||
|
line, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := io.WriteString(w, "ack:"+line); err != nil {
|
||||||
|
errCh <- fmt.Errorf("write immediate tunneled response failed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = controller.Flush()
|
||||||
|
|
||||||
|
if _, err := io.Copy(io.Discard, reader); err != nil {
|
||||||
|
errCh <- fmt.Errorf("wait for request half-close failed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := io.WriteString(w, "after-close\n"); err != nil {
|
||||||
|
errCh <- fmt.Errorf("write post-close tunneled response failed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = controller.Flush()
|
||||||
|
}))
|
||||||
|
upstream.EnableHTTP2 = true
|
||||||
|
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
|
||||||
|
t.Fatalf("configure upstream HTTP/2 server: %v", err)
|
||||||
|
}
|
||||||
|
upstream.StartTLS()
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: mustParseURL(t, upstream.URL),
|
||||||
|
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||||
|
Via: "proxy.test",
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewUnstartedServer(engine)
|
||||||
|
proxy.EnableHTTP2 = true
|
||||||
|
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
||||||
|
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
||||||
|
}
|
||||||
|
proxy.StartTLS()
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||||
|
defer transport.CloseIdleConnections()
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new CONNECT request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set(":protocol", "websocket")
|
||||||
|
|
||||||
|
resp, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("round trip extended CONNECT: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := bufio.NewReader(resp.Body)
|
||||||
|
if _, err := io.WriteString(pw, "ping\n"); err != nil {
|
||||||
|
t.Fatalf("write tunneled request body: %v", err)
|
||||||
|
}
|
||||||
|
message, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read immediate tunneled response: %v", err)
|
||||||
|
}
|
||||||
|
if message != "ack:ping\n" {
|
||||||
|
t.Fatalf("unexpected immediate tunneled response: %q", message)
|
||||||
|
}
|
||||||
|
if err := pw.Close(); err != nil {
|
||||||
|
t.Fatalf("close tunneled request body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
message, err = reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read post-close tunneled response: %v", err)
|
||||||
|
}
|
||||||
|
if message != "after-close\n" {
|
||||||
|
t.Fatalf("unexpected post-close tunneled response: %q", message)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
t.Fatal(err)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
enableHTTP2ExtendedConnectProtocol()
|
||||||
|
|
||||||
|
errCh := make(chan error, 4)
|
||||||
|
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodConnect {
|
||||||
|
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
||||||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
controller := http.NewResponseController(w)
|
||||||
|
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||||
|
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_ = controller.Flush()
|
||||||
|
|
||||||
|
<-r.Context().Done()
|
||||||
|
}))
|
||||||
|
upstream.EnableHTTP2 = true
|
||||||
|
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
|
||||||
|
t.Fatalf("configure upstream HTTP/2 server: %v", err)
|
||||||
|
}
|
||||||
|
upstream.StartTLS()
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
proxyErrCh := make(chan error, 1)
|
||||||
|
engine := New()
|
||||||
|
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: mustParseURL(t, upstream.URL),
|
||||||
|
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||||
|
Via: "proxy.test",
|
||||||
|
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
|
select {
|
||||||
|
case proxyErrCh <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewUnstartedServer(engine)
|
||||||
|
proxy.EnableHTTP2 = true
|
||||||
|
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
||||||
|
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
||||||
|
}
|
||||||
|
proxy.StartTLS()
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||||
|
defer transport.CloseIdleConnections()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodConnect, proxy.URL+"/ws", pr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new CONNECT request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set(":protocol", "websocket")
|
||||||
|
|
||||||
|
resp, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("round trip extended CONNECT: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
writeErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := io.WriteString(pw, strings.Repeat("x", 1<<20))
|
||||||
|
writeErrCh <- err
|
||||||
|
}()
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
_ = pw.CloseWithError(context.Canceled)
|
||||||
|
select {
|
||||||
|
case <-writeErrCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for request body writer to unblock")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-proxyErrCh:
|
||||||
|
t.Fatalf("proxy error handler should not be called on cancellation, got: %v", err)
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
t.Fatal(err)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) {
|
func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue