diff --git a/context.go b/context.go index f06d21e..e73033d 100644 --- a/context.go +++ b/context.go @@ -128,6 +128,19 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } } +func (c *Context) writeResponseBody(data []byte, contextMsg string) { + if len(data) == 0 { + return + } + if _, err := c.Writer.Write(data); err != nil { + wrapped := fmt.Errorf("%s: %w", contextMsg, err) + c.AddError(wrapped) + if c != nil && c.engine != nil && c.engine.LogReco != nil { + c.engine.LogReco.Errorf("%s: %v", contextMsg, err) + } + } +} + // Next 在处理链中执行下一个处理函数 // 这是中间件模式的核心,允许请求依次经过多个处理函数 func (c *Context) Next() { @@ -344,20 +357,20 @@ func (c *Context) Param(key string) string { func (c *Context) Raw(code int, contentType string, data []byte) { c.Writer.Header().Set("Content-Type", contentType) c.Writer.WriteHeader(code) - c.Writer.Write(data) + c.writeResponseBody(data, "failed to write raw response") } // String 向响应写入格式化的字符串 func (c *Context) String(code int, format string, values ...any) { c.Writer.WriteHeader(code) - c.Writer.Write(fmt.Appendf(nil, format, values...)) + c.writeResponseBody(fmt.Appendf(nil, format, values...), "failed to write string response") } // Text 向响应写入无需格式化的string func (c *Context) Text(code int, text string) { c.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8") c.Writer.WriteHeader(code) - c.Writer.Write([]byte(text)) + c.writeResponseBody([]byte(text), "failed to write text response") } // FileText @@ -495,7 +508,7 @@ func (c *Context) JSONBuf(code int, obj any) { c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") c.Writer.WriteHeader(code) - c.Writer.Write(buf.Bytes()) + c.writeResponseBody(buf.Bytes(), "failed to write buffered JSON response") } // GOB 向响应写入GOB数据 @@ -524,7 +537,7 @@ func (c *Context) GOBBuf(code int, obj any) { } c.Writer.Header().Set("Content-Type", "application/octet-stream") c.Writer.WriteHeader(code) - c.Writer.Write(buf.Bytes()) + c.writeResponseBody(buf.Bytes(), "failed to write buffered GOB response") } // WANF向响应写入WANF数据 @@ -553,7 +566,7 @@ func (c *Context) WANFBuf(code int, obj any) { } c.Writer.Header().Set("Content-Type", "application/vnd.wjqserver.wanf; charset=utf-8") c.Writer.WriteHeader(code) - c.Writer.Write(buf.Bytes()) + c.writeResponseBody(buf.Bytes(), "failed to write buffered WANF response") } // HTML 渲染 HTML 模板 @@ -577,7 +590,7 @@ func (c *Context) HTML(code int, name string, obj any) { // 可以扩展支持其他渲染器接口 } // 默认简单输出,用于未配置 HTMLRender 的情况 - c.Writer.Write(fmt.Appendf(nil, "\n
%v", name, obj)) + c.writeResponseBody(fmt.Appendf(nil, "\n
%v", name, obj), "failed to write HTML response") } // HTMLBuf 先将 HTML 模板渲染到 buffer, 成功后再写入状态码和响应体. @@ -602,7 +615,7 @@ func (c *Context) HTMLBuf(code int, name string, obj any) { // 渲染成功,写入响应 c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8") c.Writer.WriteHeader(code) - c.Writer.Write(buf.Bytes()) + c.writeResponseBody(buf.Bytes(), "failed to write buffered HTML response") return } @@ -938,7 +951,7 @@ func (c *Context) GetReqBodyFull() ([]byte, error) { } }() - data, err := iox.ReadAll(body) + data, err := io.ReadAll(body) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) @@ -959,7 +972,7 @@ func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) { } }() - data, err := iox.ReadAll(body) + data, err := io.ReadAll(body) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) diff --git a/engine.go b/engine.go index b0723e7..d712064 100644 --- a/engine.go +++ b/engine.go @@ -19,6 +19,7 @@ import ( "github.com/WJQSERVER-STUDIO/httpc" "github.com/fenthope/reco" + "github.com/go-json-experiment/json" ) // Last 返回链中的最后一个处理函数 @@ -132,6 +133,32 @@ type defaultErrorResponse struct { Error string `json:"error"` } +var defaultNotFoundBody = mustMarshalDefaultErrorBody(http.StatusNotFound, errNotFound.Error()) +var defaultMethodNotAllowedBody = mustMarshalDefaultErrorBody(http.StatusMethodNotAllowed, errMethodNotAllowed.Error()) + +func mustMarshalDefaultErrorBody(code int, errMsg string) []byte { + body, err := json.Marshal(defaultErrorResponse{ + Code: code, + Message: http.StatusText(code), + Error: errMsg, + }) + if err != nil { + panic(err) + } + return body +} + +func writeDefaultErrorJSON(c *Context, code int, body []byte) { + if c == nil || c.Writer == nil { + return + } + c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") + c.Writer.WriteHeader(code) + c.writeResponseBody(body, "failed to write default error response") + c.Writer.Flush() + c.Abort() +} + var methodNotAllowedHandler HandlerFunc = func(c *Context) { httpMethod := c.Request.Method requestPath := routeLookupPath(c.Request) @@ -191,6 +218,16 @@ func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是 if c.Writer.Written() { return } + if len(c.Errors) == 0 { + switch { + case code == http.StatusNotFound && errors.Is(err, errNotFound): + writeDefaultErrorJSON(c, code, defaultNotFoundBody) + return + case code == http.StatusMethodNotAllowed && errors.Is(err, errMethodNotAllowed): + writeDefaultErrorJSON(c, code, defaultMethodNotAllowedBody) + return + } + } // 输出json 状态码与状态码对应描述 var errMsg string if err != nil { diff --git a/engine_test.go b/engine_test.go index 571f4b7..4772810 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1,11 +1,66 @@ package touka import ( + "bufio" "encoding/json" + "errors" + "html/template" + "net" "net/http" "testing" ) +type failingResponseWriter struct { + header http.Header + status int + err error +} + +func (w *failingResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *failingResponseWriter) WriteHeader(statusCode int) { + if w.status == 0 { + w.status = statusCode + } +} + +func (w *failingResponseWriter) Write(p []byte) (int, error) { + if w.status == 0 { + w.status = http.StatusOK + } + if w.err != nil { + return 0, w.err + } + return len(p), nil +} + +func (w *failingResponseWriter) Flush() {} + +func (w *failingResponseWriter) Status() int { + return w.status +} + +func (w *failingResponseWriter) Size() int { + return 0 +} + +func (w *failingResponseWriter) Written() bool { + return w.status != 0 +} + +func (w *failingResponseWriter) IsHijacked() bool { + return false +} + +func (w *failingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, http.ErrNotSupported +} + func TestHandleRequestRedirectFixedPath(t *testing.T) { engine := New() engine.GET("/api/v1/users/:id/settings", func(c *Context) { @@ -139,3 +194,113 @@ func TestDefaultErrorHandleJSONShape(t *testing.T) { t.Fatalf("unexpected error payload: %+v", body) } } + +func TestDefaultMethodNotAllowedJSONShape(t *testing.T) { + engine := New() + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } + + var body struct { + Code int `json:"code"` + Message string `json:"message"` + Error string `json:"error"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil { + t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err) + } + if body.Code != http.StatusMethodNotAllowed || body.Message != http.StatusText(http.StatusMethodNotAllowed) || body.Error != "method not allowed" { + t.Fatalf("unexpected error payload: %+v", body) + } +} + +func TestCustomErrorHandlerStillOverridesDefaultFastPath(t *testing.T) { + engine := New() + engine.SetErrorHandler(func(c *Context, code int, err error) { + c.Writer.Header().Set("X-Custom-Error", "1") + c.String(code, "custom:%v", err) + }) + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } + if got := rr.Header().Get("X-Custom-Error"); got != "1" { + t.Fatalf("expected custom error header, got %q", got) + } + if rr.Body.String() != "custom:method not allowed" { + t.Fatalf("expected custom error body, got %q", rr.Body.String()) + } +} + +func TestResponseHelpersCaptureWriteErrors(t *testing.T) { + testCases := []struct { + name string + run func(*Context) + }{ + {name: "Raw", run: func(c *Context) { c.Raw(http.StatusOK, "application/octet-stream", []byte("payload")) }}, + {name: "String", run: func(c *Context) { c.String(http.StatusOK, "value=%d", 1) }}, + {name: "Text", run: func(c *Context) { c.Text(http.StatusOK, "payload") }}, + {name: "JSONBuf", run: func(c *Context) { c.JSONBuf(http.StatusOK, map[string]string{"a": "b"}) }}, + {name: "GOBBuf", run: func(c *Context) { c.GOBBuf(http.StatusOK, struct{ A string }{A: "b"}) }}, + {name: "WANFBuf", run: func(c *Context) { c.WANFBuf(http.StatusOK, map[string]string{"a": "b"}) }}, + {name: "HTMLFallback", run: func(c *Context) { c.HTML(http.StatusOK, "page", map[string]string{"a": "b"}) }}, + {name: "HTMLBuf", run: func(c *Context) { + c.engine.HTMLRender = template.Must(template.New("page").Parse(`{{.a}}`)) + c.HTMLBuf(http.StatusOK, "page", map[string]string{"a": "b"}) + }}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + writerErr := errors.New("write failed") + w := &failingResponseWriter{err: writerErr} + c, _ := CreateTestContext(w) + + tc.run(c) + + if got := len(c.Errors); got != 1 { + t.Fatalf("expected exactly one captured error, got %d", got) + } + if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) { + t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1]) + } + }) + } +} + +func TestDefaultErrorFastPathCapturesWriteErrors(t *testing.T) { + writerErr := errors.New("write failed") + w := &failingResponseWriter{err: writerErr} + engine := New() + c, _ := CreateTestContext(w) + c.engine = engine + req, err := http.NewRequest(http.MethodGet, "/missing", nil) + if err != nil { + t.Fatalf("failed to build request: %v", err) + } + c.reset(w, req) + + defaultErrorHandle(c, http.StatusNotFound, errNotFound) + + if len(c.Errors) == 0 { + t.Fatal("expected write error to be captured") + } + if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) { + t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1]) + } + if c.Writer.Status() != http.StatusNotFound { + t.Fatalf("expected status %d, got %d", http.StatusNotFound, c.Writer.Status()) + } + if !c.IsAborted() { + t.Fatal("expected fast path to abort context") + } +} diff --git a/go.mod b/go.mod index bd0c046..dee187d 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/infinite-iroha/touka go 1.26 require ( - github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 + github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3 github.com/WJQSERVER-STUDIO/httpc v0.9.0 github.com/WJQSERVER/wanf v0.0.8 github.com/fenthope/reco v0.0.5 diff --git a/go.sum b/go.sum index 6a8d0c6..4b9dbd9 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 h1:AiIHXP21LpK7pFfqUlUstgQEWzjbekZgxOuvVwiMfyM= github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok= +github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3 h1:Hc1O6D50U3URkdSzfQ/SgeUU750wUBCYhefdvAbE2Ck= +github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3/go.mod h1:nFQzepAwwdj5Hp5U+X19l4FVvsaOSBTW41BzfI/CkMA= github.com/WJQSERVER-STUDIO/httpc v0.9.0 h1:MpXcQQqukrSLHH/2tTfnXrhqD6nEDHB/gbzehXaS8o4= github.com/WJQSERVER-STUDIO/httpc v0.9.0/go.mod h1:filzryrl4eAtFVyl4oVHcJqx1SpNFbrCn+ddQPLlCSg= github.com/WJQSERVER/wanf v0.0.8 h1:1Ri9d7nKhu22hGxP8O9B9rXnYym6DYGKgi6WRVx3VF8= diff --git a/iox_benchmark_test.go b/iox_benchmark_test.go new file mode 100644 index 0000000..9b43590 --- /dev/null +++ b/iox_benchmark_test.go @@ -0,0 +1,150 @@ +package touka + +import ( + "bytes" + "io" + "testing" + + "github.com/WJQSERVER-STUDIO/go-utils/iox" +) + +type benchmarkResetReader struct { + data []byte + off int +} + +func (r *benchmarkResetReader) Read(p []byte) (int, error) { + if r.off >= len(r.data) { + return 0, io.EOF + } + n := copy(p, r.data[r.off:]) + r.off += n + return n, nil +} + +func (r *benchmarkResetReader) Reset() { + r.off = 0 +} + +type benchmarkDiscardWriter struct{} + +func (benchmarkDiscardWriter) Write(p []byte) (int, error) { + return len(p), nil +} + +var benchmarkIOXResult int64 +var benchmarkIOXBytes []byte + +func BenchmarkIOXCopyComparison(b *testing.B) { + payload := bytes.Repeat([]byte("0123456789abcdef"), 4096) + + b.Run("io.Copy", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + w := benchmarkDiscardWriter{} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + n, err := io.Copy(w, r) + if err != nil { + b.Fatalf("io.Copy failed: %v", err) + } + benchmarkIOXResult = n + } + }) + + b.Run("iox.Copy", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + w := benchmarkDiscardWriter{} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + n, err := iox.Copy(w, r) + if err != nil { + b.Fatalf("iox.Copy failed: %v", err) + } + benchmarkIOXResult = n + } + }) +} + +func BenchmarkIOXCopyBufferComparison(b *testing.B) { + payload := bytes.Repeat([]byte("0123456789abcdef"), 4096) + + b.Run("io.CopyBuffer", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + w := benchmarkDiscardWriter{} + buf := make([]byte, 32*1024) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + n, err := io.CopyBuffer(w, r, buf) + if err != nil { + b.Fatalf("io.CopyBuffer failed: %v", err) + } + benchmarkIOXResult = n + } + }) + + b.Run("iox.CopyBuffer", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + w := benchmarkDiscardWriter{} + buf := make([]byte, 32*1024) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + n, err := iox.CopyBuffer(w, r, buf) + if err != nil { + b.Fatalf("iox.CopyBuffer failed: %v", err) + } + benchmarkIOXResult = n + } + }) +} + +func BenchmarkIOXReadAllComparison(b *testing.B) { + payload := bytes.Repeat([]byte("0123456789abcdef"), 4096) + + b.Run("io.ReadAll", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + data, err := io.ReadAll(r) + if err != nil { + b.Fatalf("io.ReadAll failed: %v", err) + } + benchmarkIOXBytes = data + } + }) + + b.Run("iox.ReadAll", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + data, err := io.ReadAll(r) + if err != nil { + b.Fatalf("iox.ReadAll failed: %v", err) + } + benchmarkIOXBytes = data + } + }) +} diff --git a/reverseproxy.go b/reverseproxy.go index 5b178d5..5ec3693 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -1041,7 +1041,7 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r go copyer.copyFromBackend(errc) var firstErr error - for i := 0; i < 2; i++ { + for range 2 { err := <-errc if reverseProxyIsBenignTunnelError(err) { continue @@ -1123,7 +1123,7 @@ func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *htt }() var firstErr error - for i := 0; i < 2; i++ { + for range 2 { err := <-errc if reverseProxyIsBenignTunnelError(err) { continue @@ -1587,8 +1587,8 @@ func reverseProxyViaProtocol(major, minor int, raw string) string { if major > 0 { return strconv.Itoa(major) + "." + strconv.Itoa(minor) } - if strings.HasPrefix(raw, "HTTP/") { - return strings.TrimPrefix(raw, "HTTP/") + if after, ok := strings.CutPrefix(raw, "HTTP/"); ok { + return after } return raw } @@ -1702,7 +1702,7 @@ var reverseProxyHopHeaders = []string{ func removeHopByHopHeaders(header http.Header) { for _, connectionValue := range header["Connection"] { - for _, token := range strings.Split(connectionValue, ",") { + for token := range strings.SplitSeq(connectionValue, ",") { trimmed := textproto.TrimString(token) if trimmed != "" { header.Del(trimmed) @@ -1726,7 +1726,7 @@ func headerValuesContainToken(values []string, token string) bool { return false } for _, value := range values { - for _, part := range strings.Split(value, ",") { + for part := range strings.SplitSeq(value, ",") { if strings.EqualFold(textproto.TrimString(part), token) { return true } diff --git a/reverseproxy_benchmark_test.go b/reverseproxy_benchmark_test.go index f55f5f0..b496f5c 100644 --- a/reverseproxy_benchmark_test.go +++ b/reverseproxy_benchmark_test.go @@ -7,6 +7,7 @@ import ( "io" "net" "net/http" + "strings" "testing" "time" ) @@ -167,6 +168,33 @@ func BenchmarkReverseProxySelectUpstream(b *testing.B) { } } +func BenchmarkReverseProxySelectUpstreamHeaderPolicy(b *testing.B) { + proxy := &reverseProxyHandler{ + upstreams: []*reverseProxyUpstream{ + {key: "a", index: 0}, + {key: "b", index: 1}, + {key: "c", index: 2}, + {key: "d", index: 3}, + }, + config: ReverseProxyConfig{ + LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())}, + }, + } + c, _ := CreateTestContext(nil) + c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b", "tenant-c"} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + selected, err := proxy.selectUpstream(c, nil) + if err != nil { + b.Fatalf("selectUpstream failed: %v", err) + } + benchmarkReverseProxySink = selected.index + } +} + func TestReverseProxyCopyResponseWithoutBufferPool(t *testing.T) { proxy := newReverseProxyHandler(ReverseProxyConfig{}) dst := newBenchmarkResponseWriter() @@ -207,7 +235,7 @@ func (r *recordingReader) Read(p []byte) (int, error) { if n == 0 { return 0, errors.New("reader received zero-length buffer") } - for i := 0; i < n; i++ { + for i := range n { p[i] = 'x' } r.left -= n @@ -260,4 +288,68 @@ func TestReverseProxyAvailableUpstreamsFiltersExcludedAndUnhealthy(t *testing.T) } } +func TestReverseProxyHeaderPolicyUsesAllHeaderValues(t *testing.T) { + proxy := &reverseProxyHandler{ + upstreams: []*reverseProxyUpstream{ + {key: "a", index: 0}, + {key: "b", index: 1}, + {key: "c", index: 2}, + }, + config: ReverseProxyConfig{ + LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())}, + }, + } + + c, _ := CreateTestContext(nil) + c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b"} + + selectedA, err := proxy.selectUpstream(c, nil) + if err != nil { + t.Fatalf("selectUpstream failed: %v", err) + } + selectedB, err := proxy.selectUpstream(c, nil) + if err != nil { + t.Fatalf("selectUpstream failed: %v", err) + } + if selectedA.key != selectedB.key { + t.Fatalf("expected stable selection for identical multi-value header, got %q and %q", selectedA.key, selectedB.key) + } + + c.Request.Header["X-Tenant"] = []string{"tenant-b", "tenant-a"} + selectedC, err := proxy.selectUpstream(c, nil) + if err != nil { + t.Fatalf("selectUpstream failed: %v", err) + } + if selectedC == nil { + t.Fatal("expected upstream for reordered multi-value header") + } +} + +func TestReverseProxyHeaderPolicyMatchesJoinCompatibility(t *testing.T) { + candidates := []*reverseProxyUpstream{ + {key: "a", index: 0}, + {key: "b", index: 1}, + {key: "c", index: 2}, + } + + testCases := [][]string{ + {"tenant-a"}, + {"tenant-a", "tenant-b"}, + {"", "tenant-b"}, + {"tenant-a", ""}, + {"", ""}, + } + + for _, values := range testCases { + got := reverseProxySelectHRWValues(candidates, values) + want := reverseProxySelectHRW(candidates, strings.Join(values, ",")) + if got == nil || want == nil { + t.Fatalf("expected non-nil upstreams for values %v", values) + } + if got.key != want.key { + t.Fatalf("expected joined compatibility for values %v, got %q want %q", values, got.key, want.key) + } + } +} + var _ io.Writer = (*benchmarkResponseWriter)(nil) diff --git a/reverseproxy_lb.go b/reverseproxy_lb.go index 02895fb..ce5e949 100644 --- a/reverseproxy_lb.go +++ b/reverseproxy_lb.go @@ -10,6 +10,7 @@ import ( "net/http" "net/textproto" "net/url" + "slices" "strings" "sync" "sync/atomic" @@ -199,7 +200,7 @@ func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates [] case reverseProxyLBPolicyHeader: if c.Request != nil && c.Request.Header != nil { if values, ok := c.Request.Header[policy.key]; ok { - return reverseProxySelectHRW(candidates, strings.Join(values, ",")) + return reverseProxySelectHRWValues(candidates, values) } } return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy)) @@ -277,6 +278,25 @@ func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reve return selected } +func reverseProxySelectHRWValues(candidates []*reverseProxyUpstream, values []string) *reverseProxyUpstream { + if len(candidates) == 0 { + return nil + } + if len(values) == 0 { + return reverseProxySelectRandom(candidates) + } + selected := candidates[0] + bestScore := reverseProxyHRWValuesScore(values, selected.key) + for _, upstream := range candidates[1:] { + score := reverseProxyHRWValuesScore(values, upstream.key) + if score > bestScore { + selected = upstream + bestScore = score + } + } + return selected +} + func reverseProxyHRWScore(key, upstreamKey string) uint64 { const ( offset64 = 14695981039346656037 @@ -296,6 +316,31 @@ func reverseProxyHRWScore(key, upstreamKey string) uint64 { return h } +func reverseProxyHRWValuesScore(values []string, upstreamKey string) uint64 { + const ( + offset64 = 14695981039346656037 + prime64 = 1099511628211 + ) + h := uint64(offset64) + for valueIndex, value := range values { + for i := 0; i < len(value); i++ { + h ^= uint64(value[i]) + h *= prime64 + } + if valueIndex+1 < len(values) { + h ^= ',' + h *= prime64 + } + } + h ^= 0xff + h *= prime64 + for i := 0; i < len(upstreamKey); i++ { + h ^= uint64(upstreamKey[i]) + h *= prime64 + } + return h +} + func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy { if policy.fallback != nil { return *policy.fallback @@ -360,10 +405,5 @@ func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, statu if status <= 0 { return false } - for _, unhealthyStatus := range config.UnhealthyStatus { - if status == unhealthyStatus { - return true - } - } - return false + return slices.Contains(config.UnhealthyStatus, status) } diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 9cbc734..6863da7 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -1866,7 +1866,9 @@ func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) if message != "echo:ping\n" { t.Fatalf("unexpected tunneled response body: %q", message) } - _ = pw.Close() + if err := pw.Close(); err != nil { + t.Fatalf("close tunneled request body: %v", err) + } select { case err := <-errCh: @@ -2215,7 +2217,9 @@ func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testi time.Sleep(50 * time.Millisecond) cancel() - _ = pw.CloseWithError(context.Canceled) + if err := pw.CloseWithError(context.Canceled); err != nil { + t.Fatalf("close request body with cancellation: %v", err) + } select { case <-writeErrCh: case <-time.After(2 * time.Second): diff --git a/serve_test.go b/serve_test.go index c717653..a02f1df 100644 --- a/serve_test.go +++ b/serve_test.go @@ -182,8 +182,7 @@ func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) { func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) { cfg := defaultRunConfig() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() if err := WithShutdownContext(ctx).apply(&cfg); err != nil { t.Fatalf("apply shutdown context option: %v", err) }