mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
fix(reverseproxy): align forwarding and tunnel semantics
This commit is contained in:
parent
c019f24e99
commit
ed44c592d3
6 changed files with 864 additions and 26 deletions
368
reverseproxy.go
368
reverseproxy.go
|
|
@ -14,6 +14,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/http/httputil"
|
||||
"net/netip"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
|
|
@ -217,6 +218,12 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
|
|||
default:
|
||||
proxy.config.ForwardedHeaders = ForwardedBoth
|
||||
}
|
||||
proxy.config.ForwardedBy = strings.TrimSpace(proxy.config.ForwardedBy)
|
||||
if reverseProxyUsesForwardedHeader(proxy.config.ForwardedHeaders) {
|
||||
if err := validateReverseProxyForwardedBy(proxy.config.ForwardedBy); err != nil {
|
||||
proxy.configError = err
|
||||
}
|
||||
}
|
||||
|
||||
return proxy
|
||||
}
|
||||
|
|
@ -234,11 +241,20 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
transport = http.DefaultTransport
|
||||
}
|
||||
|
||||
updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c)
|
||||
if err != nil {
|
||||
p.handleError(c, err)
|
||||
return
|
||||
}
|
||||
if handledLocally {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := p.requestContext(c)
|
||||
defer cancel()
|
||||
|
||||
outreq := c.Request.Clone(ctx)
|
||||
if c.Request.ContentLength == 0 {
|
||||
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
|
||||
outreq.Body = nil
|
||||
}
|
||||
if outreq.Body != nil {
|
||||
|
|
@ -249,12 +265,35 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
outreq.Header = make(http.Header)
|
||||
}
|
||||
outreq.Close = false
|
||||
|
||||
rewriteReverseProxyURL(outreq, p.target)
|
||||
if !p.config.PreserveHost {
|
||||
outreq.Host = ""
|
||||
var connectWriter *io.PipeWriter
|
||||
defer func() {
|
||||
if connectWriter != nil {
|
||||
_ = connectWriter.Close()
|
||||
}
|
||||
}()
|
||||
if outreq.Method == http.MethodConnect {
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
outreq.Body = pipeReader
|
||||
outreq.ContentLength = -1
|
||||
defer outreq.Body.Close()
|
||||
connectWriter = pipeWriter
|
||||
}
|
||||
|
||||
if outreq.Method == http.MethodConnect {
|
||||
if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil {
|
||||
p.handleError(c, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
rewriteReverseProxyURL(outreq, p.target)
|
||||
if !p.config.PreserveHost {
|
||||
outreq.Host = ""
|
||||
}
|
||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||
}
|
||||
if updatedMaxForwards != "" {
|
||||
outreq.Header.Set("Max-Forwards", updatedMaxForwards)
|
||||
}
|
||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||
|
||||
reqUpType := reverseProxyUpgradeType(outreq.Header)
|
||||
if reqUpType != "" && !isPrintableASCII(reqUpType) {
|
||||
|
|
@ -318,6 +357,23 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
return
|
||||
}
|
||||
|
||||
if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices {
|
||||
removeHopByHopHeaders(res.Header)
|
||||
res.Header.Del("Content-Length")
|
||||
res.Header.Del("Transfer-Encoding")
|
||||
res.ContentLength = -1
|
||||
res.TransferEncoding = nil
|
||||
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
||||
if !p.modifyResponse(c, res, outreq) {
|
||||
return
|
||||
}
|
||||
if err := p.handleConnectResponse(c, outreq, res, connectWriter); err != nil {
|
||||
p.handleError(c, err)
|
||||
}
|
||||
connectWriter = nil
|
||||
return
|
||||
}
|
||||
|
||||
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
||||
if !p.modifyResponse(c, res, outreq) {
|
||||
|
|
@ -353,6 +409,9 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
defer res.Body.Close()
|
||||
c.AddError(fmt.Errorf("reverse proxy body copy failed: %w", err))
|
||||
p.logf(c, "reverse proxy body copy failed: %v", err)
|
||||
if reverseProxyShouldPanicOnCopyError(c.Request) {
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
return
|
||||
}
|
||||
res.Body.Close()
|
||||
|
|
@ -378,6 +437,86 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
}
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) {
|
||||
if c == nil || c.Request == nil {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
switch c.Request.Method {
|
||||
case http.MethodOptions, http.MethodTrace:
|
||||
default:
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
rawValue := textproto.TrimString(c.Request.Header.Get("Max-Forwards"))
|
||||
if rawValue == "" {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
value, err := strconv.Atoi(rawValue)
|
||||
if err != nil || value < 0 {
|
||||
return "", false, &reverseProxyStatusError{
|
||||
status: http.StatusBadRequest,
|
||||
err: fmt.Errorf("invalid Max-Forwards value %q", rawValue),
|
||||
}
|
||||
}
|
||||
if value == 0 {
|
||||
switch c.Request.Method {
|
||||
case http.MethodTrace:
|
||||
return "", true, p.writeLocalTraceResponse(c)
|
||||
case http.MethodOptions:
|
||||
p.writeLocalOptionsResponse(c)
|
||||
return "", true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return strconv.Itoa(value - 1), false, nil
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) writeLocalTraceResponse(c *Context) error {
|
||||
if c == nil || c.Request == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
traceReq := c.Request.Clone(c.Request.Context())
|
||||
traceReq.Body = nil
|
||||
traceReq.ContentLength = 0
|
||||
traceReq.TransferEncoding = nil
|
||||
traceReq.RequestURI = c.Request.RequestURI
|
||||
if traceReq.RequestURI == "" && traceReq.URL != nil {
|
||||
traceReq.RequestURI = traceReq.URL.RequestURI()
|
||||
}
|
||||
traceReq.Header = traceReq.Header.Clone()
|
||||
for _, key := range []string{"Authorization", "Proxy-Authorization", "Cookie", "Forwarded", "X-Forwarded-For", "X-Forwarded-Host", "X-Forwarded-Proto", "Content-Length", "Transfer-Encoding", "Trailer"} {
|
||||
traceReq.Header.Del(key)
|
||||
}
|
||||
|
||||
dump, err := httputil.DumpRequest(traceReq, false)
|
||||
if err != nil {
|
||||
return &reverseProxyStatusError{status: http.StatusInternalServerError, err: err}
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "message/http")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, err = c.Writer.Write(dump)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.engine != nil {
|
||||
if c.Request != nil && c.Request.RequestURI != "*" {
|
||||
if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request)); len(allow) > 0 {
|
||||
c.Writer.Header().Set("Allow", strings.Join(allow, ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) requestContext(c *Context) (context.Context, context.CancelFunc) {
|
||||
ctx := c.Request.Context()
|
||||
if ctx.Done() != nil {
|
||||
|
|
@ -522,7 +661,11 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques
|
|||
clientConn, brw, err := c.Writer.Hijack()
|
||||
if err != nil {
|
||||
backConn.Close()
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
status := http.StatusBadGateway
|
||||
if errors.Is(err, http.ErrNotSupported) {
|
||||
status = http.StatusNotImplemented
|
||||
}
|
||||
return &reverseProxyStatusError{status: status, err: err}
|
||||
}
|
||||
|
||||
defer clientConn.Close()
|
||||
|
|
@ -561,6 +704,80 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques
|
|||
return firstErr
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error {
|
||||
if backWrite == nil {
|
||||
res.Body.Close()
|
||||
return &reverseProxyStatusError{
|
||||
status: http.StatusBadGateway,
|
||||
err: errors.New("reverse proxy CONNECT tunnel is missing backend writer"),
|
||||
}
|
||||
}
|
||||
backRead := res.Body
|
||||
|
||||
clientConn, brw, err := c.Writer.Hijack()
|
||||
if err != nil {
|
||||
backRead.Close()
|
||||
_ = backWrite.Close()
|
||||
status := http.StatusBadGateway
|
||||
if errors.Is(err, http.ErrNotSupported) {
|
||||
status = http.StatusNotImplemented
|
||||
}
|
||||
return &reverseProxyStatusError{status: status, err: err}
|
||||
}
|
||||
|
||||
defer clientConn.Close()
|
||||
defer backRead.Close()
|
||||
defer backWrite.Close()
|
||||
|
||||
backConnClosed := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
case <-backConnClosed:
|
||||
}
|
||||
backRead.Close()
|
||||
_ = backWrite.Close()
|
||||
}()
|
||||
defer close(backConnClosed)
|
||||
|
||||
res.Body = nil
|
||||
if err := res.Write(brw); err != nil {
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
}
|
||||
if err := brw.Flush(); err != nil {
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
}
|
||||
|
||||
errc := make(chan error, 2)
|
||||
go func() {
|
||||
if _, err := io.Copy(clientConn, backRead); err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
if cw, ok := clientConn.(interface{ CloseWrite() error }); ok {
|
||||
errc <- cw.CloseWrite()
|
||||
return
|
||||
}
|
||||
errc <- errReverseProxyCopyDone
|
||||
}()
|
||||
go func() {
|
||||
if _, err := io.Copy(backWrite, clientConn); err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
errc <- backWrite.Close()
|
||||
}()
|
||||
|
||||
firstErr := <-errc
|
||||
if firstErr == nil {
|
||||
firstErr = <-errc
|
||||
}
|
||||
if errors.Is(firstErr, errReverseProxyCopyDone) || errors.Is(firstErr, net.ErrClosed) || errors.Is(firstErr, io.EOF) || errors.Is(firstErr, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration {
|
||||
if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" {
|
||||
return -1
|
||||
|
|
@ -638,6 +855,10 @@ func reverseProxyStatusCode(err error) int {
|
|||
if errors.As(err, &statusErr) && statusErr.status > 0 {
|
||||
return statusErr.status
|
||||
}
|
||||
var netErr net.Error
|
||||
if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) {
|
||||
return http.StatusGatewayTimeout
|
||||
}
|
||||
return http.StatusBadGateway
|
||||
}
|
||||
|
||||
|
|
@ -651,6 +872,17 @@ func validateReverseProxyTarget(target *url.URL) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func validateReverseProxyForwardedBy(value string) error {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
if !isValidForwardedNodeIdentifier(trimmed) {
|
||||
return fmt.Errorf("reverse proxy ForwardedBy must be an RFC 7239 node identifier, got %q", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeReverseProxyTarget(target *url.URL) {
|
||||
switch strings.ToLower(target.Scheme) {
|
||||
case "ws":
|
||||
|
|
@ -732,6 +964,83 @@ func buildForwardedHeaderValue(clientIP, by, host, scheme string) string {
|
|||
return strings.Join(pairs, ";")
|
||||
}
|
||||
|
||||
func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool {
|
||||
return policy == ForwardedBoth || policy == ForwardedRFC7239Only
|
||||
}
|
||||
|
||||
func isValidForwardedNodeIdentifier(value string) bool {
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(value, "[") {
|
||||
closing := strings.IndexByte(value, ']')
|
||||
if closing <= 1 {
|
||||
return false
|
||||
}
|
||||
addr, err := netip.ParseAddr(value[1:closing])
|
||||
if err != nil || !addr.Is6() {
|
||||
return false
|
||||
}
|
||||
if closing == len(value)-1 {
|
||||
return true
|
||||
}
|
||||
if value[closing+1] != ':' {
|
||||
return false
|
||||
}
|
||||
return isValidForwardedNodePort(value[closing+2:])
|
||||
}
|
||||
|
||||
host, port, hasPort := strings.Cut(value, ":")
|
||||
if hasPort {
|
||||
switch {
|
||||
case host == "unknown", isValidForwardedObfuscatedIdentifier(host):
|
||||
return isValidForwardedNodePort(port)
|
||||
default:
|
||||
addr, err := netip.ParseAddr(host)
|
||||
return err == nil && addr.Is4() && isValidForwardedNodePort(port)
|
||||
}
|
||||
}
|
||||
|
||||
if value == "unknown" || isValidForwardedObfuscatedIdentifier(value) {
|
||||
return true
|
||||
}
|
||||
addr, err := netip.ParseAddr(value)
|
||||
return err == nil && addr.Is4()
|
||||
}
|
||||
|
||||
func isValidForwardedNodePort(value string) bool {
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
if isValidForwardedObfuscatedIdentifier(value) {
|
||||
return true
|
||||
}
|
||||
if len(value) > 5 {
|
||||
return false
|
||||
}
|
||||
port, err := strconv.Atoi(value)
|
||||
return err == nil && port > 0 && port <= 65535
|
||||
}
|
||||
|
||||
func isValidForwardedObfuscatedIdentifier(value string) bool {
|
||||
if len(value) < 2 || value[0] != '_' {
|
||||
return false
|
||||
}
|
||||
for i := 1; i < len(value); i++ {
|
||||
b := value[i]
|
||||
if (b >= '0' && b <= '9') || (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') {
|
||||
continue
|
||||
}
|
||||
switch b {
|
||||
case '.', '_', '-':
|
||||
continue
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func formatForwardedFor(clientIP string) string {
|
||||
addr, err := netip.ParseAddr(clientIP)
|
||||
if err != nil {
|
||||
|
|
@ -817,6 +1126,47 @@ func rewriteReverseProxyURL(req *http.Request, target *url.URL) {
|
|||
}
|
||||
}
|
||||
|
||||
func rewriteReverseProxyConnectRequest(req *http.Request, target *url.URL) error {
|
||||
connectTarget, err := reverseProxyConnectTarget(target)
|
||||
if err != nil {
|
||||
return &reverseProxyStatusError{status: http.StatusBadRequest, err: err}
|
||||
}
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
req.URL.Path = ""
|
||||
req.URL.RawPath = ""
|
||||
req.URL.RawQuery = ""
|
||||
req.URL.Opaque = connectTarget
|
||||
req.Host = connectTarget
|
||||
return nil
|
||||
}
|
||||
|
||||
func reverseProxyConnectTarget(target *url.URL) (string, error) {
|
||||
if target == nil {
|
||||
return "", errReverseProxyNilTarget
|
||||
}
|
||||
host := target.Hostname()
|
||||
if host == "" {
|
||||
return "", errReverseProxyInvalidTarget
|
||||
}
|
||||
port := target.Port()
|
||||
if port == "" {
|
||||
switch strings.ToLower(target.Scheme) {
|
||||
case "http":
|
||||
port = "80"
|
||||
case "https":
|
||||
port = "443"
|
||||
default:
|
||||
return "", fmt.Errorf("reverse proxy CONNECT target requires a supported scheme, got %q", target.Scheme)
|
||||
}
|
||||
}
|
||||
portNum, err := strconv.Atoi(port)
|
||||
if err != nil || portNum <= 0 || portNum > 65535 {
|
||||
return "", fmt.Errorf("reverse proxy CONNECT target has invalid port %q", port)
|
||||
}
|
||||
return net.JoinHostPort(host, port), nil
|
||||
}
|
||||
|
||||
func joinReverseProxyURLPath(base, incoming *url.URL) (string, string) {
|
||||
if base.RawPath == "" && incoming.RawPath == "" {
|
||||
return reverseProxySingleJoiningSlash(base.Path, incoming.Path), ""
|
||||
|
|
@ -919,6 +1269,10 @@ func cleanReverseProxyQueryParams(rawQuery string) string {
|
|||
return values.Encode()
|
||||
}
|
||||
|
||||
func reverseProxyShouldPanicOnCopyError(req *http.Request) bool {
|
||||
return req != nil && req.Context().Value(http.ServerContextKey) != nil
|
||||
}
|
||||
|
||||
func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter {
|
||||
return UnwrapResponseWriter(writer)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue