feat(http2): support OPTIONS * and extended CONNECT

This commit is contained in:
wjqserver 2026-04-02 03:53:17 +08:00
parent ed44c592d3
commit 2165cc4114
8 changed files with 316 additions and 12 deletions

View file

@ -292,6 +292,7 @@ Touka 会尽量遵循代理链语义:
Touka 的反向代理实现支持以下能力: Touka 的反向代理实现支持以下能力:
- `CONNECT` 隧道转发HTTP/1.x - `CONNECT` 隧道转发HTTP/1.x
- HTTP/2 extended `CONNECT`
- `Connection: Upgrade` / `Upgrade` 协议升级转发 - `Connection: Upgrade` / `Upgrade` 协议升级转发
- WebSocket 等 101 Switching Protocols 场景 - WebSocket 等 101 Switching Protocols 场景
- SSEServer-Sent Events立即刷新 - SSEServer-Sent Events立即刷新

View file

@ -22,6 +22,8 @@ r.ANY("/any", handle)
r.HandleFunc([]string{"GET", "POST"}, "/multi", handle) r.HandleFunc([]string{"GET", "POST"}, "/multi", handle)
``` ```
服务器级 `OPTIONS *` 请求不需要单独注册路由。Touka 会直接返回一个空的 `200 OK` 响应,而不会把它当成 `/` 路由来匹配。
## 路径参数 (Named Parameters) ## 路径参数 (Named Parameters)
使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。 使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。

View file

@ -7,6 +7,7 @@ package touka
import ( import (
"context" "context"
"errors" "errors"
"io"
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
@ -344,6 +345,11 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) {
func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { func (engine *Engine) applyDefaultServerConfig(srv *http.Server) {
if engine.serverProtocols != nil { if engine.serverProtocols != nil {
srv.Protocols = engine.serverProtocols srv.Protocols = engine.serverProtocols
if engine.serverProtocols.HTTP2() || engine.serverProtocols.UnencryptedHTTP2() {
if err := configureHTTP2ExtendedConnectServer(srv); err != nil {
panic(err)
}
}
} }
} }
@ -695,6 +701,11 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// handleRequest 负责根据请求查找路由并执行相应的处理函数链 // handleRequest 负责根据请求查找路由并执行相应的处理函数链
// 这是路由查找和执行的核心逻辑 // 这是路由查找和执行的核心逻辑
func (engine *Engine) handleRequest(c *Context) { func (engine *Engine) handleRequest(c *Context) {
if isGeneralOptionsRequest(c.Request) {
engine.handleGeneralOptions(c)
return
}
httpMethod := c.Request.Method httpMethod := c.Request.Method
requestPath := routeLookupPath(c.Request) requestPath := routeLookupPath(c.Request)
@ -808,6 +819,20 @@ func (engine *Engine) allowedMethodsForPath(requestPath string) []string {
return allowedMethods return allowedMethods
} }
func (engine *Engine) handleGeneralOptions(c *Context) {
if c == nil || c.Request == nil {
return
}
c.Writer.Header().Set("Content-Length", "0")
if c.Request.ContentLength != 0 {
mb := http.MaxBytesReader(c.Writer, c.Request.Body, 4<<10)
_, _ = io.Copy(io.Discard, mb)
}
c.Writer.WriteHeader(http.StatusOK)
c.Abort()
}
// Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消. // Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消.
// 它可以用于在长连接 (如 SSE) 中监听关闭信号. // 它可以用于在长连接 (如 SSE) 中监听关闭信号.
func (engine *Engine) Context() context.Context { func (engine *Engine) Context() context.Context {

3
go.mod
View file

@ -8,9 +8,10 @@ require (
github.com/WJQSERVER/wanf v0.0.8 github.com/WJQSERVER/wanf v0.0.8
github.com/fenthope/reco v0.0.5 github.com/fenthope/reco v0.0.5
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433
golang.org/x/net v0.52.0
) )
require ( require (
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
golang.org/x/net v0.52.0 // indirect golang.org/x/text v0.35.0 // indirect
) )

2
go.sum
View file

@ -12,3 +12,5 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=

53
http2xconnect.go Normal file
View file

@ -0,0 +1,53 @@
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2026 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
"strings"
"sync"
_ "unsafe"
"golang.org/x/net/http2"
)
var enableHTTP2ExtendedConnectOnce sync.Once
//go:linkname xnetDisableHTTP2ExtendedConnectProtocol golang.org/x/net/http2.disableExtendedConnectProtocol
var xnetDisableHTTP2ExtendedConnectProtocol bool
func enableHTTP2ExtendedConnectProtocol() {
enableHTTP2ExtendedConnectOnce.Do(func() {
xnetDisableHTTP2ExtendedConnectProtocol = false
})
}
func configureHTTP2ExtendedConnectServer(srv *http.Server) error {
if srv == nil {
return nil
}
enableHTTP2ExtendedConnectProtocol()
return http2.ConfigureServer(srv, nil)
}
func newHTTP2ExtendedConnectTransport(target *url.URL) http.RoundTripper {
enableHTTP2ExtendedConnectProtocol()
transport := &http2.Transport{}
if target == nil || !strings.EqualFold(target.Scheme, "http") {
return transport
}
transport.AllowHTTP = true
transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
var dialer net.Dialer
return dialer.DialContext(ctx, network, addr)
}
return transport
}

View file

@ -67,10 +67,11 @@ var (
) )
type reverseProxyHandler struct { type reverseProxyHandler struct {
config ReverseProxyConfig config ReverseProxyConfig
target *url.URL target *url.URL
receivedBy string receivedBy string
configError error configError error
extendedConnectTransport http.RoundTripper
} }
type reverseProxyStatusError struct { type reverseProxyStatusError struct {
@ -208,6 +209,9 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
target: target, target: target,
receivedBy: reverseProxyReceivedBy(config.Via), receivedBy: reverseProxyReceivedBy(config.Via),
} }
if config.Transport == nil {
proxy.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
}
if err := validateReverseProxyTarget(target); err != nil { if err := validateReverseProxyTarget(target); err != nil {
proxy.configError = err proxy.configError = err
@ -238,7 +242,11 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
transport := p.config.Transport transport := p.config.Transport
if transport == nil { if transport == nil {
transport = http.DefaultTransport if reverseProxyIsExtendedConnectRequest(c.Request) && p.extendedConnectTransport != nil {
transport = p.extendedConnectTransport
} else {
transport = http.DefaultTransport
}
} }
updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c) updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c)
@ -280,9 +288,17 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
} }
if outreq.Method == http.MethodConnect { if outreq.Method == http.MethodConnect {
if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil { if reverseProxyIsExtendedConnectRequest(outreq) {
p.handleError(c, err) rewriteReverseProxyURL(outreq, p.target)
return if !p.config.PreserveHost {
outreq.Host = ""
}
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
} else {
if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil {
p.handleError(c, err)
return
}
} }
} else { } else {
rewriteReverseProxyURL(outreq, p.target) rewriteReverseProxyURL(outreq, p.target)
@ -367,7 +383,11 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
if !p.modifyResponse(c, res, outreq) { if !p.modifyResponse(c, res, outreq) {
return return
} }
if err := p.handleConnectResponse(c, outreq, res, connectWriter); err != nil { handleConnect := p.handleConnectResponse
if reverseProxyIsExtendedConnectRequest(outreq) {
handleConnect = p.handleExtendedConnectResponse
}
if err := handleConnect(c, outreq, res, connectWriter); err != nil {
p.handleError(c, err) p.handleError(c, err)
} }
connectWriter = nil connectWriter = nil
@ -778,6 +798,72 @@ func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Reques
return firstErr return firstErr
} }
func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error {
if c == nil || c.Request == nil {
res.Body.Close()
if backWrite != nil {
_ = backWrite.Close()
}
return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT requires a valid request context")}
}
if backWrite == nil {
res.Body.Close()
return &reverseProxyStatusError{
status: http.StatusBadGateway,
err: errors.New("reverse proxy extended CONNECT tunnel is missing backend writer"),
}
}
controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer))
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
res.Body.Close()
_ = backWrite.Close()
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
reverseProxyCopyHeader(c.Writer.Header(), res.Header)
c.Writer.WriteHeader(res.StatusCode)
if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
res.Body.Close()
_ = backWrite.Close()
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
errc := make(chan error, 2)
go func() {
_, err := io.Copy(backWrite, c.Request.Body)
closeErr := backWrite.Close()
if err != nil && !reverseProxyIsBenignTunnelError(err) {
errc <- err
return
}
errc <- closeErr
}()
go func() {
copyErr := p.copyResponse(c.Writer, res.Body, -1)
closeErr := res.Body.Close()
if copyErr != nil {
errc <- copyErr
return
}
errc <- closeErr
}()
firstErr := <-errc
_ = c.Request.Body.Close()
_ = backWrite.Close()
_ = res.Body.Close()
secondErr := <-errc
for _, err := range []error{firstErr, secondErr} {
if reverseProxyIsBenignTunnelError(err) {
continue
}
return err
}
return nil
}
func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration {
if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" { if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" {
return -1 return -1
@ -968,6 +1054,17 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool {
return policy == ForwardedBoth || policy == ForwardedRFC7239Only return policy == ForwardedBoth || policy == ForwardedRFC7239Only
} }
func reverseProxyIsExtendedConnectRequest(req *http.Request) bool {
return reverseProxyExtendedConnectProtocol(req) != ""
}
func reverseProxyExtendedConnectProtocol(req *http.Request) string {
if req == nil || req.Method != http.MethodConnect || req.Header == nil {
return ""
}
return textproto.TrimString(req.Header.Get(":protocol"))
}
func isValidForwardedNodeIdentifier(value string) bool { func isValidForwardedNodeIdentifier(value string) bool {
if value == "" { if value == "" {
return false return false
@ -1273,6 +1370,10 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool {
return req != nil && req.Context().Value(http.ServerContextKey) != nil return req != nil && req.Context().Value(http.ServerContextKey) != nil
} }
func reverseProxyIsBenignTunnelError(err error) bool {
return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler)
}
func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter {
return UnwrapResponseWriter(writer) return UnwrapResponseWriter(writer)
} }

View file

@ -3,6 +3,7 @@ package touka
import ( import (
"bufio" "bufio"
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -15,6 +16,8 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"golang.org/x/net/http2"
) )
func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
@ -680,7 +683,7 @@ func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) {
} }
} }
func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) { func TestEngineHandlesOptionsAsteriskLocally(t *testing.T) {
t.Helper() t.Helper()
engine := New() engine := New()
@ -695,9 +698,12 @@ func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
engine.ServeHTTP(rr, req) engine.ServeHTTP(rr, req)
if rr.Code != http.StatusNotFound { if rr.Code != http.StatusOK {
t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code) t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code)
} }
if got := rr.Header().Get("Content-Length"); got != "0" {
t.Fatalf("unexpected Content-Length header: %q", got)
}
} }
func TestReverseProxyConnectTunnel(t *testing.T) { func TestReverseProxyConnectTunnel(t *testing.T) {
@ -848,6 +854,119 @@ func TestReverseProxyConnectNeedsHijacker(t *testing.T) {
} }
} }
func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
t.Helper()
enableHTTP2ExtendedConnectProtocol()
errCh := make(chan error, 4)
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if r.ProtoMajor != 2 {
errCh <- fmt.Errorf("unexpected upstream protocol version: %s", r.Proto)
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.Header.Get(":protocol"); got != "websocket" {
errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got)
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.URL.Path; got != "/ws" {
errCh <- fmt.Errorf("unexpected upstream path: %q", got)
w.WriteHeader(http.StatusBadRequest)
return
}
controller := http.NewResponseController(w)
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
return
}
w.WriteHeader(http.StatusOK)
_ = controller.Flush()
line, err := bufio.NewReader(r.Body).ReadString('\n')
if err != nil {
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
return
}
if _, err := io.WriteString(w, "echo:"+line); err != nil {
errCh <- fmt.Errorf("write tunneled response body failed: %w", err)
return
}
_ = controller.Flush()
}))
upstream.EnableHTTP2 = true
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
t.Fatalf("configure upstream HTTP/2 server: %v", err)
}
upstream.StartTLS()
defer upstream.Close()
engine := New()
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, upstream.URL),
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
Via: "proxy.test",
}))
proxy := httptest.NewUnstartedServer(engine)
proxy.EnableHTTP2 = true
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
t.Fatalf("configure proxy HTTP/2 server: %v", err)
}
proxy.StartTLS()
defer proxy.Close()
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
defer transport.CloseIdleConnections()
pr, pw := io.Pipe()
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
if err != nil {
t.Fatalf("new CONNECT request: %v", err)
}
req.Header.Set(":protocol", "websocket")
resp, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("round trip extended CONNECT: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "2.0 proxy.test" {
t.Fatalf("unexpected Via response header: %#v", gotVia)
}
if _, err := io.WriteString(pw, "ping\n"); err != nil {
t.Fatalf("write tunneled request body: %v", err)
}
message, err := bufio.NewReader(resp.Body).ReadString('\n')
if err != nil {
t.Fatalf("read tunneled response body: %v", err)
}
if message != "echo:ping\n" {
t.Fatalf("unexpected tunneled response body: %q", message)
}
if err := pw.Close(); err != nil {
t.Fatalf("close tunneled request body: %v", err)
}
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}
func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) { func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) {
t.Helper() t.Helper()