mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
feat(http2): support OPTIONS * and extended CONNECT
This commit is contained in:
parent
ed44c592d3
commit
2165cc4114
8 changed files with 316 additions and 12 deletions
|
|
@ -292,6 +292,7 @@ Touka 会尽量遵循代理链语义:
|
||||||
Touka 的反向代理实现支持以下能力:
|
Touka 的反向代理实现支持以下能力:
|
||||||
|
|
||||||
- `CONNECT` 隧道转发(HTTP/1.x)
|
- `CONNECT` 隧道转发(HTTP/1.x)
|
||||||
|
- HTTP/2 extended `CONNECT`
|
||||||
- `Connection: Upgrade` / `Upgrade` 协议升级转发
|
- `Connection: Upgrade` / `Upgrade` 协议升级转发
|
||||||
- WebSocket 等 101 Switching Protocols 场景
|
- WebSocket 等 101 Switching Protocols 场景
|
||||||
- SSE(Server-Sent Events)立即刷新
|
- SSE(Server-Sent Events)立即刷新
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,8 @@ r.ANY("/any", handle)
|
||||||
r.HandleFunc([]string{"GET", "POST"}, "/multi", handle)
|
r.HandleFunc([]string{"GET", "POST"}, "/multi", handle)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
服务器级 `OPTIONS *` 请求不需要单独注册路由。Touka 会直接返回一个空的 `200 OK` 响应,而不会把它当成 `/` 路由来匹配。
|
||||||
|
|
||||||
## 路径参数 (Named Parameters)
|
## 路径参数 (Named Parameters)
|
||||||
|
|
||||||
使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。
|
使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。
|
||||||
|
|
|
||||||
25
engine.go
25
engine.go
|
|
@ -7,6 +7,7 @@ package touka
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
@ -344,6 +345,11 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) {
|
||||||
func (engine *Engine) applyDefaultServerConfig(srv *http.Server) {
|
func (engine *Engine) applyDefaultServerConfig(srv *http.Server) {
|
||||||
if engine.serverProtocols != nil {
|
if engine.serverProtocols != nil {
|
||||||
srv.Protocols = engine.serverProtocols
|
srv.Protocols = engine.serverProtocols
|
||||||
|
if engine.serverProtocols.HTTP2() || engine.serverProtocols.UnencryptedHTTP2() {
|
||||||
|
if err := configureHTTP2ExtendedConnectServer(srv); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -695,6 +701,11 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
// handleRequest 负责根据请求查找路由并执行相应的处理函数链
|
// handleRequest 负责根据请求查找路由并执行相应的处理函数链
|
||||||
// 这是路由查找和执行的核心逻辑
|
// 这是路由查找和执行的核心逻辑
|
||||||
func (engine *Engine) handleRequest(c *Context) {
|
func (engine *Engine) handleRequest(c *Context) {
|
||||||
|
if isGeneralOptionsRequest(c.Request) {
|
||||||
|
engine.handleGeneralOptions(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
httpMethod := c.Request.Method
|
httpMethod := c.Request.Method
|
||||||
requestPath := routeLookupPath(c.Request)
|
requestPath := routeLookupPath(c.Request)
|
||||||
|
|
||||||
|
|
@ -808,6 +819,20 @@ func (engine *Engine) allowedMethodsForPath(requestPath string) []string {
|
||||||
return allowedMethods
|
return allowedMethods
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (engine *Engine) handleGeneralOptions(c *Context) {
|
||||||
|
if c == nil || c.Request == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Length", "0")
|
||||||
|
if c.Request.ContentLength != 0 {
|
||||||
|
mb := http.MaxBytesReader(c.Writer, c.Request.Body, 4<<10)
|
||||||
|
_, _ = io.Copy(io.Discard, mb)
|
||||||
|
}
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
|
||||||
// Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消.
|
// Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消.
|
||||||
// 它可以用于在长连接 (如 SSE) 中监听关闭信号.
|
// 它可以用于在长连接 (如 SSE) 中监听关闭信号.
|
||||||
func (engine *Engine) Context() context.Context {
|
func (engine *Engine) Context() context.Context {
|
||||||
|
|
|
||||||
3
go.mod
3
go.mod
|
|
@ -8,9 +8,10 @@ require (
|
||||||
github.com/WJQSERVER/wanf v0.0.8
|
github.com/WJQSERVER/wanf v0.0.8
|
||||||
github.com/fenthope/reco v0.0.5
|
github.com/fenthope/reco v0.0.5
|
||||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433
|
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433
|
||||||
|
golang.org/x/net v0.52.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||||
golang.org/x/net v0.52.0 // indirect
|
golang.org/x/text v0.35.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
|
||||||
2
go.sum
2
go.sum
|
|
@ -12,3 +12,5 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw
|
||||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||||
|
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||||
|
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||||
|
|
|
||||||
53
http2xconnect.go
Normal file
53
http2xconnect.go
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
// Copyright 2026 WJQSERVER. All rights reserved.
|
||||||
|
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
_ "unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var enableHTTP2ExtendedConnectOnce sync.Once
|
||||||
|
|
||||||
|
//go:linkname xnetDisableHTTP2ExtendedConnectProtocol golang.org/x/net/http2.disableExtendedConnectProtocol
|
||||||
|
var xnetDisableHTTP2ExtendedConnectProtocol bool
|
||||||
|
|
||||||
|
func enableHTTP2ExtendedConnectProtocol() {
|
||||||
|
enableHTTP2ExtendedConnectOnce.Do(func() {
|
||||||
|
xnetDisableHTTP2ExtendedConnectProtocol = false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureHTTP2ExtendedConnectServer(srv *http.Server) error {
|
||||||
|
if srv == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
enableHTTP2ExtendedConnectProtocol()
|
||||||
|
return http2.ConfigureServer(srv, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP2ExtendedConnectTransport(target *url.URL) http.RoundTripper {
|
||||||
|
enableHTTP2ExtendedConnectProtocol()
|
||||||
|
|
||||||
|
transport := &http2.Transport{}
|
||||||
|
if target == nil || !strings.EqualFold(target.Scheme, "http") {
|
||||||
|
return transport
|
||||||
|
}
|
||||||
|
|
||||||
|
transport.AllowHTTP = true
|
||||||
|
transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
|
||||||
|
var dialer net.Dialer
|
||||||
|
return dialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
return transport
|
||||||
|
}
|
||||||
103
reverseproxy.go
103
reverseproxy.go
|
|
@ -71,6 +71,7 @@ type reverseProxyHandler struct {
|
||||||
target *url.URL
|
target *url.URL
|
||||||
receivedBy string
|
receivedBy string
|
||||||
configError error
|
configError error
|
||||||
|
extendedConnectTransport http.RoundTripper
|
||||||
}
|
}
|
||||||
|
|
||||||
type reverseProxyStatusError struct {
|
type reverseProxyStatusError struct {
|
||||||
|
|
@ -208,6 +209,9 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
|
||||||
target: target,
|
target: target,
|
||||||
receivedBy: reverseProxyReceivedBy(config.Via),
|
receivedBy: reverseProxyReceivedBy(config.Via),
|
||||||
}
|
}
|
||||||
|
if config.Transport == nil {
|
||||||
|
proxy.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
|
||||||
|
}
|
||||||
|
|
||||||
if err := validateReverseProxyTarget(target); err != nil {
|
if err := validateReverseProxyTarget(target); err != nil {
|
||||||
proxy.configError = err
|
proxy.configError = err
|
||||||
|
|
@ -238,8 +242,12 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
|
|
||||||
transport := p.config.Transport
|
transport := p.config.Transport
|
||||||
if transport == nil {
|
if transport == nil {
|
||||||
|
if reverseProxyIsExtendedConnectRequest(c.Request) && p.extendedConnectTransport != nil {
|
||||||
|
transport = p.extendedConnectTransport
|
||||||
|
} else {
|
||||||
transport = http.DefaultTransport
|
transport = http.DefaultTransport
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c)
|
updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -280,10 +288,18 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if outreq.Method == http.MethodConnect {
|
if outreq.Method == http.MethodConnect {
|
||||||
|
if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||||
|
rewriteReverseProxyURL(outreq, p.target)
|
||||||
|
if !p.config.PreserveHost {
|
||||||
|
outreq.Host = ""
|
||||||
|
}
|
||||||
|
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||||
|
} else {
|
||||||
if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil {
|
if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil {
|
||||||
p.handleError(c, err)
|
p.handleError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
rewriteReverseProxyURL(outreq, p.target)
|
rewriteReverseProxyURL(outreq, p.target)
|
||||||
if !p.config.PreserveHost {
|
if !p.config.PreserveHost {
|
||||||
|
|
@ -367,7 +383,11 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
if !p.modifyResponse(c, res, outreq) {
|
if !p.modifyResponse(c, res, outreq) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := p.handleConnectResponse(c, outreq, res, connectWriter); err != nil {
|
handleConnect := p.handleConnectResponse
|
||||||
|
if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||||
|
handleConnect = p.handleExtendedConnectResponse
|
||||||
|
}
|
||||||
|
if err := handleConnect(c, outreq, res, connectWriter); err != nil {
|
||||||
p.handleError(c, err)
|
p.handleError(c, err)
|
||||||
}
|
}
|
||||||
connectWriter = nil
|
connectWriter = nil
|
||||||
|
|
@ -778,6 +798,72 @@ func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Reques
|
||||||
return firstErr
|
return firstErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error {
|
||||||
|
if c == nil || c.Request == nil {
|
||||||
|
res.Body.Close()
|
||||||
|
if backWrite != nil {
|
||||||
|
_ = backWrite.Close()
|
||||||
|
}
|
||||||
|
return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT requires a valid request context")}
|
||||||
|
}
|
||||||
|
if backWrite == nil {
|
||||||
|
res.Body.Close()
|
||||||
|
return &reverseProxyStatusError{
|
||||||
|
status: http.StatusBadGateway,
|
||||||
|
err: errors.New("reverse proxy extended CONNECT tunnel is missing backend writer"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer))
|
||||||
|
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||||
|
res.Body.Close()
|
||||||
|
_ = backWrite.Close()
|
||||||
|
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
reverseProxyCopyHeader(c.Writer.Header(), res.Header)
|
||||||
|
c.Writer.WriteHeader(res.StatusCode)
|
||||||
|
if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||||
|
res.Body.Close()
|
||||||
|
_ = backWrite.Close()
|
||||||
|
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
errc := make(chan error, 2)
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(backWrite, c.Request.Body)
|
||||||
|
closeErr := backWrite.Close()
|
||||||
|
if err != nil && !reverseProxyIsBenignTunnelError(err) {
|
||||||
|
errc <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errc <- closeErr
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
copyErr := p.copyResponse(c.Writer, res.Body, -1)
|
||||||
|
closeErr := res.Body.Close()
|
||||||
|
if copyErr != nil {
|
||||||
|
errc <- copyErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errc <- closeErr
|
||||||
|
}()
|
||||||
|
|
||||||
|
firstErr := <-errc
|
||||||
|
_ = c.Request.Body.Close()
|
||||||
|
_ = backWrite.Close()
|
||||||
|
_ = res.Body.Close()
|
||||||
|
secondErr := <-errc
|
||||||
|
|
||||||
|
for _, err := range []error{firstErr, secondErr} {
|
||||||
|
if reverseProxyIsBenignTunnelError(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration {
|
func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration {
|
||||||
if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" {
|
if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" {
|
||||||
return -1
|
return -1
|
||||||
|
|
@ -968,6 +1054,17 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool {
|
||||||
return policy == ForwardedBoth || policy == ForwardedRFC7239Only
|
return policy == ForwardedBoth || policy == ForwardedRFC7239Only
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func reverseProxyIsExtendedConnectRequest(req *http.Request) bool {
|
||||||
|
return reverseProxyExtendedConnectProtocol(req) != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxyExtendedConnectProtocol(req *http.Request) string {
|
||||||
|
if req == nil || req.Method != http.MethodConnect || req.Header == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return textproto.TrimString(req.Header.Get(":protocol"))
|
||||||
|
}
|
||||||
|
|
||||||
func isValidForwardedNodeIdentifier(value string) bool {
|
func isValidForwardedNodeIdentifier(value string) bool {
|
||||||
if value == "" {
|
if value == "" {
|
||||||
return false
|
return false
|
||||||
|
|
@ -1273,6 +1370,10 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool {
|
||||||
return req != nil && req.Context().Value(http.ServerContextKey) != nil
|
return req != nil && req.Context().Value(http.ServerContextKey) != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func reverseProxyIsBenignTunnelError(err error) bool {
|
||||||
|
return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler)
|
||||||
|
}
|
||||||
|
|
||||||
func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter {
|
func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter {
|
||||||
return UnwrapResponseWriter(writer)
|
return UnwrapResponseWriter(writer)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package touka
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
@ -15,6 +16,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
||||||
|
|
@ -680,7 +683,7 @@ func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) {
|
func TestEngineHandlesOptionsAsteriskLocally(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
engine := New()
|
engine := New()
|
||||||
|
|
@ -695,9 +698,12 @@ func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) {
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
engine.ServeHTTP(rr, req)
|
engine.ServeHTTP(rr, req)
|
||||||
|
|
||||||
if rr.Code != http.StatusNotFound {
|
if rr.Code != http.StatusOK {
|
||||||
t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code)
|
t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code)
|
||||||
}
|
}
|
||||||
|
if got := rr.Header().Get("Content-Length"); got != "0" {
|
||||||
|
t.Fatalf("unexpected Content-Length header: %q", got)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReverseProxyConnectTunnel(t *testing.T) {
|
func TestReverseProxyConnectTunnel(t *testing.T) {
|
||||||
|
|
@ -848,6 +854,119 @@ func TestReverseProxyConnectNeedsHijacker(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
enableHTTP2ExtendedConnectProtocol()
|
||||||
|
|
||||||
|
errCh := make(chan error, 4)
|
||||||
|
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodConnect {
|
||||||
|
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
||||||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.ProtoMajor != 2 {
|
||||||
|
errCh <- fmt.Errorf("unexpected upstream protocol version: %s", r.Proto)
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got := r.Header.Get(":protocol"); got != "websocket" {
|
||||||
|
errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got)
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got := r.URL.Path; got != "/ws" {
|
||||||
|
errCh <- fmt.Errorf("unexpected upstream path: %q", got)
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
controller := http.NewResponseController(w)
|
||||||
|
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||||
|
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_ = controller.Flush()
|
||||||
|
|
||||||
|
line, err := bufio.NewReader(r.Body).ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := io.WriteString(w, "echo:"+line); err != nil {
|
||||||
|
errCh <- fmt.Errorf("write tunneled response body failed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = controller.Flush()
|
||||||
|
}))
|
||||||
|
upstream.EnableHTTP2 = true
|
||||||
|
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
|
||||||
|
t.Fatalf("configure upstream HTTP/2 server: %v", err)
|
||||||
|
}
|
||||||
|
upstream.StartTLS()
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: mustParseURL(t, upstream.URL),
|
||||||
|
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||||
|
Via: "proxy.test",
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewUnstartedServer(engine)
|
||||||
|
proxy.EnableHTTP2 = true
|
||||||
|
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
||||||
|
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
||||||
|
}
|
||||||
|
proxy.StartTLS()
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||||
|
defer transport.CloseIdleConnections()
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new CONNECT request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set(":protocol", "websocket")
|
||||||
|
|
||||||
|
resp, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("round trip extended CONNECT: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "2.0 proxy.test" {
|
||||||
|
t.Fatalf("unexpected Via response header: %#v", gotVia)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.WriteString(pw, "ping\n"); err != nil {
|
||||||
|
t.Fatalf("write tunneled request body: %v", err)
|
||||||
|
}
|
||||||
|
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read tunneled response body: %v", err)
|
||||||
|
}
|
||||||
|
if message != "echo:ping\n" {
|
||||||
|
t.Fatalf("unexpected tunneled response body: %q", message)
|
||||||
|
}
|
||||||
|
if err := pw.Close(); err != nil {
|
||||||
|
t.Fatalf("close tunneled request body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
t.Fatal(err)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) {
|
func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue