mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
feat(http2): support OPTIONS * and extended CONNECT
This commit is contained in:
parent
ed44c592d3
commit
2165cc4114
8 changed files with 316 additions and 12 deletions
119
reverseproxy.go
119
reverseproxy.go
|
|
@ -67,10 +67,11 @@ var (
|
|||
)
|
||||
|
||||
type reverseProxyHandler struct {
|
||||
config ReverseProxyConfig
|
||||
target *url.URL
|
||||
receivedBy string
|
||||
configError error
|
||||
config ReverseProxyConfig
|
||||
target *url.URL
|
||||
receivedBy string
|
||||
configError error
|
||||
extendedConnectTransport http.RoundTripper
|
||||
}
|
||||
|
||||
type reverseProxyStatusError struct {
|
||||
|
|
@ -208,6 +209,9 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
|
|||
target: target,
|
||||
receivedBy: reverseProxyReceivedBy(config.Via),
|
||||
}
|
||||
if config.Transport == nil {
|
||||
proxy.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
|
||||
}
|
||||
|
||||
if err := validateReverseProxyTarget(target); err != nil {
|
||||
proxy.configError = err
|
||||
|
|
@ -238,7 +242,11 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
|
||||
transport := p.config.Transport
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
if reverseProxyIsExtendedConnectRequest(c.Request) && p.extendedConnectTransport != nil {
|
||||
transport = p.extendedConnectTransport
|
||||
} else {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
}
|
||||
|
||||
updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c)
|
||||
|
|
@ -280,9 +288,17 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
}
|
||||
|
||||
if outreq.Method == http.MethodConnect {
|
||||
if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil {
|
||||
p.handleError(c, err)
|
||||
return
|
||||
if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||
rewriteReverseProxyURL(outreq, p.target)
|
||||
if !p.config.PreserveHost {
|
||||
outreq.Host = ""
|
||||
}
|
||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||
} else {
|
||||
if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil {
|
||||
p.handleError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rewriteReverseProxyURL(outreq, p.target)
|
||||
|
|
@ -367,7 +383,11 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
if !p.modifyResponse(c, res, outreq) {
|
||||
return
|
||||
}
|
||||
if err := p.handleConnectResponse(c, outreq, res, connectWriter); err != nil {
|
||||
handleConnect := p.handleConnectResponse
|
||||
if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||
handleConnect = p.handleExtendedConnectResponse
|
||||
}
|
||||
if err := handleConnect(c, outreq, res, connectWriter); err != nil {
|
||||
p.handleError(c, err)
|
||||
}
|
||||
connectWriter = nil
|
||||
|
|
@ -778,6 +798,72 @@ func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Reques
|
|||
return firstErr
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error {
|
||||
if c == nil || c.Request == nil {
|
||||
res.Body.Close()
|
||||
if backWrite != nil {
|
||||
_ = backWrite.Close()
|
||||
}
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT requires a valid request context")}
|
||||
}
|
||||
if backWrite == nil {
|
||||
res.Body.Close()
|
||||
return &reverseProxyStatusError{
|
||||
status: http.StatusBadGateway,
|
||||
err: errors.New("reverse proxy extended CONNECT tunnel is missing backend writer"),
|
||||
}
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer))
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
res.Body.Close()
|
||||
_ = backWrite.Close()
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
}
|
||||
|
||||
reverseProxyCopyHeader(c.Writer.Header(), res.Header)
|
||||
c.Writer.WriteHeader(res.StatusCode)
|
||||
if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
res.Body.Close()
|
||||
_ = backWrite.Close()
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
}
|
||||
|
||||
errc := make(chan error, 2)
|
||||
go func() {
|
||||
_, err := io.Copy(backWrite, c.Request.Body)
|
||||
closeErr := backWrite.Close()
|
||||
if err != nil && !reverseProxyIsBenignTunnelError(err) {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
errc <- closeErr
|
||||
}()
|
||||
go func() {
|
||||
copyErr := p.copyResponse(c.Writer, res.Body, -1)
|
||||
closeErr := res.Body.Close()
|
||||
if copyErr != nil {
|
||||
errc <- copyErr
|
||||
return
|
||||
}
|
||||
errc <- closeErr
|
||||
}()
|
||||
|
||||
firstErr := <-errc
|
||||
_ = c.Request.Body.Close()
|
||||
_ = backWrite.Close()
|
||||
_ = res.Body.Close()
|
||||
secondErr := <-errc
|
||||
|
||||
for _, err := range []error{firstErr, secondErr} {
|
||||
if reverseProxyIsBenignTunnelError(err) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration {
|
||||
if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" {
|
||||
return -1
|
||||
|
|
@ -968,6 +1054,17 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool {
|
|||
return policy == ForwardedBoth || policy == ForwardedRFC7239Only
|
||||
}
|
||||
|
||||
func reverseProxyIsExtendedConnectRequest(req *http.Request) bool {
|
||||
return reverseProxyExtendedConnectProtocol(req) != ""
|
||||
}
|
||||
|
||||
func reverseProxyExtendedConnectProtocol(req *http.Request) string {
|
||||
if req == nil || req.Method != http.MethodConnect || req.Header == nil {
|
||||
return ""
|
||||
}
|
||||
return textproto.TrimString(req.Header.Get(":protocol"))
|
||||
}
|
||||
|
||||
func isValidForwardedNodeIdentifier(value string) bool {
|
||||
if value == "" {
|
||||
return false
|
||||
|
|
@ -1273,6 +1370,10 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool {
|
|||
return req != nil && req.Context().Value(http.ServerContextKey) != nil
|
||||
}
|
||||
|
||||
func reverseProxyIsBenignTunnelError(err error) bool {
|
||||
return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler)
|
||||
}
|
||||
|
||||
func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter {
|
||||
return UnwrapResponseWriter(writer)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue