mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
feat(http2): support OPTIONS * and extended CONNECT
This commit is contained in:
parent
ed44c592d3
commit
2165cc4114
8 changed files with 316 additions and 12 deletions
|
|
@ -292,6 +292,7 @@ Touka 会尽量遵循代理链语义:
|
|||
Touka 的反向代理实现支持以下能力:
|
||||
|
||||
- `CONNECT` 隧道转发(HTTP/1.x)
|
||||
- HTTP/2 extended `CONNECT`
|
||||
- `Connection: Upgrade` / `Upgrade` 协议升级转发
|
||||
- WebSocket 等 101 Switching Protocols 场景
|
||||
- SSE(Server-Sent Events)立即刷新
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ r.ANY("/any", handle)
|
|||
r.HandleFunc([]string{"GET", "POST"}, "/multi", handle)
|
||||
```
|
||||
|
||||
服务器级 `OPTIONS *` 请求不需要单独注册路由。Touka 会直接返回一个空的 `200 OK` 响应,而不会把它当成 `/` 路由来匹配。
|
||||
|
||||
## 路径参数 (Named Parameters)
|
||||
|
||||
使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。
|
||||
|
|
|
|||
25
engine.go
25
engine.go
|
|
@ -7,6 +7,7 @@ package touka
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
|
@ -344,6 +345,11 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) {
|
|||
func (engine *Engine) applyDefaultServerConfig(srv *http.Server) {
|
||||
if engine.serverProtocols != nil {
|
||||
srv.Protocols = engine.serverProtocols
|
||||
if engine.serverProtocols.HTTP2() || engine.serverProtocols.UnencryptedHTTP2() {
|
||||
if err := configureHTTP2ExtendedConnectServer(srv); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -695,6 +701,11 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||
// handleRequest 负责根据请求查找路由并执行相应的处理函数链
|
||||
// 这是路由查找和执行的核心逻辑
|
||||
func (engine *Engine) handleRequest(c *Context) {
|
||||
if isGeneralOptionsRequest(c.Request) {
|
||||
engine.handleGeneralOptions(c)
|
||||
return
|
||||
}
|
||||
|
||||
httpMethod := c.Request.Method
|
||||
requestPath := routeLookupPath(c.Request)
|
||||
|
||||
|
|
@ -808,6 +819,20 @@ func (engine *Engine) allowedMethodsForPath(requestPath string) []string {
|
|||
return allowedMethods
|
||||
}
|
||||
|
||||
func (engine *Engine) handleGeneralOptions(c *Context) {
|
||||
if c == nil || c.Request == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Length", "0")
|
||||
if c.Request.ContentLength != 0 {
|
||||
mb := http.MaxBytesReader(c.Writer, c.Request.Body, 4<<10)
|
||||
_, _ = io.Copy(io.Discard, mb)
|
||||
}
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消.
|
||||
// 它可以用于在长连接 (如 SSE) 中监听关闭信号.
|
||||
func (engine *Engine) Context() context.Context {
|
||||
|
|
|
|||
3
go.mod
3
go.mod
|
|
@ -8,9 +8,10 @@ require (
|
|||
github.com/WJQSERVER/wanf v0.0.8
|
||||
github.com/fenthope/reco v0.0.5
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433
|
||||
golang.org/x/net v0.52.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
)
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -12,3 +12,5 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw
|
|||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
|
|
|
|||
53
http2xconnect.go
Normal file
53
http2xconnect.go
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
// Copyright 2026 WJQSERVER. All rights reserved.
|
||||
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||
package touka
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
_ "unsafe"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var enableHTTP2ExtendedConnectOnce sync.Once
|
||||
|
||||
//go:linkname xnetDisableHTTP2ExtendedConnectProtocol golang.org/x/net/http2.disableExtendedConnectProtocol
|
||||
var xnetDisableHTTP2ExtendedConnectProtocol bool
|
||||
|
||||
func enableHTTP2ExtendedConnectProtocol() {
|
||||
enableHTTP2ExtendedConnectOnce.Do(func() {
|
||||
xnetDisableHTTP2ExtendedConnectProtocol = false
|
||||
})
|
||||
}
|
||||
|
||||
func configureHTTP2ExtendedConnectServer(srv *http.Server) error {
|
||||
if srv == nil {
|
||||
return nil
|
||||
}
|
||||
enableHTTP2ExtendedConnectProtocol()
|
||||
return http2.ConfigureServer(srv, nil)
|
||||
}
|
||||
|
||||
func newHTTP2ExtendedConnectTransport(target *url.URL) http.RoundTripper {
|
||||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
transport := &http2.Transport{}
|
||||
if target == nil || !strings.EqualFold(target.Scheme, "http") {
|
||||
return transport
|
||||
}
|
||||
|
||||
transport.AllowHTTP = true
|
||||
transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
|
||||
var dialer net.Dialer
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
return transport
|
||||
}
|
||||
119
reverseproxy.go
119
reverseproxy.go
|
|
@ -67,10 +67,11 @@ var (
|
|||
)
|
||||
|
||||
type reverseProxyHandler struct {
|
||||
config ReverseProxyConfig
|
||||
target *url.URL
|
||||
receivedBy string
|
||||
configError error
|
||||
config ReverseProxyConfig
|
||||
target *url.URL
|
||||
receivedBy string
|
||||
configError error
|
||||
extendedConnectTransport http.RoundTripper
|
||||
}
|
||||
|
||||
type reverseProxyStatusError struct {
|
||||
|
|
@ -208,6 +209,9 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
|
|||
target: target,
|
||||
receivedBy: reverseProxyReceivedBy(config.Via),
|
||||
}
|
||||
if config.Transport == nil {
|
||||
proxy.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
|
||||
}
|
||||
|
||||
if err := validateReverseProxyTarget(target); err != nil {
|
||||
proxy.configError = err
|
||||
|
|
@ -238,7 +242,11 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
|
||||
transport := p.config.Transport
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
if reverseProxyIsExtendedConnectRequest(c.Request) && p.extendedConnectTransport != nil {
|
||||
transport = p.extendedConnectTransport
|
||||
} else {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
}
|
||||
|
||||
updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c)
|
||||
|
|
@ -280,9 +288,17 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
}
|
||||
|
||||
if outreq.Method == http.MethodConnect {
|
||||
if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil {
|
||||
p.handleError(c, err)
|
||||
return
|
||||
if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||
rewriteReverseProxyURL(outreq, p.target)
|
||||
if !p.config.PreserveHost {
|
||||
outreq.Host = ""
|
||||
}
|
||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||
} else {
|
||||
if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil {
|
||||
p.handleError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rewriteReverseProxyURL(outreq, p.target)
|
||||
|
|
@ -367,7 +383,11 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
if !p.modifyResponse(c, res, outreq) {
|
||||
return
|
||||
}
|
||||
if err := p.handleConnectResponse(c, outreq, res, connectWriter); err != nil {
|
||||
handleConnect := p.handleConnectResponse
|
||||
if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||
handleConnect = p.handleExtendedConnectResponse
|
||||
}
|
||||
if err := handleConnect(c, outreq, res, connectWriter); err != nil {
|
||||
p.handleError(c, err)
|
||||
}
|
||||
connectWriter = nil
|
||||
|
|
@ -778,6 +798,72 @@ func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Reques
|
|||
return firstErr
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error {
|
||||
if c == nil || c.Request == nil {
|
||||
res.Body.Close()
|
||||
if backWrite != nil {
|
||||
_ = backWrite.Close()
|
||||
}
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT requires a valid request context")}
|
||||
}
|
||||
if backWrite == nil {
|
||||
res.Body.Close()
|
||||
return &reverseProxyStatusError{
|
||||
status: http.StatusBadGateway,
|
||||
err: errors.New("reverse proxy extended CONNECT tunnel is missing backend writer"),
|
||||
}
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer))
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
res.Body.Close()
|
||||
_ = backWrite.Close()
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
}
|
||||
|
||||
reverseProxyCopyHeader(c.Writer.Header(), res.Header)
|
||||
c.Writer.WriteHeader(res.StatusCode)
|
||||
if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
res.Body.Close()
|
||||
_ = backWrite.Close()
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
}
|
||||
|
||||
errc := make(chan error, 2)
|
||||
go func() {
|
||||
_, err := io.Copy(backWrite, c.Request.Body)
|
||||
closeErr := backWrite.Close()
|
||||
if err != nil && !reverseProxyIsBenignTunnelError(err) {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
errc <- closeErr
|
||||
}()
|
||||
go func() {
|
||||
copyErr := p.copyResponse(c.Writer, res.Body, -1)
|
||||
closeErr := res.Body.Close()
|
||||
if copyErr != nil {
|
||||
errc <- copyErr
|
||||
return
|
||||
}
|
||||
errc <- closeErr
|
||||
}()
|
||||
|
||||
firstErr := <-errc
|
||||
_ = c.Request.Body.Close()
|
||||
_ = backWrite.Close()
|
||||
_ = res.Body.Close()
|
||||
secondErr := <-errc
|
||||
|
||||
for _, err := range []error{firstErr, secondErr} {
|
||||
if reverseProxyIsBenignTunnelError(err) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration {
|
||||
if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" {
|
||||
return -1
|
||||
|
|
@ -968,6 +1054,17 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool {
|
|||
return policy == ForwardedBoth || policy == ForwardedRFC7239Only
|
||||
}
|
||||
|
||||
func reverseProxyIsExtendedConnectRequest(req *http.Request) bool {
|
||||
return reverseProxyExtendedConnectProtocol(req) != ""
|
||||
}
|
||||
|
||||
func reverseProxyExtendedConnectProtocol(req *http.Request) string {
|
||||
if req == nil || req.Method != http.MethodConnect || req.Header == nil {
|
||||
return ""
|
||||
}
|
||||
return textproto.TrimString(req.Header.Get(":protocol"))
|
||||
}
|
||||
|
||||
func isValidForwardedNodeIdentifier(value string) bool {
|
||||
if value == "" {
|
||||
return false
|
||||
|
|
@ -1273,6 +1370,10 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool {
|
|||
return req != nil && req.Context().Value(http.ServerContextKey) != nil
|
||||
}
|
||||
|
||||
func reverseProxyIsBenignTunnelError(err error) bool {
|
||||
return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler)
|
||||
}
|
||||
|
||||
func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter {
|
||||
return UnwrapResponseWriter(writer)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package touka
|
|||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
|
@ -15,6 +16,8 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
||||
|
|
@ -680,7 +683,7 @@ func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) {
|
||||
func TestEngineHandlesOptionsAsteriskLocally(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
engine := New()
|
||||
|
|
@ -695,9 +698,12 @@ func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) {
|
|||
rr := httptest.NewRecorder()
|
||||
engine.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusNotFound {
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code)
|
||||
}
|
||||
if got := rr.Header().Get("Content-Length"); got != "0" {
|
||||
t.Fatalf("unexpected Content-Length header: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyConnectTunnel(t *testing.T) {
|
||||
|
|
@ -848,6 +854,119 @@ func TestReverseProxyConnectNeedsHijacker(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if r.ProtoMajor != 2 {
|
||||
errCh <- fmt.Errorf("unexpected upstream protocol version: %s", r.Proto)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get(":protocol"); got != "websocket" {
|
||||
errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.URL.Path; got != "/ws" {
|
||||
errCh <- fmt.Errorf("unexpected upstream path: %q", got)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(w)
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = controller.Flush()
|
||||
|
||||
line, err := bufio.NewReader(r.Body).ReadString('\n')
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
|
||||
return
|
||||
}
|
||||
if _, err := io.WriteString(w, "echo:"+line); err != nil {
|
||||
errCh <- fmt.Errorf("write tunneled response body failed: %w", err)
|
||||
return
|
||||
}
|
||||
_ = controller.Flush()
|
||||
}))
|
||||
upstream.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
|
||||
t.Fatalf("configure upstream HTTP/2 server: %v", err)
|
||||
}
|
||||
upstream.StartTLS()
|
||||
defer upstream.Close()
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, upstream.URL),
|
||||
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||
Via: "proxy.test",
|
||||
}))
|
||||
|
||||
proxy := httptest.NewUnstartedServer(engine)
|
||||
proxy.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
||||
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
||||
}
|
||||
proxy.StartTLS()
|
||||
defer proxy.Close()
|
||||
|
||||
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
defer transport.CloseIdleConnections()
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
||||
if err != nil {
|
||||
t.Fatalf("new CONNECT request: %v", err)
|
||||
}
|
||||
req.Header.Set(":protocol", "websocket")
|
||||
|
||||
resp, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("round trip extended CONNECT: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "2.0 proxy.test" {
|
||||
t.Fatalf("unexpected Via response header: %#v", gotVia)
|
||||
}
|
||||
|
||||
if _, err := io.WriteString(pw, "ping\n"); err != nil {
|
||||
t.Fatalf("write tunneled request body: %v", err)
|
||||
}
|
||||
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("read tunneled response body: %v", err)
|
||||
}
|
||||
if message != "echo:ping\n" {
|
||||
t.Fatalf("unexpected tunneled response body: %q", message)
|
||||
}
|
||||
if err := pw.Close(); err != nil {
|
||||
t.Fatalf("close tunneled request body: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue