diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 1dfd760..959d866 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -292,6 +292,7 @@ Touka 会尽量遵循代理链语义: Touka 的反向代理实现支持以下能力: - `CONNECT` 隧道转发(HTTP/1.x) +- HTTP/2 extended `CONNECT` - `Connection: Upgrade` / `Upgrade` 协议升级转发 - WebSocket 等 101 Switching Protocols 场景 - SSE(Server-Sent Events)立即刷新 diff --git a/docs/routing.md b/docs/routing.md index e90308e..223081a 100644 --- a/docs/routing.md +++ b/docs/routing.md @@ -22,6 +22,8 @@ r.ANY("/any", handle) r.HandleFunc([]string{"GET", "POST"}, "/multi", handle) ``` +服务器级 `OPTIONS *` 请求不需要单独注册路由。Touka 会直接返回一个空的 `200 OK` 响应,而不会把它当成 `/` 路由来匹配。 + ## 路径参数 (Named Parameters) 使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。 diff --git a/engine.go b/engine.go index a4350c0..b7cf330 100644 --- a/engine.go +++ b/engine.go @@ -7,6 +7,7 @@ package touka import ( "context" "errors" + "io" "reflect" "runtime" "strings" @@ -344,6 +345,11 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) { func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { if engine.serverProtocols != nil { srv.Protocols = engine.serverProtocols + if engine.serverProtocols.HTTP2() || engine.serverProtocols.UnencryptedHTTP2() { + if err := configureHTTP2ExtendedConnectServer(srv); err != nil { + panic(err) + } + } } } @@ -695,6 +701,11 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { // handleRequest 负责根据请求查找路由并执行相应的处理函数链 // 这是路由查找和执行的核心逻辑 func (engine *Engine) handleRequest(c *Context) { + if isGeneralOptionsRequest(c.Request) { + engine.handleGeneralOptions(c) + return + } + httpMethod := c.Request.Method requestPath := routeLookupPath(c.Request) @@ -808,6 +819,20 @@ func (engine *Engine) allowedMethodsForPath(requestPath string) []string { return allowedMethods } +func (engine *Engine) handleGeneralOptions(c *Context) { + if c == nil || c.Request == nil { + return + } + + c.Writer.Header().Set("Content-Length", "0") + if c.Request.ContentLength != 0 { + mb := http.MaxBytesReader(c.Writer, c.Request.Body, 4<<10) + _, _ = io.Copy(io.Discard, mb) + } + c.Writer.WriteHeader(http.StatusOK) + c.Abort() +} + // Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消. // 它可以用于在长连接 (如 SSE) 中监听关闭信号. func (engine *Engine) Context() context.Context { diff --git a/go.mod b/go.mod index 42f4be4..bd0c046 100644 --- a/go.mod +++ b/go.mod @@ -8,9 +8,10 @@ require ( github.com/WJQSERVER/wanf v0.0.8 github.com/fenthope/reco v0.0.5 github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 + golang.org/x/net v0.52.0 ) require ( github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/net v0.52.0 // indirect + golang.org/x/text v0.35.0 // indirect ) diff --git a/go.sum b/go.sum index b49879b..6a8d0c6 100644 --- a/go.sum +++ b/go.sum @@ -12,3 +12,5 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= diff --git a/http2xconnect.go b/http2xconnect.go new file mode 100644 index 0000000..b3b12a0 --- /dev/null +++ b/http2xconnect.go @@ -0,0 +1,53 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2026 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "net/url" + "strings" + "sync" + _ "unsafe" + + "golang.org/x/net/http2" +) + +var enableHTTP2ExtendedConnectOnce sync.Once + +//go:linkname xnetDisableHTTP2ExtendedConnectProtocol golang.org/x/net/http2.disableExtendedConnectProtocol +var xnetDisableHTTP2ExtendedConnectProtocol bool + +func enableHTTP2ExtendedConnectProtocol() { + enableHTTP2ExtendedConnectOnce.Do(func() { + xnetDisableHTTP2ExtendedConnectProtocol = false + }) +} + +func configureHTTP2ExtendedConnectServer(srv *http.Server) error { + if srv == nil { + return nil + } + enableHTTP2ExtendedConnectProtocol() + return http2.ConfigureServer(srv, nil) +} + +func newHTTP2ExtendedConnectTransport(target *url.URL) http.RoundTripper { + enableHTTP2ExtendedConnectProtocol() + + transport := &http2.Transport{} + if target == nil || !strings.EqualFold(target.Scheme, "http") { + return transport + } + + transport.AllowHTTP = true + transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { + var dialer net.Dialer + return dialer.DialContext(ctx, network, addr) + } + return transport +} diff --git a/reverseproxy.go b/reverseproxy.go index 977402b..e01f4d0 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -67,10 +67,11 @@ var ( ) type reverseProxyHandler struct { - config ReverseProxyConfig - target *url.URL - receivedBy string - configError error + config ReverseProxyConfig + target *url.URL + receivedBy string + configError error + extendedConnectTransport http.RoundTripper } type reverseProxyStatusError struct { @@ -208,6 +209,9 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { target: target, receivedBy: reverseProxyReceivedBy(config.Via), } + if config.Transport == nil { + proxy.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target) + } if err := validateReverseProxyTarget(target); err != nil { proxy.configError = err @@ -238,7 +242,11 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { transport := p.config.Transport if transport == nil { - transport = http.DefaultTransport + if reverseProxyIsExtendedConnectRequest(c.Request) && p.extendedConnectTransport != nil { + transport = p.extendedConnectTransport + } else { + transport = http.DefaultTransport + } } updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c) @@ -280,9 +288,17 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { } if outreq.Method == http.MethodConnect { - if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil { - p.handleError(c, err) - return + if reverseProxyIsExtendedConnectRequest(outreq) { + rewriteReverseProxyURL(outreq, p.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + } else { + if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil { + p.handleError(c, err) + return + } } } else { rewriteReverseProxyURL(outreq, p.target) @@ -367,7 +383,11 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { if !p.modifyResponse(c, res, outreq) { return } - if err := p.handleConnectResponse(c, outreq, res, connectWriter); err != nil { + handleConnect := p.handleConnectResponse + if reverseProxyIsExtendedConnectRequest(outreq) { + handleConnect = p.handleExtendedConnectResponse + } + if err := handleConnect(c, outreq, res, connectWriter); err != nil { p.handleError(c, err) } connectWriter = nil @@ -778,6 +798,72 @@ func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Reques return firstErr } +func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error { + if c == nil || c.Request == nil { + res.Body.Close() + if backWrite != nil { + _ = backWrite.Close() + } + return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT requires a valid request context")} + } + if backWrite == nil { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: errors.New("reverse proxy extended CONNECT tunnel is missing backend writer"), + } + } + + controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer)) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + res.Body.Close() + _ = backWrite.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + reverseProxyCopyHeader(c.Writer.Header(), res.Header) + c.Writer.WriteHeader(res.StatusCode) + if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + res.Body.Close() + _ = backWrite.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + errc := make(chan error, 2) + go func() { + _, err := io.Copy(backWrite, c.Request.Body) + closeErr := backWrite.Close() + if err != nil && !reverseProxyIsBenignTunnelError(err) { + errc <- err + return + } + errc <- closeErr + }() + go func() { + copyErr := p.copyResponse(c.Writer, res.Body, -1) + closeErr := res.Body.Close() + if copyErr != nil { + errc <- copyErr + return + } + errc <- closeErr + }() + + firstErr := <-errc + _ = c.Request.Body.Close() + _ = backWrite.Close() + _ = res.Body.Close() + secondErr := <-errc + + for _, err := range []error{firstErr, secondErr} { + if reverseProxyIsBenignTunnelError(err) { + continue + } + return err + } + return nil +} + func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" { return -1 @@ -968,6 +1054,17 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { return policy == ForwardedBoth || policy == ForwardedRFC7239Only } +func reverseProxyIsExtendedConnectRequest(req *http.Request) bool { + return reverseProxyExtendedConnectProtocol(req) != "" +} + +func reverseProxyExtendedConnectProtocol(req *http.Request) string { + if req == nil || req.Method != http.MethodConnect || req.Header == nil { + return "" + } + return textproto.TrimString(req.Header.Get(":protocol")) +} + func isValidForwardedNodeIdentifier(value string) bool { if value == "" { return false @@ -1273,6 +1370,10 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool { return req != nil && req.Context().Value(http.ServerContextKey) != nil } +func reverseProxyIsBenignTunnelError(err error) bool { + return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) +} + func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { return UnwrapResponseWriter(writer) } diff --git a/reverseproxy_test.go b/reverseproxy_test.go index b7df512..345dd97 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -3,6 +3,7 @@ package touka import ( "bufio" "context" + "crypto/tls" "errors" "fmt" "io" @@ -15,6 +16,8 @@ import ( "strings" "testing" "time" + + "golang.org/x/net/http2" ) func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { @@ -680,7 +683,7 @@ func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) { } } -func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) { +func TestEngineHandlesOptionsAsteriskLocally(t *testing.T) { t.Helper() engine := New() @@ -695,9 +698,12 @@ func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) { rr := httptest.NewRecorder() engine.ServeHTTP(rr, req) - if rr.Code != http.StatusNotFound { + if rr.Code != http.StatusOK { t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code) } + if got := rr.Header().Get("Content-Length"); got != "0" { + t.Fatalf("unexpected Content-Length header: %q", got) + } } func TestReverseProxyConnectTunnel(t *testing.T) { @@ -848,6 +854,119 @@ func TestReverseProxyConnectNeedsHijacker(t *testing.T) { } } +func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if r.ProtoMajor != 2 { + errCh <- fmt.Errorf("unexpected upstream protocol version: %s", r.Proto) + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.Header.Get(":protocol"); got != "websocket" { + errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got) + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.URL.Path; got != "/ws" { + errCh <- fmt.Errorf("unexpected upstream path: %q", got) + w.WriteHeader(http.StatusBadRequest) + return + } + + controller := http.NewResponseController(w) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + errCh <- fmt.Errorf("enable full duplex failed: %w", err) + return + } + w.WriteHeader(http.StatusOK) + _ = controller.Flush() + + line, err := bufio.NewReader(r.Body).ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("read tunneled request body failed: %w", err) + return + } + if _, err := io.WriteString(w, "echo:"+line); err != nil { + errCh <- fmt.Errorf("write tunneled response body failed: %w", err) + return + } + _ = controller.Flush() + })) + upstream.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { + t.Fatalf("configure upstream HTTP/2 server: %v", err) + } + upstream.StartTLS() + defer upstream.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Via: "proxy.test", + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "2.0 proxy.test" { + t.Fatalf("unexpected Via response header: %#v", gotVia) + } + + if _, err := io.WriteString(pw, "ping\n"); err != nil { + t.Fatalf("write tunneled request body: %v", err) + } + message, err := bufio.NewReader(resp.Body).ReadString('\n') + if err != nil { + t.Fatalf("read tunneled response body: %v", err) + } + if message != "echo:ping\n" { + t.Fatalf("unexpected tunneled response body: %q", message) + } + if err := pw.Close(); err != nil { + t.Fatalf("close tunneled request body: %v", err) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) { t.Helper()