// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at https://mozilla.org/MPL/2.0/. // Copyright 2026 WJQSERVER. All rights reserved. // All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. package touka import ( "context" "crypto/rand" "encoding/base64" "errors" "fmt" "io" "log" "mime" "net" "net/http" "net/http/httptrace" "net/http/httputil" "net/netip" "net/textproto" "net/url" "strconv" "strings" "sync" "sync/atomic" "time" "golang.org/x/net/http2" ) // ForwardedHeadersPolicy controls how forwarding headers are generated. // The zero value uses both X-Forwarded-* and RFC 7239 Forwarded headers. type ForwardedHeadersPolicy int const ( ForwardedBoth ForwardedHeadersPolicy = iota ForwardedNone ForwardedXForwardedOnly ForwardedRFC7239Only ) // BufferPool provides temporary buffers for response body copying. type BufferPool interface { Get() []byte Put([]byte) } // ReverseProxyConfig configures the reverse proxy handler. type ReverseProxyConfig struct { Target *url.URL Targets []string LoadBalancing ReverseProxyLoadBalancingConfig PassiveHealth ReverseProxyPassiveHealthConfig Transport http.RoundTripper FlushInterval time.Duration BufferPool BufferPool AllowH2CUpstream bool ModifyRequest func(*http.Request) ModifyResponse func(*http.Response) error ErrorHandler func(http.ResponseWriter, *http.Request, error) ForwardedHeaders ForwardedHeadersPolicy ForwardedBy string Via string PreserveHost bool } var ( errReverseProxyNilTarget = errors.New("reverse proxy target is nil") errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams") ) type reverseProxyHandler struct { config ReverseProxyConfig upstreams []*reverseProxyUpstream receivedBy string configError error roundRobin atomic.Uint64 } type reverseProxyStatusError struct { status int err error } type reverseProxyExtendedConnectBridge struct { body io.ReadCloser } type reverseProxyH2ReadWriteCloser struct { io.ReadCloser ResponseWriter controller *http.ResponseController } func (rwc *reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) { n, err := rwc.ResponseWriter.Write(p) if err != nil { return n, err } if err := rwc.controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { return n, err } return n, nil } func (rwc *reverseProxyH2ReadWriteCloser) Close() error { if rwc.ReadCloser == nil { return nil } return rwc.ReadCloser.Close() } func (e *reverseProxyStatusError) Error() string { if e == nil || e.err == nil { return "" } return e.err.Error() } func (e *reverseProxyStatusError) Unwrap() error { if e == nil { return nil } return e.err } type noopCloseReader struct { readCloser io.ReadCloser closed atomic.Bool } func (n *noopCloseReader) Read(p []byte) (int, error) { if n.closed.Load() { return 0, errors.New("reverse proxy read on closed body") } return n.readCloser.Read(p) } func (n *noopCloseReader) Close() error { n.closed.Store(true) return nil } type maxLatencyWriter struct { dst ResponseWriter latency time.Duration mu sync.Mutex t *time.Timer flushPending bool } func (m *maxLatencyWriter) Write(p []byte) (int, error) { m.mu.Lock() defer m.mu.Unlock() n, err := m.dst.Write(p) if m.latency < 0 { m.dst.Flush() return n, err } if m.flushPending { return n, err } if m.t == nil { m.t = time.AfterFunc(m.latency, m.delayedFlush) } else { m.t.Reset(m.latency) } m.flushPending = true return n, err } func (m *maxLatencyWriter) delayedFlush() { m.mu.Lock() defer m.mu.Unlock() if !m.flushPending { return } m.dst.Flush() m.flushPending = false } func (m *maxLatencyWriter) stop() { m.mu.Lock() defer m.mu.Unlock() m.flushPending = false if m.t != nil { m.t.Stop() } } type switchProtocolCopier struct { user io.ReadWriter backend io.ReadWriter } func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { if _, err := io.Copy(c.user, c.backend); err != nil { errc <- err return } if cw, ok := c.user.(interface{ CloseWrite() error }); ok { errc <- cw.CloseWrite() return } errc <- errReverseProxyCopyDone } func (c switchProtocolCopier) copyToBackend(errc chan<- error) { if _, err := io.Copy(c.backend, c.user); err != nil { errc <- err return } if cw, ok := c.backend.(interface{ CloseWrite() error }); ok { errc <- cw.CloseWrite() return } errc <- errReverseProxyCopyDone } // ReverseProxy returns a handler that proxies requests to the configured backend. func ReverseProxy(config ReverseProxyConfig) HandlerFunc { proxy := newReverseProxyHandler(config) return func(c *Context) { proxy.ServeHTTP(c) } } func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { proxy := &reverseProxyHandler{ config: config, receivedBy: reverseProxyReceivedBy(config.Via), } upstreams, err := buildReverseProxyUpstreams(config) if err != nil { proxy.configError = err } else { proxy.upstreams = upstreams } switch config.ForwardedHeaders { case ForwardedBoth, ForwardedNone, ForwardedXForwardedOnly, ForwardedRFC7239Only: default: proxy.config.ForwardedHeaders = ForwardedBoth } proxy.config.ForwardedBy = strings.TrimSpace(proxy.config.ForwardedBy) if reverseProxyUsesForwardedHeader(proxy.config.ForwardedHeaders) { if err := validateReverseProxyForwardedBy(proxy.config.ForwardedBy); err != nil { proxy.configError = err } } if proxy.configError == nil { if err := validateReverseProxyLBPolicy(proxy.config.LoadBalancing.Policy); err != nil { proxy.configError = err } } return proxy } func (p *reverseProxyHandler) ServeHTTP(c *Context) { defer c.Abort() if p.configError != nil { p.handleError(c, &reverseProxyStatusError{status: http.StatusInternalServerError, err: p.configError}) return } updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c) if err != nil { p.handleError(c, err) return } if handledLocally { return } ctx, cancel := p.requestContext(c) defer cancel() attempted := make(map[string]struct{}, len(p.upstreams)) attempts := 0 started := time.Now() var lastErr error for { upstream, err := p.selectUpstream(c, attempted) if err != nil { if lastErr != nil { p.handleError(c, lastErr) return } p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: err}) return } attempts++ upstream.inFlight.Add(1) served, attemptErr, retriable := p.serveUpstreamAttempt(c, ctx, upstream, updatedMaxForwards) upstream.inFlight.Add(-1) if served { return } if attemptErr != nil { lastErr = attemptErr } if retriable && p.shouldRetryAttempt(c.Request, attempts, started) { attempted[upstream.key] = struct{}{} if !p.waitRetryInterval(ctx, started) { if lastErr != nil { p.handleError(c, lastErr) } return } continue } if attemptErr != nil { p.handleError(c, attemptErr) return } if lastErr != nil { p.handleError(c, lastErr) return } p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: errReverseProxyNoAvailableUpstreams}) return } } func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (bool, error, bool) { outreq, connectWriter, cleanup, err := p.buildOutgoingRequest(c, ctx, upstream, updatedMaxForwards) if err != nil { return false, err, false } defer cleanup() transport := p.transportForUpstream(outreq, upstream) rawWriter := reverseProxyBaseResponseWriter(c.Writer) var ( roundTripMu sync.Mutex roundTripDone bool ) trace := &httptrace.ClientTrace{ Got1xxResponse: func(code int, header textproto.MIMEHeader) error { roundTripMu.Lock() defer roundTripMu.Unlock() if roundTripDone { return nil } h := c.Writer.Header() saved := h.Clone() clear(h) reverseProxyCopyHeader(h, http.Header(header)) rawWriter.WriteHeader(code) clear(h) reverseProxyCopyHeader(h, saved) return nil }, } outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) res, err := transport.RoundTrip(outreq) roundTripMu.Lock() roundTripDone = true roundTripMu.Unlock() if err != nil { if reverseProxyShouldCountPassiveFailure(outreq, err) { upstream.recordFailure(time.Now(), p.config.PassiveHealth) } return false, err, true } if reverseProxyStatusIsUnhealthy(p.config.PassiveHealth, res.StatusCode) { upstream.recordFailure(time.Now(), p.config.PassiveHealth) } if bridge := reverseProxyExtendedConnectBridgeFromContext(outreq.Context()); bridge != nil { if res.StatusCode == http.StatusSwitchingProtocols { appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { return true, nil, false } if err := p.handleBridgedExtendedConnectResponse(c, outreq, res, bridge); err != nil { return false, err, false } return true, nil, false } return false, &reverseProxyStatusError{status: http.StatusBadGateway, err: fmt.Errorf("extended CONNECT backend returned status %d instead of 101", res.StatusCode)}, false } if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices { removeHopByHopHeaders(res.Header) res.Header.Del("Content-Length") res.Header.Del("Transfer-Encoding") res.ContentLength = -1 res.TransferEncoding = nil appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { return true, nil, false } handleConnect := p.handleConnectResponse if reverseProxyIsExtendedConnectRequest(outreq) { handleConnect = p.handleExtendedConnectResponse } if err := handleConnect(c, outreq, res, connectWriter); err != nil { return false, err, false } return true, nil, false } if res.StatusCode == http.StatusSwitchingProtocols { appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { return true, nil, false } if err := p.handleUpgradeResponse(c, outreq, res); err != nil { return false, err, false } return true, nil, false } removeHopByHopHeaders(res.Header) appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { return true, nil, false } reverseProxyCopyHeader(c.Writer.Header(), res.Header) announcedTrailers := len(res.Trailer) if announcedTrailers > 0 { trailerKeys := make([]string, 0, len(res.Trailer)) for key := range res.Trailer { trailerKeys = append(trailerKeys, key) } c.Writer.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) } c.Writer.WriteHeader(res.StatusCode) if err := p.copyResponse(c.Writer, res.Body, p.flushInterval(res)); err != nil { defer res.Body.Close() c.AddError(fmt.Errorf("reverse proxy body copy failed: %w", err)) p.logf(c, "reverse proxy body copy failed: %v", err) if reverseProxyShouldPanicOnCopyError(c.Request) { panic(http.ErrAbortHandler) } return true, nil, false } res.Body.Close() if len(res.Trailer) > 0 { c.Writer.Flush() } if len(res.Trailer) == announcedTrailers { reverseProxyCopyHeader(c.Writer.Header(), res.Trailer) return true, nil, false } for key, values := range res.Trailer { prefixedKey := http.TrailerPrefix + key for _, value := range values { c.Writer.Header().Add(prefixedKey, value) } } return true, nil, false } func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) { outreq := c.Request.Clone(ctx) bridgeCtx, bridged, err := reverseProxyPrepareExtendedConnectBridge(outreq) if err != nil { return nil, nil, nil, err } if bridged { outreq = outreq.WithContext(bridgeCtx) } if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 { outreq.Body = nil } else if c.Request.GetBody != nil { body, err := c.Request.GetBody() if err != nil { return nil, nil, nil, fmt.Errorf("reverse proxy failed to replay request body: %w", err) } outreq.Body = body } else if outreq.Body != nil { outreq.Body = &noopCloseReader{readCloser: outreq.Body} } if outreq.Header == nil { outreq.Header = make(http.Header) } outreq.Close = false var connectWriter *io.PipeWriter if outreq.Method == http.MethodConnect && !bridged { pipeReader, pipeWriter := io.Pipe() outreq.Body = pipeReader outreq.ContentLength = -1 connectWriter = pipeWriter } cleanup := func() { if outreq.Body != nil { _ = outreq.Body.Close() } if connectWriter != nil { _ = connectWriter.Close() } } if outreq.Method == http.MethodConnect && !reverseProxyIsExtendedConnectRequest(outreq) { if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil { cleanup() return nil, nil, nil, err } } else { rewriteReverseProxyURL(outreq, upstream.target) if !p.config.PreserveHost { outreq.Host = "" } outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) } if updatedMaxForwards != "" { outreq.Header.Set("Max-Forwards", updatedMaxForwards) } reqUpType := reverseProxyUpgradeType(outreq.Header) if reqUpType != "" && !isPrintableASCII(reqUpType) { cleanup() return nil, nil, nil, &reverseProxyStatusError{ status: http.StatusBadRequest, err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), } } removeHopByHopHeaders(outreq.Header) if headerValuesContainToken(c.Request.Header["Te"], "trailers") { outreq.Header.Set("Te", "trailers") } if reqUpType != "" { outreq.Header.Set("Connection", "Upgrade") outreq.Header.Set("Upgrade", reqUpType) } p.addForwardingHeaders(c.Request, outreq) appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy) if _, ok := outreq.Header["User-Agent"]; !ok { outreq.Header.Set("User-Agent", "") } if p.config.ModifyRequest != nil { p.config.ModifyRequest(outreq) } return outreq, connectWriter, cleanup, nil } func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream *reverseProxyUpstream) http.RoundTripper { if p.config.Transport != nil { return p.config.Transport } if reverseProxyExtendedConnectBridgeFromContext(req.Context()) != nil { if upstream.bridgeTransport != nil { return upstream.bridgeTransport } return http.DefaultTransport } if upstream.useH2C && upstream.h2cTransport != nil { return upstream.h2cTransport } if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil { return upstream.extendedConnectTransport } return http.DefaultTransport } func (p *reverseProxyHandler) shouldRetryAttempt(req *http.Request, attempts int, started time.Time) bool { if req == nil || req.Context().Err() != nil || !reverseProxyCanRetryRequest(req) { return false } lb := p.config.LoadBalancing if lb.TryDuration > 0 { return time.Since(started) < lb.TryDuration } return attempts <= lb.Retries } func (p *reverseProxyHandler) waitRetryInterval(ctx context.Context, started time.Time) bool { interval := p.config.LoadBalancing.TryInterval tryDuration := p.config.LoadBalancing.TryDuration if tryDuration > 0 && interval == 0 { interval = 250 * time.Millisecond } if tryDuration > 0 { remaining := tryDuration - time.Since(started) if remaining <= 0 { return false } if interval <= 0 { return ctx.Err() == nil } if interval > remaining { return false } } if interval <= 0 { return ctx.Err() == nil } timer := time.NewTimer(interval) defer timer.Stop() select { case <-ctx.Done(): return false case <-timer.C: return true } } func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) { if c == nil || c.Request == nil { return "", false, nil } switch c.Request.Method { case http.MethodOptions, http.MethodTrace: default: return "", false, nil } rawValue := textproto.TrimString(c.Request.Header.Get("Max-Forwards")) if rawValue == "" { return "", false, nil } value, err := strconv.Atoi(rawValue) if err != nil || value < 0 { return "", false, &reverseProxyStatusError{ status: http.StatusBadRequest, err: fmt.Errorf("invalid Max-Forwards value %q", rawValue), } } if value == 0 { switch c.Request.Method { case http.MethodTrace: return "", true, p.writeLocalTraceResponse(c) case http.MethodOptions: p.writeLocalOptionsResponse(c) return "", true, nil } } return strconv.Itoa(value - 1), false, nil } func (p *reverseProxyHandler) writeLocalTraceResponse(c *Context) error { if c == nil || c.Request == nil { return nil } traceReq := c.Request.Clone(c.Request.Context()) traceReq.Body = nil traceReq.ContentLength = 0 traceReq.TransferEncoding = nil traceReq.RequestURI = c.Request.RequestURI if traceReq.RequestURI == "" && traceReq.URL != nil { traceReq.RequestURI = traceReq.URL.RequestURI() } traceReq.Header = traceReq.Header.Clone() for _, key := range []string{"Authorization", "Proxy-Authorization", "Cookie", "Forwarded", "X-Forwarded-For", "X-Forwarded-Host", "X-Forwarded-Proto", "Content-Length", "Transfer-Encoding", "Trailer"} { traceReq.Header.Del(key) } dump, err := httputil.DumpRequest(traceReq, false) if err != nil { return &reverseProxyStatusError{status: http.StatusInternalServerError, err: err} } c.Writer.Header().Set("Content-Type", "message/http") c.Writer.WriteHeader(http.StatusOK) _, err = c.Writer.Write(dump) return err } func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) { if c == nil { return } if c.engine != nil { if c.Request != nil && c.Request.RequestURI != "*" { if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request), c.allowedMethodsBuf[:0]); len(allow) > 0 { c.allowedMethodsBuf = allow[:0] allowHeader := c.allowHeaderBuf[:0] for i, method := range allow { if i > 0 { allowHeader = append(allowHeader, ',', ' ') } allowHeader = append(allowHeader, method...) } c.allowHeaderBuf = allowHeader[:0] c.Writer.Header().Set("Allow", BytesToString(allowHeader)) } } } c.Writer.WriteHeader(http.StatusOK) } func (p *reverseProxyHandler) requestContext(c *Context) (context.Context, context.CancelFunc) { ctx := c.Request.Context() if ctx.Done() != nil { return ctx, func() {} } // Follow the same compatibility path as net/http/httputil.ReverseProxy: // request contexts are normally cancelable, but middleware can still replace // c.Request with one backed by context.Background/TODO or another context with // a nil Done channel. In that case CloseNotifier still provides disconnect // propagation for the upstream round trip. rawWriter := reverseProxyBaseResponseWriter(c.Writer) cn, ok := rawWriter.(http.CloseNotifier) if !ok { return ctx, func() {} } ctx, cancel := context.WithCancel(ctx) notifyChan := cn.CloseNotify() go func() { select { case <-notifyChan: cancel() case <-ctx.Done(): } }() return ctx, cancel } func (p *reverseProxyHandler) addForwardingHeaders(in *http.Request, out *http.Request) { if p.config.ForwardedHeaders == ForwardedNone { return } clientIP := reverseProxyClientIP(in.RemoteAddr) scheme := reverseProxyRequestScheme(in) host := in.Host if p.config.ForwardedHeaders == ForwardedBoth || p.config.ForwardedHeaders == ForwardedXForwardedOnly { if clientIP != "" { appendXForwardedFor(out.Header, clientIP) } if host != "" { if len(out.Header.Values("X-Forwarded-Host")) == 0 { out.Header.Set("X-Forwarded-Host", host) } } if scheme != "" { if len(out.Header.Values("X-Forwarded-Proto")) == 0 { out.Header.Set("X-Forwarded-Proto", scheme) } } } if p.config.ForwardedHeaders == ForwardedBoth || p.config.ForwardedHeaders == ForwardedRFC7239Only { if forwardedValue := buildForwardedHeaderValue(clientIP, p.config.ForwardedBy, host, scheme); forwardedValue != "" { if prior := out.Header.Values("Forwarded"); len(prior) > 0 { forwardedValue = strings.Join(prior, ", ") + ", " + forwardedValue out.Header.Del("Forwarded") } out.Header.Add("Forwarded", forwardedValue) } } } func appendXForwardedFor(header http.Header, clientIP string) { if clientIP == "" { return } prior := header.Values("X-Forwarded-For") if len(prior) == 0 { header.Set("X-Forwarded-For", clientIP) return } header.Set("X-Forwarded-For", strings.Join(prior, ", ")+", "+clientIP) } func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req *http.Request) bool { if p.config.ModifyResponse == nil { return true } if err := p.config.ModifyResponse(res); err != nil { res.Body.Close() p.handleError(c, err) return false } return true } func (p *reverseProxyHandler) handleError(c *Context, err error) { if err == nil { return } c.AddError(err) if c.Writer.IsHijacked() { p.logf(c, "reverse proxy error after hijack: %v", err) return } if p.config.ErrorHandler != nil { p.config.ErrorHandler(c.Writer, c.Request, err) if c.Writer.Written() || c.Writer.IsHijacked() { return } } c.ErrorUseHandle(reverseProxyStatusCode(err), err) } func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Request, res *http.Response) error { reqUpType := reverseProxyUpgradeType(req.Header) resUpType := reverseProxyUpgradeType(res.Header) if reqUpType == "" || resUpType == "" { res.Body.Close() return &reverseProxyStatusError{ status: http.StatusBadGateway, err: fmt.Errorf("invalid upgrade negotiation: request protocol=%q, response protocol=%q", reqUpType, resUpType), } } if !isPrintableASCII(resUpType) { res.Body.Close() return &reverseProxyStatusError{ status: http.StatusBadGateway, err: fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), } } if !strings.EqualFold(reqUpType, resUpType) { res.Body.Close() return &reverseProxyStatusError{ status: http.StatusBadGateway, err: fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType), } } backConn, ok := res.Body.(io.ReadWriteCloser) if !ok { res.Body.Close() return &reverseProxyStatusError{ status: http.StatusBadGateway, err: errors.New("backend returned 101 response without writable body"), } } clientConn, brw, err := c.Writer.Hijack() if err != nil { backConn.Close() status := http.StatusBadGateway if errors.Is(err, http.ErrNotSupported) { status = http.StatusNotImplemented } return &reverseProxyStatusError{status: status, err: err} } defer clientConn.Close() defer backConn.Close() backConnClosed := make(chan struct{}) go func() { select { case <-req.Context().Done(): case <-backConnClosed: } backConn.Close() }() defer close(backConnClosed) res.Body = nil if err := res.Write(brw); err != nil { return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} } if err := brw.Flush(); err != nil { return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} } errc := make(chan error, 2) copyer := switchProtocolCopier{user: clientConn, backend: backConn} go copyer.copyToBackend(errc) go copyer.copyFromBackend(errc) firstErr := <-errc if firstErr == nil { firstErr = <-errc } if errors.Is(firstErr, errReverseProxyCopyDone) || errors.Is(firstErr, net.ErrClosed) || errors.Is(firstErr, io.EOF) || errors.Is(firstErr, context.Canceled) { return nil } return firstErr } func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error { if backWrite == nil { res.Body.Close() return &reverseProxyStatusError{ status: http.StatusBadGateway, err: errors.New("reverse proxy CONNECT tunnel is missing backend writer"), } } backRead := res.Body clientConn, brw, err := c.Writer.Hijack() if err != nil { backRead.Close() _ = backWrite.Close() status := http.StatusBadGateway if errors.Is(err, http.ErrNotSupported) { status = http.StatusNotImplemented } return &reverseProxyStatusError{status: status, err: err} } defer clientConn.Close() defer backRead.Close() defer backWrite.Close() backConnClosed := make(chan struct{}) go func() { select { case <-req.Context().Done(): case <-backConnClosed: } backRead.Close() _ = backWrite.Close() }() defer close(backConnClosed) res.Body = nil if err := res.Write(brw); err != nil { return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} } if err := brw.Flush(); err != nil { return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} } errc := make(chan error, 2) go func() { if _, err := io.Copy(clientConn, backRead); err != nil { errc <- err return } if cw, ok := clientConn.(interface{ CloseWrite() error }); ok { errc <- cw.CloseWrite() return } errc <- errReverseProxyCopyDone }() go func() { if _, err := io.Copy(backWrite, clientConn); err != nil { errc <- err return } errc <- backWrite.Close() }() firstErr := <-errc if firstErr == nil { firstErr = <-errc } if errors.Is(firstErr, errReverseProxyCopyDone) || errors.Is(firstErr, net.ErrClosed) || errors.Is(firstErr, io.EOF) || errors.Is(firstErr, context.Canceled) { return nil } return firstErr } func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, bridge *reverseProxyExtendedConnectBridge) error { if c == nil || c.Request == nil { res.Body.Close() return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT bridge requires a valid request context")} } backConn, ok := res.Body.(io.ReadWriteCloser) if !ok { res.Body.Close() return &reverseProxyStatusError{ status: http.StatusBadGateway, err: errors.New("backend returned bridged websocket response without writable body"), } } controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer)) if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { backConn.Close() return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} } responseHeader := c.Writer.Header() reverseProxyCopyHeader(responseHeader, res.Header) removeHopByHopHeaders(responseHeader) responseHeader.Del("Sec-WebSocket-Accept") c.Writer.WriteHeader(http.StatusOK) if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { backConn.Close() return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} } conn := &reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer, controller: controller} var closeOnce sync.Once closeTunnel := func() { closeOnce.Do(func() { _ = conn.Close() _ = backConn.Close() }) } go func() { <-req.Context().Done() closeTunnel() }() errc := make(chan error, 2) copyer := switchProtocolCopier{user: conn, backend: backConn} go copyer.copyToBackend(errc) go copyer.copyFromBackend(errc) var firstErr error for i := 0; i < 2; i++ { err := <-errc if reverseProxyIsBenignTunnelError(err) { continue } if firstErr == nil { firstErr = err closeTunnel() } } closeTunnel() if reverseProxyIsBenignTunnelError(firstErr) { return nil } return firstErr } func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error { if c == nil || c.Request == nil { res.Body.Close() if backWrite != nil { _ = backWrite.Close() } return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT requires a valid request context")} } if backWrite == nil { res.Body.Close() return &reverseProxyStatusError{ status: http.StatusBadGateway, err: errors.New("reverse proxy extended CONNECT tunnel is missing backend writer"), } } controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer)) if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { res.Body.Close() _ = backWrite.Close() return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} } reverseProxyCopyHeader(c.Writer.Header(), res.Header) c.Writer.WriteHeader(res.StatusCode) if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { res.Body.Close() _ = backWrite.Close() return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} } var closeOnce sync.Once closeTunnel := func() { closeOnce.Do(func() { _ = c.Request.Body.Close() _ = backWrite.Close() _ = res.Body.Close() }) } go func() { <-req.Context().Done() closeTunnel() }() errc := make(chan error, 2) go func() { _, err := io.Copy(backWrite, c.Request.Body) closeErr := backWrite.Close() if err != nil && !reverseProxyIsBenignTunnelError(err) { errc <- err return } errc <- closeErr }() go func() { copyErr := p.copyResponse(c.Writer, res.Body, -1) closeErr := res.Body.Close() if copyErr != nil { errc <- copyErr return } errc <- closeErr }() var firstErr error for i := 0; i < 2; i++ { err := <-errc if reverseProxyIsBenignTunnelError(err) { continue } if firstErr == nil { firstErr = err closeTunnel() } } closeTunnel() if reverseProxyIsBenignTunnelError(firstErr) { return nil } return firstErr } func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" { return -1 } if res.ContentLength == -1 { return -1 } return p.config.FlushInterval } func (p *reverseProxyHandler) copyResponse(dst ResponseWriter, src io.Reader, flushInterval time.Duration) error { var writer io.Writer = dst if flushInterval != 0 { mlw := &maxLatencyWriter{dst: dst, latency: flushInterval} defer mlw.stop() mlw.flushPending = true mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) writer = mlw } var buf []byte if p.config.BufferPool != nil { buf = p.config.BufferPool.Get() defer p.config.BufferPool.Put(buf) } _, err := p.copyBuffer(writer, src, buf) return err } func (p *reverseProxyHandler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { if len(buf) == 0 { buf = make([]byte, 32*1024) } var written int64 for { nr, rerr := src.Read(buf) if rerr != nil && !errors.Is(rerr, io.EOF) && !reverseProxyIsBenignTunnelError(rerr) { p.logf(nil, "reverse proxy read error during body copy: %v", rerr) } if nr > 0 { nw, werr := dst.Write(buf[:nr]) if nw > 0 { written += int64(nw) } if werr != nil { return written, werr } if nr != nw { return written, io.ErrShortWrite } } if rerr != nil { if errors.Is(rerr, io.EOF) { return written, nil } return written, rerr } } } func (p *reverseProxyHandler) logf(c *Context, format string, args ...any) { if c != nil { if logger := c.GetLogger(); logger != nil { logger.Errorf(format, args...) return } } log.Printf(format, args...) } func reverseProxyStatusCode(err error) int { var statusErr *reverseProxyStatusError if errors.As(err, &statusErr) && statusErr.status > 0 { return statusErr.status } var netErr net.Error if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) { return http.StatusGatewayTimeout } return http.StatusBadGateway } func validateReverseProxyTarget(target *url.URL) error { if target == nil { return errReverseProxyNilTarget } if target.Scheme == "" || target.Host == "" { return errReverseProxyInvalidTarget } return nil } func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstream, error) { if config.Target != nil && len(config.Targets) > 0 { return nil, errors.New("reverse proxy Target and Targets cannot be used together") } targets := make([]*url.URL, 0, max(1, len(config.Targets))) if config.Target != nil { target := cloneReverseProxyURL(config.Target) normalizeReverseProxyTarget(target) if err := validateReverseProxyTarget(target); err != nil { return nil, err } targets = append(targets, target) } for i, rawTarget := range config.Targets { trimmed := strings.TrimSpace(rawTarget) if trimmed == "" { return nil, fmt.Errorf("reverse proxy target at index %d is empty", i) } target, err := url.Parse(trimmed) if err != nil { return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err) } normalizeReverseProxyTarget(target) if err := validateReverseProxyTarget(target); err != nil { return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err) } targets = append(targets, target) } if len(targets) == 0 { return nil, errReverseProxyNilTarget } upstreams := make([]*reverseProxyUpstream, 0, len(targets)) for i, target := range targets { useH2C := strings.EqualFold(target.Scheme, "h2c") if useH2C { target = cloneReverseProxyURL(target) target.Scheme = "http" } upstream := &reverseProxyUpstream{ key: fmt.Sprintf("%d:%s", i, target.String()), target: target, index: i, useH2C: useH2C || config.AllowH2CUpstream, } if config.Transport == nil { upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport() upstream.bridgeTransport = newHTTP1BridgeTransport() if upstream.useH2C { upstream.h2cTransport = newH2CTransport() } } upstreams = append(upstreams, upstream) } return upstreams, nil } func validateReverseProxyForwardedBy(value string) error { trimmed := strings.TrimSpace(value) if trimmed == "" { return nil } if !isValidForwardedNodeIdentifier(trimmed) { return fmt.Errorf("reverse proxy ForwardedBy must be an RFC 7239 node identifier, got %q", value) } return nil } func normalizeReverseProxyTarget(target *url.URL) { switch strings.ToLower(target.Scheme) { case "ws": target.Scheme = "http" case "wss": target.Scheme = "https" } } func cloneReverseProxyURL(target *url.URL) *url.URL { if target == nil { return nil } clone := *target return &clone } func reverseProxyReceivedBy(configValue string) string { trimmed := strings.TrimSpace(configValue) if trimmed != "" { return trimmed } return "touka-engine" } func reverseProxyClientIP(remoteAddr string) string { if remoteAddr == "" { return "" } if addrPort, err := netip.ParseAddrPort(remoteAddr); err == nil { return addrPort.Addr().String() } host, _, err := net.SplitHostPort(remoteAddr) if err == nil { if addr, parseErr := netip.ParseAddr(host); parseErr == nil { return addr.String() } return host } if addr, err := netip.ParseAddr(remoteAddr); err == nil { return addr.String() } return remoteAddr } func reverseProxyRequestScheme(req *http.Request) string { if req == nil { return "" } if req.TLS != nil { return "https" } if req.URL != nil { scheme := strings.ToLower(req.URL.Scheme) if scheme != "" { return scheme } } return "http" } func buildForwardedHeaderValue(clientIP, by, host, scheme string) string { pairs := make([]string, 0, 4) if by != "" { pairs = append(pairs, "by="+formatForwardedParameterValue(by)) } if clientIP != "" { pairs = append(pairs, "for="+formatForwardedFor(clientIP)) } if host != "" { pairs = append(pairs, "host="+formatForwardedParameterValue(host)) } if scheme != "" { pairs = append(pairs, "proto="+formatForwardedParameterValue(strings.ToLower(scheme))) } if len(pairs) == 0 { return "" } return strings.Join(pairs, ";") } func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { return policy == ForwardedBoth || policy == ForwardedRFC7239Only } func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) { if req == nil { return context.Background(), false, nil } protocol := reverseProxyExtendedConnectProtocol(req) if req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") { return req.Context(), false, nil } bridge := &reverseProxyExtendedConnectBridge{body: req.Body} ctx := context.WithValue(req.Context(), reverseProxyExtendedConnectBridge{}, bridge) req.Header.Del(":protocol") req.Method = http.MethodGet req.Body = http.NoBody req.ContentLength = 0 req.Header.Set("Upgrade", "websocket") req.Header.Set("Connection", "Upgrade") req.Header.Set("Sec-WebSocket-Version", "13") key, err := reverseProxyGenerateWebSocketKey() if err != nil { return nil, false, fmt.Errorf("reverse proxy failed to generate websocket key: %w", err) } req.Header.Set("Sec-WebSocket-Key", key) return ctx, true, nil } func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseProxyExtendedConnectBridge { if ctx == nil { return nil } bridge, _ := ctx.Value(reverseProxyExtendedConnectBridge{}).(*reverseProxyExtendedConnectBridge) return bridge } func reverseProxyGenerateWebSocketKey() (string, error) { key := make([]byte, 16) if _, err := io.ReadFull(rand.Reader, key); err != nil { return "", err } return base64.StdEncoding.EncodeToString(key), nil } func reverseProxyIsExtendedConnectRequest(req *http.Request) bool { return reverseProxyExtendedConnectProtocol(req) != "" } func reverseProxyExtendedConnectProtocol(req *http.Request) string { if req == nil || req.Method != http.MethodConnect || req.Header == nil { return "" } return textproto.TrimString(req.Header.Get(":protocol")) } func isValidForwardedNodeIdentifier(value string) bool { if value == "" { return false } if strings.HasPrefix(value, "[") { closing := strings.IndexByte(value, ']') if closing <= 1 { return false } addr, err := netip.ParseAddr(value[1:closing]) if err != nil || !addr.Is6() { return false } if closing == len(value)-1 { return true } if value[closing+1] != ':' { return false } return isValidForwardedNodePort(value[closing+2:]) } host, port, hasPort := strings.Cut(value, ":") if hasPort { switch { case host == "unknown", isValidForwardedObfuscatedIdentifier(host): return isValidForwardedNodePort(port) default: addr, err := netip.ParseAddr(host) return err == nil && addr.Is4() && isValidForwardedNodePort(port) } } if value == "unknown" || isValidForwardedObfuscatedIdentifier(value) { return true } addr, err := netip.ParseAddr(value) return err == nil && addr.Is4() } func isValidForwardedNodePort(value string) bool { if value == "" { return false } if isValidForwardedObfuscatedIdentifier(value) { return true } if len(value) > 5 { return false } port, err := strconv.Atoi(value) return err == nil && port > 0 && port <= 65535 } func isValidForwardedObfuscatedIdentifier(value string) bool { if len(value) < 2 || value[0] != '_' { return false } for i := 1; i < len(value); i++ { b := value[i] if (b >= '0' && b <= '9') || (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') { continue } switch b { case '.', '_', '-': continue default: return false } } return true } func formatForwardedFor(clientIP string) string { addr, err := netip.ParseAddr(clientIP) if err != nil { return formatForwardedParameterValue(clientIP) } if addr.Is6() { return quoteForwardedString("[" + addr.String() + "]") } return addr.String() } func formatForwardedParameterValue(value string) string { if isToken(value) { return value } return quoteForwardedString(value) } func quoteForwardedString(value string) string { replacer := strings.NewReplacer(`\`, `\\`, `"`, `\"`) return `"` + replacer.Replace(value) + `"` } func isToken(value string) bool { if value == "" { return false } for i := 0; i < len(value); i++ { if !isTokenChar(value[i]) { return false } } return true } func isTokenChar(b byte) bool { if b >= '0' && b <= '9' { return true } if b >= 'A' && b <= 'Z' { return true } if b >= 'a' && b <= 'z' { return true } switch b { case '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~': return true default: return false } } func appendViaHeader(header http.Header, protocol, receivedBy string) { if header == nil || receivedBy == "" { return } if protocol == "" { protocol = "1.1" } header.Add("Via", protocol+" "+receivedBy) } func reverseProxyViaProtocol(major, minor int, raw string) string { if major > 0 { return strconv.Itoa(major) + "." + strconv.Itoa(minor) } if strings.HasPrefix(raw, "HTTP/") { return strings.TrimPrefix(raw, "HTTP/") } return raw } func rewriteReverseProxyURL(req *http.Request, target *url.URL) { targetQuery := target.RawQuery req.URL.Scheme = target.Scheme req.URL.Host = target.Host req.URL.Path, req.URL.RawPath = joinReverseProxyURLPath(target, req.URL) if targetQuery == "" || req.URL.RawQuery == "" { req.URL.RawQuery = targetQuery + req.URL.RawQuery } else { req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery } } func rewriteReverseProxyConnectRequest(req *http.Request, target *url.URL) error { connectTarget, err := reverseProxyConnectTarget(target) if err != nil { return &reverseProxyStatusError{status: http.StatusBadRequest, err: err} } req.URL.Scheme = target.Scheme req.URL.Host = target.Host req.URL.Path = "" req.URL.RawPath = "" req.URL.RawQuery = "" req.URL.Opaque = connectTarget req.Host = connectTarget return nil } func reverseProxyConnectTarget(target *url.URL) (string, error) { if target == nil { return "", errReverseProxyNilTarget } host := target.Hostname() if host == "" { return "", errReverseProxyInvalidTarget } port := target.Port() if port == "" { switch strings.ToLower(target.Scheme) { case "http": port = "80" case "https": port = "443" default: return "", fmt.Errorf("reverse proxy CONNECT target requires a supported scheme, got %q", target.Scheme) } } portNum, err := strconv.Atoi(port) if err != nil || portNum <= 0 || portNum > 65535 { return "", fmt.Errorf("reverse proxy CONNECT target has invalid port %q", port) } return net.JoinHostPort(host, port), nil } func joinReverseProxyURLPath(base, incoming *url.URL) (string, string) { if base.RawPath == "" && incoming.RawPath == "" { return reverseProxySingleJoiningSlash(base.Path, incoming.Path), "" } baseEscaped := base.EscapedPath() incomingEscaped := incoming.EscapedPath() baseSlash := strings.HasSuffix(baseEscaped, "/") incomingSlash := strings.HasPrefix(incomingEscaped, "/") switch { case baseSlash && incomingSlash: return base.Path + incoming.Path[1:], baseEscaped + incomingEscaped[1:] case !baseSlash && !incomingSlash: return base.Path + "/" + incoming.Path, baseEscaped + "/" + incomingEscaped default: return base.Path + incoming.Path, baseEscaped + incomingEscaped } } func reverseProxySingleJoiningSlash(a, b string) string { aslash := strings.HasSuffix(a, "/") bslash := strings.HasPrefix(b, "/") switch { case aslash && bslash: return a + b[1:] case !aslash && !bslash: return a + "/" + b default: return a + b } } func reverseProxyCopyHeader(dst, src http.Header) { for key, values := range src { for _, value := range values { dst.Add(key, value) } } } var reverseProxyHopHeaders = []string{ "Connection", "Proxy-Connection", "Keep-Alive", "Proxy-Authenticate", "Proxy-Authorization", "Te", "Trailer", "Transfer-Encoding", "Upgrade", } func removeHopByHopHeaders(header http.Header) { for _, connectionValue := range header["Connection"] { for _, token := range strings.Split(connectionValue, ",") { trimmed := textproto.TrimString(token) if trimmed != "" { header.Del(trimmed) } } } for _, hopHeader := range reverseProxyHopHeaders { header.Del(hopHeader) } } func reverseProxyUpgradeType(header http.Header) string { if !headerValuesContainToken(header["Connection"], "Upgrade") { return "" } return header.Get("Upgrade") } func headerValuesContainToken(values []string, token string) bool { if token == "" { return false } for _, value := range values { for _, part := range strings.Split(value, ",") { if strings.EqualFold(textproto.TrimString(part), token) { return true } } } return false } func cleanReverseProxyQueryParams(rawQuery string) string { if rawQuery == "" { return "" } // Normalize the outgoing query string so the proxy and upstream do not see // different semantics for non-standard separators or malformed pairs. // This can change the exact textual form of the original query and may drop // parts that net/url rejects, but it keeps proxy-chain parsing behavior more // consistent and reduces parameter-smuggling ambiguity. values, _ := url.ParseQuery(rawQuery) return values.Encode() } func reverseProxyShouldPanicOnCopyError(req *http.Request) bool { return req != nil && req.Context().Value(http.ServerContextKey) != nil } func reverseProxyCanRetryRequest(req *http.Request) bool { if req == nil || req.Method == http.MethodConnect || reverseProxyUpgradeType(req.Header) != "" || !reverseProxyMethodIsSafe(req.Method) { return false } if req.Body == nil || req.ContentLength == 0 { return true } return req.GetBody != nil } func reverseProxyShouldCountPassiveFailure(req *http.Request, err error) bool { if err == nil || reverseProxyIsBenignTunnelError(err) { return false } if req != nil && req.Context().Err() != nil { return false } return !errors.Is(err, context.Canceled) } func reverseProxyMethodIsSafe(method string) bool { switch method { case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: return true default: return false } } func reverseProxyIsBenignTunnelError(err error) bool { return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err) } func reverseProxyIsClosedBodyError(err error) bool { if err == nil { return false } var streamErr http2.StreamError if errors.As(err, &streamErr) && streamErr.Code == http2.ErrCodeCancel { return true } switch err.Error() { case "body closed by handler", "http2: response body closed", "response body closed": return true default: return false } } func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { return UnwrapResponseWriter(writer) } func isPrintableASCII(value string) bool { for i := 0; i < len(value); i++ { if value[i] < 0x20 || value[i] > 0x7e { return false } } return true }