Compare commits

..

92 commits

Author SHA1 Message Date
WJQSERVER
d439662adf
Merge pull request #94 from infinite-iroha/dependabot/go_modules/github.com/WJQSERVER-STUDIO/httpc-0.9.3
Some checks are pending
Go Test / test (push) Waiting to run
build(deps): bump github.com/WJQSERVER-STUDIO/httpc from 0.9.2 to 0.9.3
2026-05-04 23:13:25 +08:00
dependabot[bot]
810ba788ae
build(deps): bump github.com/WJQSERVER-STUDIO/httpc from 0.9.2 to 0.9.3
Bumps [github.com/WJQSERVER-STUDIO/httpc](https://github.com/WJQSERVER-STUDIO/httpc) from 0.9.2 to 0.9.3.
- [Release notes](https://github.com/WJQSERVER-STUDIO/httpc/releases)
- [Commits](https://github.com/WJQSERVER-STUDIO/httpc/compare/v0.9.2...v0.9.3)

---
updated-dependencies:
- dependency-name: github.com/WJQSERVER-STUDIO/httpc
  dependency-version: 0.9.3
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-05-04 00:41:08 +00:00
WJQSERVER
de0e16852f
Merge pull request #92 from infinite-iroha/dependabot/go_modules/golang.org/x/net-0.53.0
build(deps): bump golang.org/x/net from 0.52.0 to 0.53.0
2026-04-24 10:15:30 +08:00
WJQSERVER
8ec77ecc9f
Merge pull request #93 from infinite-iroha/dependabot/go_modules/github.com/WJQSERVER-STUDIO/httpc-0.9.2
build(deps): bump github.com/WJQSERVER-STUDIO/httpc from 0.9.0 to 0.9.2
2026-04-24 10:15:18 +08:00
dependabot[bot]
b3b82b3c61
build(deps): bump github.com/WJQSERVER-STUDIO/httpc from 0.9.0 to 0.9.2
Bumps [github.com/WJQSERVER-STUDIO/httpc](https://github.com/WJQSERVER-STUDIO/httpc) from 0.9.0 to 0.9.2.
- [Release notes](https://github.com/WJQSERVER-STUDIO/httpc/releases)
- [Commits](https://github.com/WJQSERVER-STUDIO/httpc/compare/v0.9.0...v0.9.2)

---
updated-dependencies:
- dependency-name: github.com/WJQSERVER-STUDIO/httpc
  dependency-version: 0.9.2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-24 00:34:29 +00:00
dependabot[bot]
52db699db9
build(deps): bump golang.org/x/net from 0.52.0 to 0.53.0
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.52.0 to 0.53.0.
- [Commits](https://github.com/golang/net/compare/v0.52.0...v0.53.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-version: 0.53.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-23 00:34:06 +00:00
WJQSERVER
43fede96d5
Merge pull request #72 from infinite-iroha/break/v1
v1 break
2026-04-22 14:03:05 +08:00
WJQSERVER
01395dc942
Merge pull request #91 from infinite-iroha/feat/httpc-context-integration
Some checks are pending
Go Test / test (push) Waiting to run
feat: httpc 集成、MergeCtx cause 传播
2026-04-22 09:43:18 +08:00
wjqserver
3c40a3d6b5 fix: 修正 GetHTTPC 注释中方法名 typo
HTTPClient() → HTTPC()

Alina Agent生成
2026-04-22 09:37:45 +08:00
wjqserver
9dcab4b1ae fix: orDone 使用 sync.Once 修复 close(done) 竞态条件
修复 Gemini 审查意见:多 goroutine 同时 close(done) 可能导致 panic。
恢复 sync.Once 保证 channel 只被关闭一次。

Alina Agent生成
2026-04-22 09:37:19 +08:00
wjqserver
2d693e3b13 refactor: mergectx 简化结构,修复 Gemini 审查意见
- deadlineCtx 改为 cancelCtx 的子 context,建立父子层级关系
- 嵌入 cancelCtx/context.Context 直接提供 Done()/Err()/Deadline(),移除冗余方法
- orDone 中加入 cancelCtx,防止手动 cancel() 时 goroutine 泄漏
- 移除 cancelCtx/deadlineCtx/done/doneOnce 字段,struct 简化为 Context + parents
- 移除冗余 Cause() 方法(context.Cause 用 Value(&cancelCtxKey) 机制)
- 移除 Done()/Err() 显式实现,由嵌入 context 自动提供

Alina Agent生成
2026-04-22 09:27:53 +08:00
wjqserver
d8a5f200c1 fix: Client()/HTTPC() 优先使用 per-request HTTPClient 字段
修复 Gemini 审查意见:中间件设置的自定义 HTTPClient 不再被绕过。
Client() 和 HTTPC() 现在优先使用 Context.HTTPClient,
仅在未设置时回退到 Engine 默认实例。

Alina Agent生成
2026-04-22 09:17:02 +08:00
wjqserver
6006267d25 fix: Done() 使用 sync.Once 缓存 channel,避免重复创建 orDone goroutine
修复 Gemini 审查意见:多次调用 Done() 时不再重复创建 goroutine,
每个 mergedContext 最多产生 2 个 orDone goroutine。

Alina Agent生成
2026-04-22 09:00:01 +08:00
wjqserver
390190695f fix: 修复 examples/httpc 中 c.String 非常量 format string 编译错误 2026-04-22 08:51:42 +08:00
wjqserver
7487369125 improve: MergeCtx 支持 cause 传播, 使用 WithCancelCause/WithDeadlineCause
- 内部改用 context.WithCancelCause 和 WithDeadlineCause, 父 context 取消原因自动传播
- Value() 先检查嵌入 context 再查 parents, 确保 context.Cause() 正确工作
- Done()/Err() 同时监听 cancelCtx 和 deadlineCtx, 支持 deadline 到期 cause
- 新增 Cause() 便捷方法
- 单 parent 短路径改用 WithCancelCause 保留 cause
- 新增 mergectx_test.go, 覆盖 cause 传播、deadline、Value 查找等场景
- API 兼容: 返回类型保持 CancelFunc 不变

Alina Agent生成
2026-04-22 08:43:36 +08:00
wjqserver
e7c7d5e41f fix: 修复 Client() 返回过时 HTTPClient 的问题
- 将 Client() 从返回 c.HTTPClient 改为返回 c.engine.HTTPClient
- 与 HTTPC() 方法保持一致
2026-04-22 07:30:40 +08:00
wjqserver
4f262b2497 docs: 添加 httpc 集成文档和示例
- 新增 examples/httpc 示例代码
- 新增 docs/httpc.md 文档说明
2026-04-22 07:13:55 +08:00
wjqserver
f2295c3084 feat: httpc 集成改进,自动关联请求 Context
- 新增 contextHTTPClient 包装器,自动关联请求 Context
- 新增 Context.HTTPC() 方法返回 contextHTTPClient
- Client() 标记为 Deprecated
- 添加 GetHTTPC() go:fix inline 兼容函数

当请求被取消时,出站 HTTP 请求也会自动取消。
2026-04-21 22:55:26 +08:00
WJQSERVER
b83e536def
Merge pull request #90 from infinite-iroha/feat/logger-interface
feat: 引入 Logger 接口抽象
2026-04-21 22:28:31 +08:00
wjqserver
10033f4a17 docs: 修复审查意见,修正设计文档与实现的不一致
- 将设计文档中 logReco 改为 LogReco,与实际实现保持一致
- LogReco 字段保持公开但标记为 Deprecated
2026-04-21 21:49:42 +08:00
wjqserver
c8b14ef43a feat: 引入 Logger 接口抽象,支持自定义日志实现
- 新增 Logger 接口定义,支持 zap/slog 等自定义实现
- 新增 CloserLogger 接口用于支持关闭操作
- Engine 新增 SetLogger/GetLogger 方法使用接口
- 新增 compat.go 兼容层,保留 reco 兼容方法
- 新增 slog 适配器示例
- 删除 zap 示例
- Context.GetLogger() 返回接口类型
2026-04-21 19:43:56 +08:00
WJQSERVER
2581697771
Merge pull request #89 from infinite-iroha/docs/add-middleware-examples
docs: 补充中间件文档
2026-04-21 18:34:34 +08:00
wjqserver
58fd877ae2 docs: 修复审查意见,统一术语并补充注册顺序说明
- 补充中间件注册顺序说明(必须在路由定义之前)
- 统一术语:'组中间件' → '路由组中间件'
- 统一流程图术语
2026-04-21 18:32:10 +08:00
wjqserver
fce12ee7e7 docs: 补充中间件文档,添加路由级中间件和执行顺序说明
- 添加路由级中间件使用示例
- 说明在创建组时直接传入中间件的方法
- 添加中间件执行顺序章节,清晰展示全局/组/路由中间件的执行流程
2026-04-21 18:19:44 +08:00
WJQSERVER
d9328c3176
Merge pull request #87 from infinite-iroha/feat/headers-ops-v1
feat: 反向代理头部操作功能 (Headers Operations)
2026-04-21 18:16:50 +08:00
WJQSERVER
8fdb16ae1e
Merge pull request #88 from infinite-iroha/feat/replacer-dynamic-vars
feat: 实现动态请求变量替换
2026-04-21 18:14:38 +08:00
wjqserver
1243d2d37a fix: address PR review for replacer — nil check, EscapedPath, scheme reuse, perf
- add req.URL nil guard
- use EscapedPath for {path} to avoid illegal header characters
- reuse reverseProxyRequestScheme for {scheme} consistency
- replace strings.NewReplacer with struct fields + strings.ReplaceAll
2026-04-21 18:02:57 +08:00
wjqserver
fa925582d7 feat: implement dynamic request variable replacement in replacer
Replace the no-op reverseProxyReplacer.Replace with strings.NewReplacer
supporting {method}, {host}, {path}, {query}, {scheme}, {uri}, {proto}
2026-04-21 17:36:38 +08:00
wjqserver
5d9bb3187d perf: optimize wildcard header deletion; test: assert invalid regex returns 500
- refactor Delete logic to iterate headers once, reducing ToLower calls
  from O(patterns * headers) to O(headers)
- rewrite invalid regex test to verify runtime 500 response
2026-04-21 17:20:30 +08:00
wjqserver
c0e31c449e fix: address PR review comments for header ops
- fix Deferred response header logic: apply headers after ModifyResponse callback
- refactor applyToRequest to eliminate code duplication via applyTo
- remove redundant Sec-WebSocket-Accept condition check
2026-04-21 16:58:14 +08:00
wjqserver
93f5edc6eb feat: add Replace support for reverse proxy header ops
- Support substring replacement via Search field
- Support regex replacement via SearchRegexp field (precompiled at Provision)
- Support wildcard field name '*' to apply replacement to all headers
- Validate that Search and SearchRegexp are mutually exclusive
- Add 5 functional tests and 9 benchmark tests covering all operations

Benchmark results (no external allocs in hot paths):
  Add:              527 ns/op, 448 B/op,  5 allocs/op
  Set:              891 ns/op, 480 B/op,  7 allocs/op
  Delete(single):   476 ns/op,  48 B/op,  3 allocs/op
  Delete(wildcard): 1073 ns/op, 104 B/op,  7 allocs/op
  Replace(sub):     303 ns/op,  64 B/op,  2 allocs/op
  Replace(regex):  1503 ns/op, 224 B/op,  6 allocs/op
  Replace(wild):    731 ns/op,  80 B/op,  4 allocs/op
  Mixed:           1527 ns/op, 128 B/op,  7 allocs/op
2026-04-21 16:34:25 +08:00
wjqserver
06a6d42de1 feat: add headers operations for reverse proxy
- Add HeaderOps struct for Add/Set/Delete header operations
- Add RespHeaderOps for response header manipulation with deferred support
- Support wildcard patterns for header deletion (prefix-*, *suffix, *substring*)
- Apply request headers before forwarding to upstream
- Apply response headers before sending to client
- Add comprehensive test coverage for header operations

Usage example:
  engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
    Target: target,
    RequestHeaders: &HeaderOps{
      Add: map[string][]string{"X-Custom": {"value"}},
      Delete: []string{"X-Sensitive-*"},
    },
    ResponseHeaders: &RespHeaderOps{
      HeaderOps: &HeaderOps{
        Set: map[string][]string{"X-Frame-Options": {"DENY"}},
      },
    },
  }))
2026-04-21 16:34:25 +08:00
wjqserver
3b5f2c81af fix: optimize Sec-WebSocket-Accept header check
- Remove unused variable assignment in condition
- Direct comparison is more efficient (no extra variable allocation)
- Maintains same defensive check behavior
2026-04-21 16:34:25 +08:00
wjqserver
b008fc8e61 fix: only remove Sec-WebSocket-Accept if present in HTTP/2 Extended CONNECT
- Check if Sec-WebSocket-Accept header exists before deleting
- This prevents unnecessary header manipulation when backend doesn't send it
- Maintains compatibility with backends that may or may not include this header
2026-04-21 16:34:25 +08:00
WJQSERVER
0f7cf23abb
Merge pull request #86 from infinite-iroha/perf/go126-memory-pass
Perf/go126 memory pass
2026-04-21 16:29:12 +08:00
wjqserver
54f7de0c60 perf: modernize io paths and reduce proxy allocations 2026-04-11 01:43:34 +08:00
wjqserver
02861b5537 perf: avoid header policy join allocations 2026-04-10 21:55:21 +08:00
wjqserver
7c37d4c38c perf: fast-path default 404 and 405 responses 2026-04-10 21:44:31 +08:00
WJQSERVER
271e54eb4d
Merge pull request #84 from infinite-iroha/perf/go126-memory-pass
Perf/go126 memory pass
2026-04-10 07:21:40 +08:00
wjqserver
017bb13295 perf: reuse reverse proxy candidate slices 2026-04-10 06:18:52 +08:00
wjqserver
71a344a3de perf: reuse reverse proxy copy buffers 2026-04-10 06:08:55 +08:00
里見 灯花
efa1e3fb3f
Merge pull request #82 from infinite-iroha/break/v1-redesign-run-api
feat: redesign server startup around Run options
2026-04-07 20:54:47 +08:00
WJQSERVER
7cb777225f
Merge pull request #83 from infinite-iroha/break/v1-redirect-host-strategy
feat: add redirect host selection options
2026-04-07 20:50:09 +08:00
wjqserver
121679b44e fix: preserve IPv6 brackets in redirects
Re-wrap bare IPv6 hosts after stripping ports so HTTPS redirect URLs stay valid. Add a regression test covering bracketed IPv6 hosts in redirect responses.
2026-04-07 20:31:10 +08:00
wjqserver
9e57f5a5f5 fix: stop redirect siblings on shutdown
Make the non-graceful HTTPS redirect path shut down all sibling servers after any server returns, so cleanup stays consistent with the graceful path and partial shutdowns do not leave the redirect listener running.
2026-04-07 20:00:58 +08:00
wjqserver
e2cf08d5dd feat: add redirect host selection options
Support explicit redirect host source selection for HTTP-to-HTTPS redirects with ordered header lookup, fixed host mode, and strict validation. Document the new redirect option relationships and add focused tests for 426 fallback, conflict checks, and non-graceful startup errors.
2026-04-07 19:49:13 +08:00
wjqserver
e4d3eed379 feat: redesign server startup around Run options
Replace the old RunShutdown and RunTLS style entry points with a single Run(opts...) API for v1. Add focused startup semantics tests, keep TLS and graceful shutdown independent, ensure sibling servers are cleaned up on startup failure, and update docs to match the new option-based startup model.
2026-04-07 17:44:55 +08:00
WJQSERVER
fca9bbd3ef
Merge pull request #81 from infinite-iroha/feat/optimize-route-match-hotpath
Feat/optimize route match hotpath
2026-04-07 09:58:10 +08:00
wjqserver
987ea81329 fix: avoid fixed-path miss panic and trim 405 fallback work 2026-04-07 09:57:16 +08:00
wjqserver
fa027347d3 fix: reduce default error response overhead
Encode the built-in 404 and 405 payload with a fixed struct instead of a map so default error pages allocate less on the hot miss path. Add a regression test to keep the JSON shape stable.
2026-04-07 09:35:39 +08:00
wjqserver
57847fa446 fix: avoid unsafe header buffer reuse
Use safe string copies for pooled header buffers and simplify case-insensitive lookup buffering now that the pseudo stack path was ineffective. This addresses review concerns without changing the routing semantics.
2026-04-07 09:32:14 +08:00
wjqserver
2d4aefc86e fix: cut redirect and allow-path routing overhead
Reuse fixed-path and Allow-header buffers so redirect and OPTIONS handling stop rebuilding temporary data on every request. Cache fallback chains and add regression coverage for redirect, 404, 405, and Allow behavior to keep the faster miss paths stable.
2026-04-07 09:06:56 +08:00
wjqserver
5d979e5670 fix: reduce per-request context and fallback overhead
Make Context keys lazy so requests that never call Set stop allocating on reset. Reuse stable 404 and 405 handlers and add focused benchmarks so ServeHTTP miss paths stay measurable.
2026-04-07 08:39:10 +08:00
wjqserver
6acac9edce fix: streamline route matcher backtracking
Avoid rebuilding skipped-node state during wildcard fallback so the matcher no longer loops on the same static branch and stops allocating on the hot path. Add focused route benchmarks and regression coverage to keep the optimized path stable.
2026-04-07 08:27:00 +08:00
WJQSERVER
b1ce4d584e
Merge pull request #80 from infinite-iroha/fix/v1-runshutdown-http-only
fix: keep RunShutdown on HTTP path
2026-04-07 07:53:36 +08:00
wjqserver
7db3d32d7b test: improve serve startup failure diagnostics 2026-04-07 07:51:39 +08:00
wjqserver
d12e887858 fix: keep RunShutdown on HTTP path 2026-04-07 07:46:06 +08:00
WJQSERVER
7f69d5668e
Merge pull request #79 from infinite-iroha/fix/v1-findcaseinsensitivepath-wildchild-order
fix: avoid panic in case-insensitive wildcard lookup
2026-04-07 07:25:28 +08:00
wjqserver
70f8cc6159 fix: avoid panic in case-insensitive wildcard lookup 2026-04-07 07:19:33 +08:00
WJQSERVER
863f984990
Merge pull request #78 from infinite-iroha/break/v1-enhance-reverse-proxy
fix(reverseproxy): bridge websocket extended connect upstreams
2026-04-03 00:42:01 +08:00
wjqserver
1a6325d461 feat: improve reverse proxy tunnel management with sync.Once and better error handling 2026-04-03 00:29:15 +08:00
wjqserver
d53693952a refactor: improve TLS config handling and add bridge connection tests 2026-04-02 22:13:50 +08:00
wjqserver
dcdb1504a3 feat: add robust transport cloning and improve header handling in reverse proxy 2026-04-02 19:58:34 +08:00
wjqserver
20dc6e4047 refactor: cache ResponseController in H2ReadWriteCloser for better performance 2026-04-02 19:44:02 +08:00
wjqserver
7abedc1ace enhance: improve reverse proxy error handling and add tests 2026-04-02 19:33:18 +08:00
wjqserver
50c6a23614 refactor: simplify reverse proxy bridged connection handling by removing unused bufio 2026-04-02 18:50:27 +08:00
wjqserver
a9c1662333 fix(reverseproxy): bridge websocket extended connect upstreams 2026-04-02 18:19:41 +08:00
WJQSERVER
0d7721a24c
Merge pull request #77 from infinite-iroha/break/v1-enhance-reverse-proxy
feat(reverseproxy): add upstream balancing and protocol improvements
2026-04-02 15:32:41 +08:00
wjqserver
919236665b feat(reverseproxy): add upstream balancing and failover 2026-04-02 14:40:56 +08:00
wjqserver
59f190ce3a fix(http2): preserve extended CONNECT tunnel shutdown semantics 2026-04-02 04:09:43 +08:00
wjqserver
2165cc4114 feat(http2): support OPTIONS * and extended CONNECT 2026-04-02 03:53:17 +08:00
wjqserver
ed44c592d3 fix(reverseproxy): align forwarding and tunnel semantics 2026-04-02 03:18:49 +08:00
WJQSERVER
c019f24e99
Merge pull request #76 from infinite-iroha/break/v1-fix-filetext-bodylimit
Break/v1 fix filetext bodylimit
2026-04-01 00:09:30 +08:00
wjqserver
e6ff0fa6b9 fix(maxreader): treat non-positive limits as unlimited 2026-04-01 00:03:23 +08:00
wjqserver
91c50536c4 fix(maxreader): avoid hangs after reaching body limit 2026-03-31 23:37:02 +08:00
wjqserver
85cc9b5cf6 fix(form): align PostForm parsing with body limit handling 2026-03-31 18:59:32 +08:00
wjqserver
64e2ad9e7b Fix FileText status code and unify request body size limits
- FileText: now respects the provided status code instead of defaulting to 200 OK
- Request body limits: prepareRequestBody() is now only called when MaxRequestBodySize > 0
  - ShouldBindJSON, ShouldBindWANF, ShouldBindGOB, ShouldBindForm, GetReqBody, PostForm
    all now use the original c.Request.Body path when no limit is configured
- maxBytesReader: fixed exact-limit boundary case where body size == limit was
  incorrectly rejected
- Added regression tests for FileText status codes and body limit behavior

All existing tests pass, and new tests verify the corrected behavior.
2026-03-31 16:38:04 +08:00
WJQSERVER
ef965f4a6a
Merge pull request #75 from infinite-iroha/break/v1-fix-mergectx
Some checks are pending
Go Test / test (push) Waiting to run
fix: mergedContext.Value 遍历父 contexts 查找值
2026-03-30 16:45:09 +08:00
wjqserver
d90d043811 fix: mergedContext.Value 遍历父 contexts 查找值 2026-03-30 02:21:11 +08:00
WJQSERVER
8dc7d8c136
Merge pull request #74 from infinite-iroha/break/v1-feat-add-samesite
Some checks are pending
Go Test / test (push) Waiting to run
feat(cookie): add SameSite support to SetCookie method
2026-03-30 01:50:43 +08:00
wjqserver
9f210deadf fix(cookie): add warning log when multiple SameSite values provided 2026-03-30 01:42:10 +08:00
wjqserver
7be49b96c8 feat(cookie): add SameSite support to SetCookie method 2026-03-30 01:33:00 +08:00
WJQSERVER
3aa84f5dcf
Merge pull request #73 from infinite-iroha/break/v1-feat-add-buf-methods
feat(render): add buffered variants for JSON/GOB/WANF/HTML
2026-03-30 01:22:21 +08:00
WJQSERVER
fba6fedfc5
Update context.go
Some checks failed
Go Test / test (push) Has been cancelled
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-30 01:17:59 +08:00
WJQSERVER
d0fa14c3c5
Update context.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-30 01:17:32 +08:00
wjqserver
45c6d36748 fix(HTMLBuf): return 500 on template error, no content
- Remove fallback to HTML() on template rendering failure
- Return 500 error without writing any content on error
- Only fallback to HTML() when renderer is nil or unsupported type
- Prevents multiple response writes
2026-03-30 01:02:37 +08:00
wjqserver
b4e45610b2 refactor(HTMLBuf): delegate fallback to HTML() method
Reduce code duplication by calling c.HTML() for fallback cases:
- When template rendering fails
- When HTMLRender is not configured
- When HTMLRender is not a *template.Template

This ensures consistent behavior between HTMLBuf and HTML methods.
2026-03-30 00:29:50 +08:00
wjqserver
b09595e745 fix: address PR #73 review feedback
- Remove redundant c.Errorf call in JSONBuf
- Consolidate error wrapping in HTMLBuf to avoid duplicate fmt.Errorf calls
- Keep error handling consistent across all Buf methods
2026-03-29 23:43:29 +08:00
wjqserver
6e33bc48aa fix: simplify error handling in Buf methods
Consolidate error wrapping to avoid redundant fmt.Errorf calls.
Follows PR #73 review feedback.
2026-03-29 18:45:08 +08:00
wjqserver
7e15181c0b feat(render): add Buf variants for JSON/GOB/WANF/HTML
Add buffered rendering methods that encode to a buffer first, then
write the response. This allows returning a proper 500 status code
if encoding fails, unlike the streaming variants which must write
the status code before encoding (an inherent HTTP constraint).

New methods:
- JSONBuf(code int, obj any)
- GOBBuf(code int, obj any)
- WANFBuf(code int, obj any)
- HTMLBuf(code int, name string, obj any)

Trade-off: one extra memory allocation per call in exchange for
correct error status codes on encoding failure.
2026-03-29 17:03:57 +08:00
wjqserver
559aefeb85 fix(SSE): capture Writer before goroutine, use select for channel send
Address PR review feedback:
- Capture w := c.Writer before goroutine start, use w (not c.Writer)
  inside the goroutine to avoid holding *Context reference
- Move channel send into select alongside context cancellation in all
  examples and tests, preventing goroutine leak when client disconnects
  while blocked on unbuffered send
2026-03-29 16:50:37 +08:00
wjqserver
2f94763c65 fix(SSE)!: redesign EventStreamChan to prevent context pool recycling
BREAKING CHANGE: EventStreamChan signature changed from
  (chan<- Event, <-chan error)
to
  (eventChan <-chan Event)
The caller now creates and passes the channel instead of receiving it.
The errChan return value is removed.

The old non-blocking design allowed the handler to return before the SSE
stream ended, causing ServeHTTP to return the Context to the pool while
the internal goroutine was still writing to the pooled writer — a data
race across requests. The new blocking design keeps the handler inside
EventStreamChan until the event channel is closed or the client
disconnects, ensuring the Context remains bound throughout the stream.

- Caller creates channel, producer goroutine sends events
- EventStreamChan blocks handler until stream ends
- Internal goroutine captures stable references (Flusher, context.Context)
  instead of holding *Context pointer
- Nil guard on Flusher type assertion
- Add sse_test.go covering blocking, disconnect, and event format
- Update docs/sse.md for new API
2026-03-29 15:42:01 +08:00
50 changed files with 9257 additions and 798 deletions

1
.gitignore vendored
View file

@ -1 +1,2 @@
test test
/bench_route_match_baseline.txt

View file

@ -59,9 +59,9 @@ func main() {
c.String(http.StatusOK, "Hello, %s! You seem %s.", name, query) c.String(http.StatusOK, "Hello, %s! You seem %s.", name, query)
}) })
// 启动服务器 (支持优雅关闭) // 启动服务器(通过 WithGracefulShutdown 启用优雅关闭)
log.Println("Touka Server starting on :8080...") log.Println("Touka Server starting on :8080...")
if err := r.RunShutdown(":8080", 10*time.Second); err != nil { if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
log.Fatalf("Touka server failed to start: %v", err) log.Fatalf("Touka server failed to start: %v", err)
} }
} }

View file

@ -70,13 +70,13 @@ func main() {
r.SetGlobalMaxRequestBodySize(10 * 1024 * 1024) // 10 MB r.SetGlobalMaxRequestBodySize(10 * 1024 * 1024) // 10 MB
// ... 其他配置 // ... 其他配置
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
``` ```
#### 1.3. 服务器生命周期管理 #### 1.3. 服务器生命周期管理
Touka 提供了对底层 `*http.Server` 的完全控制,并内置了优雅关闭的逻辑。 Touka 提供了对底层 `*http.Server` 的完全控制,并可通过 `Run(...)` 的启动选项启用优雅关闭逻辑。
```go ```go
func main() { func main() {
@ -90,11 +90,11 @@ func main() {
fmt.Println("自定义的 HTTP 服务器配置已应用") fmt.Println("自定义的 HTTP 服务器配置已应用")
}) })
// 启动服务器,并支持优雅关闭 // 启动服务器,并通过 Run 选项启用优雅关闭
// RunShutdown 会阻塞,直到收到 SIGINT 或 SIGTERM 信号 // Run(...) 会阻塞当前 goroutine
// 第二个参数是优雅关闭的超时时间 // WithGracefulShutdown(10*time.Second) 表示在关闭时最多等待 10 秒
fmt.Println("服务器启动于 :8080") fmt.Println("服务器启动于 :8080")
if err := r.RunShutdown(":8080", 10*time.Second); err != nil { if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
log.Fatalf("服务器启动失败: %v", err) log.Fatalf("服务器启动失败: %v", err)
} }
} }
@ -187,7 +187,7 @@ func main() {
} }
} }
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
func AuthMiddleware() touka.HandlerFunc { func AuthMiddleware() touka.HandlerFunc {
@ -313,7 +313,7 @@ func main() {
}) })
}) })
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
// templates/index.html // templates/index.html
@ -400,7 +400,7 @@ func main() {
c.JSON(http.StatusOK, touka.H{"status": "ok", "request_id": requestID}) c.JSON(http.StatusOK, touka.H{"status": "ok", "request_id": requestID})
}) })
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
``` ```
@ -483,7 +483,7 @@ func main() {
// 静态文件服务,如果文件不存在,也会被上面的 ErrorHandler 捕获 // 静态文件服务,如果文件不存在,也会被上面的 ErrorHandler 捕获
r.StaticDir("/files", "./non-existent-dir") r.StaticDir("/files", "./non-existent-dir")
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
``` ```
@ -546,7 +546,7 @@ func main() {
// 所有对 / 的访问都会映射到嵌入的 frontend/dist 目录 // 所有对 / 的访问都会映射到嵌入的 frontend/dist 目录
r.StaticFS("/", http.FS(subFS)) r.StaticFS("/", http.FS(subFS))
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
``` ```

52
compat.go Normal file
View file

@ -0,0 +1,52 @@
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2024 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
import (
"github.com/WJQSERVER-STUDIO/httpc"
"github.com/fenthope/reco"
)
// --- reco 兼容函数 ---
// GetLogReco 返回底层的 reco.Logger 实例
// 用于需要访问 reco 特定功能的场景
// 如果当前 logger 不是 *reco.Logger 类型,返回 nil
//
//go:fix inline
func (engine *Engine) GetLogReco() *reco.Logger {
return engine.LogReco
}
// SetLogReco 设置 reco.Logger 实例
// 用于向后兼容,等价于 SetLogger(l)
//
//go:fix inline
func (engine *Engine) SetLogReco(l *reco.Logger) {
engine.LogReco = l
engine.logger = l
}
// GetLoggerReco 返回底层的 reco.Logger 实例
// 用于需要访问 reco 特定功能的场景
// 如果当前 logger 不是 *reco.Logger 类型,返回 nil
//
//go:fix inline
func (c *Context) GetLoggerReco() *reco.Logger {
if rl, ok := c.engine.logger.(*reco.Logger); ok {
return rl
}
return c.engine.LogReco
}
// --- httpc 兼容函数 ---
// GetHTTPC 返回底层的 httpc.Client 实例
// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context
//
//go:fix inline
func (c *Context) GetHTTPC() *httpc.Client {
return c.Client()
}

View file

@ -26,7 +26,6 @@ import (
"time" "time"
"github.com/WJQSERVER/wanf" "github.com/WJQSERVER/wanf"
"github.com/fenthope/reco"
"github.com/go-json-experiment/json" "github.com/go-json-experiment/json"
"github.com/WJQSERVER-STUDIO/go-utils/iox" "github.com/WJQSERVER-STUDIO/go-utils/iox"
@ -44,6 +43,8 @@ type Context struct {
handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler) handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler)
index int8 // 当前执行到处理链的哪个位置 index int8 // 当前执行到处理链的哪个位置
requestBodyPrepared bool
mu sync.RWMutex mu sync.RWMutex
Keys map[string]any // 用于在中间件之间传递数据 Keys map[string]any // 用于在中间件之间传递数据
@ -71,6 +72,12 @@ type Context struct {
// skippedNodes 用于记录跳过的节点信息,以便回溯 // skippedNodes 用于记录跳过的节点信息,以便回溯
// 通常在处理嵌套路由时使用 // 通常在处理嵌套路由时使用
SkippedNodes []skippedNode SkippedNodes []skippedNode
// fixedPathBuf 用于复用固定路径重定向时的大小写修正结果缓冲.
fixedPathBuf []byte
allowedMethodsBuf []string
allowHeaderBuf []byte
} }
// --- Context 相关方法实现 --- // --- Context 相关方法实现 ---
@ -95,19 +102,42 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
} }
c.handlers = nil c.handlers = nil
c.index = -1 // 初始为 -1`Next()` 将其设置为 0 c.index = -1 // 初始为 -1`Next()` 将其设置为 0
c.Keys = make(map[string]any) // 每次请求重新创建 map避免数据污染 c.Keys = nil // 仅在首次 Set 时创建,避免每个请求都分配 map
c.Errors = c.Errors[:0] // 清空 Errors 切片 c.Errors = c.Errors[:0] // 清空 Errors 切片
c.queryCache = nil // 清空查询参数缓存 c.queryCache = nil // 清空查询参数缓存
c.formCache = nil // 清空表单数据缓存 c.formCache = nil // 清空表单数据缓存
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式 c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize
c.requestBodyPrepared = false
if cap(c.SkippedNodes) > 0 { if cap(c.SkippedNodes) > 0 {
c.SkippedNodes = c.SkippedNodes[:0] c.SkippedNodes = c.SkippedNodes[:0]
} else { } else {
c.SkippedNodes = make([]skippedNode, 0, 256) c.SkippedNodes = make([]skippedNode, 0, 256)
} }
if cap(c.fixedPathBuf) > 0 {
c.fixedPathBuf = c.fixedPathBuf[:0]
}
if cap(c.allowedMethodsBuf) > 0 {
c.allowedMethodsBuf = c.allowedMethodsBuf[:0]
}
if cap(c.allowHeaderBuf) > 0 {
c.allowHeaderBuf = c.allowHeaderBuf[:0]
}
}
func (c *Context) writeResponseBody(data []byte, contextMsg string) {
if len(data) == 0 {
return
}
if _, err := c.Writer.Write(data); err != nil {
wrapped := fmt.Errorf("%s: %w", contextMsg, err)
c.AddError(wrapped)
if c.engine != nil && c.engine.logger != nil {
c.engine.logger.Errorf("%s: %v", contextMsg, err)
}
}
} }
// Next 在处理链中执行下一个处理函数 // Next 在处理链中执行下一个处理函数
@ -237,6 +267,18 @@ func (c *Context) SetMaxRequestBodySize(size int64) {
c.MaxRequestBodySize = size c.MaxRequestBodySize = size
} }
func (c *Context) prepareRequestBody() io.ReadCloser {
if c.Request == nil || c.Request.Body == nil {
return nil
}
if c.requestBodyPrepared || c.MaxRequestBodySize <= 0 {
return c.Request.Body
}
c.Request.Body = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
c.requestBodyPrepared = true
return c.Request.Body
}
// Query 从 URL 查询参数中获取值 // Query 从 URL 查询参数中获取值
// 懒加载解析查询参数,并进行缓存 // 懒加载解析查询参数,并进行缓存
func (c *Context) Query(key string) string { func (c *Context) Query(key string) string {
@ -258,7 +300,39 @@ func (c *Context) DefaultQuery(key, defaultValue string) string {
// 懒加载解析表单数据,并进行缓存 // 懒加载解析表单数据,并进行缓存
func (c *Context) PostForm(key string) string { func (c *Context) PostForm(key string) string {
if c.formCache == nil { if c.formCache == nil {
c.Request.ParseMultipartForm(defaultMemory) // 解析 multipart/form-data 或 application/x-www-form-urlencoded if c.MaxRequestBodySize > 0 {
c.prepareRequestBody()
}
contentType := c.Request.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
c.AddError(fmt.Errorf("parse form error: %w", err))
c.formCache = make(url.Values)
return ""
}
switch mediaType {
case "multipart/form-data":
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
c.AddError(fmt.Errorf("parse form error: %w", err))
c.formCache = make(url.Values)
return ""
}
case "application/x-www-form-urlencoded":
if err := c.Request.ParseForm(); err != nil {
c.AddError(fmt.Errorf("parse form error: %w", err))
c.formCache = make(url.Values)
return ""
}
default:
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
if !errors.Is(err, http.ErrNotMultipart) {
c.AddError(fmt.Errorf("parse form error: %w", err))
c.formCache = make(url.Values)
return ""
}
}
}
c.formCache = c.Request.PostForm c.formCache = c.Request.PostForm
} }
return c.formCache.Get(key) return c.formCache.Get(key)
@ -282,20 +356,20 @@ func (c *Context) Param(key string) string {
func (c *Context) Raw(code int, contentType string, data []byte) { func (c *Context) Raw(code int, contentType string, data []byte) {
c.Writer.Header().Set("Content-Type", contentType) c.Writer.Header().Set("Content-Type", contentType)
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
c.Writer.Write(data) c.writeResponseBody(data, "failed to write raw response")
} }
// String 向响应写入格式化的字符串 // String 向响应写入格式化的字符串
func (c *Context) String(code int, format string, values ...any) { func (c *Context) String(code int, format string, values ...any) {
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
c.Writer.Write(fmt.Appendf(nil, format, values...)) c.writeResponseBody(fmt.Appendf(nil, format, values...), "failed to write string response")
} }
// Text 向响应写入无需格式化的string // Text 向响应写入无需格式化的string
func (c *Context) Text(code int, text string) { func (c *Context) Text(code int, text string) {
c.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8") c.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
c.Writer.Write([]byte(text)) c.writeResponseBody([]byte(text), "failed to write text response")
} }
// FileText // FileText
@ -338,8 +412,11 @@ func (c *Context) FileText(code int, filePath string) {
} }
c.SetHeader("Content-Type", "text/plain; charset=utf-8") c.SetHeader("Content-Type", "text/plain; charset=utf-8")
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size()))
c.SetBodyStream(file, int(fileInfo.Size())) c.Writer.WriteHeader(code)
if _, err := iox.Copy(c.Writer, file); err != nil {
c.AddError(fmt.Errorf("failed to write file %s to response: %w", cleanPath, err))
}
} }
/* /*
@ -417,6 +494,22 @@ func (c *Context) JSON(code int, obj any) {
} }
} }
// JSONBuf 先将 JSON 编码到 buffer, 成功后再写入状态码和响应体.
// 与 JSON 相比,编码失败时可以正确返回 500 状态码,代价是多一次内存分配.
func (c *Context) JSONBuf(code int, obj any) {
var buf bytes.Buffer
if err := json.MarshalWrite(&buf, obj); err != nil {
errMsg := fmt.Errorf("failed to marshal JSON: %w", err)
c.AddError(errMsg)
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
return
}
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
c.Writer.WriteHeader(code)
c.writeResponseBody(buf.Bytes(), "failed to write buffered JSON response")
}
// GOB 向响应写入GOB数据 // GOB 向响应写入GOB数据
// 设置 Content-Type 为 application/octet-stream // 设置 Content-Type 为 application/octet-stream
func (c *Context) GOB(code int, obj any) { func (c *Context) GOB(code int, obj any) {
@ -431,6 +524,21 @@ func (c *Context) GOB(code int, obj any) {
} }
} }
// GOBBuf 先将 GOB 编码到 buffer, 成功后再写入状态码和响应体.
func (c *Context) GOBBuf(code int, obj any) {
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
if err := encoder.Encode(obj); err != nil {
errMsg := fmt.Errorf("failed to encode GOB: %w", err)
c.AddError(errMsg)
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
return
}
c.Writer.Header().Set("Content-Type", "application/octet-stream")
c.Writer.WriteHeader(code)
c.writeResponseBody(buf.Bytes(), "failed to write buffered GOB response")
}
// WANF向响应写入WANF数据 // WANF向响应写入WANF数据
// 设置 application/vnd.wjqserver.wanf; charset=utf-8 // 设置 application/vnd.wjqserver.wanf; charset=utf-8
func (c *Context) WANF(code int, obj any) { func (c *Context) WANF(code int, obj any) {
@ -445,6 +553,21 @@ func (c *Context) WANF(code int, obj any) {
} }
} }
// WANFBuf 先将 WANF 编码到 buffer, 成功后再写入状态码和响应体.
func (c *Context) WANFBuf(code int, obj any) {
var buf bytes.Buffer
encoder := wanf.NewStreamEncoder(&buf)
if err := encoder.Encode(obj); err != nil {
errMsg := fmt.Errorf("failed to encode WANF: %w", err)
c.AddError(errMsg)
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
return
}
c.Writer.Header().Set("Content-Type", "application/vnd.wjqserver.wanf; charset=utf-8")
c.Writer.WriteHeader(code)
c.writeResponseBody(buf.Bytes(), "failed to write buffered WANF response")
}
// HTML 渲染 HTML 模板 // HTML 渲染 HTML 模板
// 如果 Engine 配置了 HTMLRender则使用它进行渲染 // 如果 Engine 配置了 HTMLRender则使用它进行渲染
// 否则,会进行简单的字符串输出 // 否则,会进行简单的字符串输出
@ -466,7 +589,37 @@ func (c *Context) HTML(code int, name string, obj any) {
// 可以扩展支持其他渲染器接口 // 可以扩展支持其他渲染器接口
} }
// 默认简单输出,用于未配置 HTMLRender 的情况 // 默认简单输出,用于未配置 HTMLRender 的情况
c.Writer.Write(fmt.Appendf(nil, "<!-- HTML rendered for %s -->\n<pre>%v</pre>", name, obj)) c.writeResponseBody(fmt.Appendf(nil, "<!-- HTML rendered for %s -->\n<pre>%v</pre>", name, obj), "failed to write HTML response")
}
// HTMLBuf 先将 HTML 模板渲染到 buffer, 成功后再写入状态码和响应体.
// 如果模板渲染失败,则返回 500 错误且不写入任何内容.
func (c *Context) HTMLBuf(code int, name string, obj any) {
if c.engine == nil || c.engine.HTMLRender == nil {
// 没有渲染器,回退到简单输出
c.HTML(code, name, obj)
return
}
if tpl, ok := c.engine.HTMLRender.(*template.Template); ok {
var buf bytes.Buffer
err := tpl.ExecuteTemplate(&buf, name, obj)
if err != nil {
// 渲染失败,记录错误并返回 500不写入任何内容
errMsg := fmt.Errorf("failed to render HTML template '%s': %w", name, err)
c.AddError(errMsg)
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
return
}
// 渲染成功,写入响应
c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8")
c.Writer.WriteHeader(code)
c.writeResponseBody(buf.Bytes(), "failed to write buffered HTML response")
return
}
// 不支持的渲染器类型,回退到简单输出
c.HTML(code, name, obj)
} }
// Redirect 执行 HTTP 重定向 // Redirect 执行 HTTP 重定向
@ -481,10 +634,16 @@ func (c *Context) Redirect(code int, location string) {
// ShouldBindJSON 尝试将请求体绑定到 JSON 对象 // ShouldBindJSON 尝试将请求体绑定到 JSON 对象
func (c *Context) ShouldBindJSON(obj any) error { func (c *Context) ShouldBindJSON(obj any) error {
if c.Request.Body == nil { var body io.ReadCloser
if c.MaxRequestBodySize > 0 {
body = c.prepareRequestBody()
} else {
body = c.Request.Body
}
if body == nil {
return errors.New("request body is empty") return errors.New("request body is empty")
} }
err := json.UnmarshalRead(c.Request.Body, obj) err := json.UnmarshalRead(body, obj)
if err != nil { if err != nil {
return fmt.Errorf("json binding error: %w", err) return fmt.Errorf("json binding error: %w", err)
} }
@ -493,10 +652,16 @@ func (c *Context) ShouldBindJSON(obj any) error {
// ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象 // ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象
func (c *Context) ShouldBindWANF(obj any) error { func (c *Context) ShouldBindWANF(obj any) error {
if c.Request.Body == nil { var body io.ReadCloser
if c.MaxRequestBodySize > 0 {
body = c.prepareRequestBody()
} else {
body = c.Request.Body
}
if body == nil {
return errors.New("request body is empty") return errors.New("request body is empty")
} }
decoder, err := wanf.NewStreamDecoder(c.Request.Body) decoder, err := wanf.NewStreamDecoder(body)
if err != nil { if err != nil {
return fmt.Errorf("failed to create WANF decoder: %w", err) return fmt.Errorf("failed to create WANF decoder: %w", err)
} }
@ -509,10 +674,16 @@ func (c *Context) ShouldBindWANF(obj any) error {
// ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象 // ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象
func (c *Context) ShouldBindGOB(obj any) error { func (c *Context) ShouldBindGOB(obj any) error {
if c.Request.Body == nil { var body io.ReadCloser
if c.MaxRequestBodySize > 0 {
body = c.prepareRequestBody()
} else {
body = c.Request.Body
}
if body == nil {
return errors.New("request body is empty") return errors.New("request body is empty")
} }
decoder := gob.NewDecoder(c.Request.Body) decoder := gob.NewDecoder(body)
if err := decoder.Decode(obj); err != nil { if err := decoder.Decode(obj); err != nil {
return fmt.Errorf("GOB binding error: %w", err) return fmt.Errorf("GOB binding error: %w", err)
} }
@ -629,6 +800,10 @@ func setFieldValue(field reflect.Value, values []string) error {
// ShouldBindForm 尝试将表单数据绑定到结构体 // ShouldBindForm 尝试将表单数据绑定到结构体
// 支持 application/x-www-form-urlencoded 和 multipart/form-data // 支持 application/x-www-form-urlencoded 和 multipart/form-data
func (c *Context) ShouldBindForm(obj any) error { func (c *Context) ShouldBindForm(obj any) error {
if c.MaxRequestBodySize > 0 {
c.prepareRequestBody()
}
contentType := c.Request.Header.Get("Content-Type") contentType := c.Request.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(contentType) mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil { if err != nil {
@ -637,7 +812,7 @@ func (c *Context) ShouldBindForm(obj any) error {
switch mediaType { switch mediaType {
case "multipart/form-data": case "multipart/form-data":
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
return fmt.Errorf("parse multipart form error: %w", err) return fmt.Errorf("parse multipart form error: %w", err)
} }
case "application/x-www-form-urlencoded": case "application/x-www-form-urlencoded":
@ -651,6 +826,7 @@ func (c *Context) ShouldBindForm(obj any) error {
if err := bindForm(c.Request.Form, obj); err != nil { if err := bindForm(c.Request.Form, obj); err != nil {
return fmt.Errorf("form binding error: %w", err) return fmt.Errorf("form binding error: %w", err)
} }
c.formCache = c.Request.PostForm
return nil return nil
} }
@ -688,11 +864,30 @@ func (c *Context) GetErrors() []error {
return c.Errors return c.Errors
} }
// Client 返回 Engine 提供的 HTTPClient // Client 返回当前请求的 HTTPClient
// 方便在请求处理函数中进行出站 HTTP 请求 // 如果请求处理函数或中间件设置了自定义 HTTPClient返回该实例
// 否则返回 Engine 提供的默认实例
//
// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context
func (c *Context) Client() *httpc.Client { func (c *Context) Client() *httpc.Client {
if c.HTTPClient != nil {
return c.HTTPClient return c.HTTPClient
} }
return c.engine.HTTPClient
}
// HTTPC 返回自动关联请求 Context 的 HTTP 客户端
// 当请求被取消时,通过此客户端发起的出站请求也会自动取消
func (c *Context) HTTPC() *contextHTTPClient {
client := c.HTTPClient
if client == nil {
client = c.engine.HTTPClient
}
return &contextHTTPClient{
client: client,
ctx: c.ctx,
}
}
// Context() 返回请求的上下文,用于取消操作 // Context() 返回请求的上下文,用于取消操作
// 这是 Go 标准库的 `context.Context`,用于请求的取消和超时管理 // 这是 Go 标准库的 `context.Context`,用于请求的取消和超时管理
@ -751,37 +946,30 @@ func (c *Context) WriteStream(reader io.Reader) (written int64, err error) {
// GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体 // GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体
// 注意:请求体只能读取一次 // 注意:请求体只能读取一次
func (c *Context) GetReqBody() io.ReadCloser { func (c *Context) GetReqBody() io.ReadCloser {
if c.MaxRequestBodySize > 0 {
return c.prepareRequestBody()
}
if c.Request == nil || c.Request.Body == nil {
return nil
}
return c.Request.Body return c.Request.Body
} }
// GetReqBodyFull 读取并返回请求体的所有内容 // GetReqBodyFull 读取并返回请求体的所有内容
// 注意:请求体只能读取一次 // 注意:请求体只能读取一次
func (c *Context) GetReqBodyFull() ([]byte, error) { func (c *Context) GetReqBodyFull() ([]byte, error) {
if c.Request.Body == nil { body := c.GetReqBody()
if body == nil {
return nil, nil return nil, nil
} }
var limitBytesReader io.ReadCloser
if c.MaxRequestBodySize > 0 {
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
defer func() { defer func() {
err := limitBytesReader.Close() err := body.Close()
if err != nil { if err != nil {
c.AddError(fmt.Errorf("failed to close request body: %w", err)) c.AddError(fmt.Errorf("failed to close request body: %w", err))
} }
}() }()
} else {
limitBytesReader = c.Request.Body
defer func() {
err := limitBytesReader.Close()
if err != nil {
c.AddError(fmt.Errorf("failed to close request body: %w", err))
}
}()
}
data, err := iox.ReadAll(limitBytesReader) data, err := io.ReadAll(body)
if err != nil { if err != nil {
c.AddError(fmt.Errorf("failed to read request body: %w", err)) c.AddError(fmt.Errorf("failed to read request body: %w", err))
return nil, fmt.Errorf("failed to read request body: %w", err) return nil, fmt.Errorf("failed to read request body: %w", err)
@ -791,31 +979,18 @@ func (c *Context) GetReqBodyFull() ([]byte, error) {
// 类似 GetReqBodyFull, 返回 *bytes.Buffer // 类似 GetReqBodyFull, 返回 *bytes.Buffer
func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) { func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
if c.Request.Body == nil { body := c.GetReqBody()
if body == nil {
return nil, nil return nil, nil
} }
var limitBytesReader io.ReadCloser
if c.MaxRequestBodySize > 0 {
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
defer func() { defer func() {
err := limitBytesReader.Close() err := body.Close()
if err != nil { if err != nil {
c.AddError(fmt.Errorf("failed to close request body: %w", err)) c.AddError(fmt.Errorf("failed to close request body: %w", err))
} }
}() }()
} else {
limitBytesReader = c.Request.Body
defer func() {
err := limitBytesReader.Close()
if err != nil {
c.AddError(fmt.Errorf("failed to close request body: %w", err))
}
}()
}
data, err := iox.ReadAll(limitBytesReader) data, err := io.ReadAll(body)
if err != nil { if err != nil {
c.AddError(fmt.Errorf("failed to read request body: %w", err)) c.AddError(fmt.Errorf("failed to read request body: %w", err))
return nil, fmt.Errorf("failed to read request body: %w", err) return nil, fmt.Errorf("failed to read request body: %w", err)
@ -974,14 +1149,9 @@ func (c *Context) GetProtocol() string {
return c.Request.Proto return c.Request.Proto
} }
// GetHTTPC 获取框架自带传递的httpc // GetLogger 获取engine的Logger接口
func (c *Context) GetHTTPC() *httpc.Client { func (c *Context) GetLogger() Logger {
return c.HTTPClient return c.engine.logger
}
// GetLogger 获取engine的Logger
func (c *Context) GetLogger() *reco.Logger {
return c.engine.LogReco
} }
// GetReqQueryString // GetReqQueryString
@ -1084,17 +1254,25 @@ func (c *Context) SetSameSite(samesite http.SameSite) {
} }
// SetCookie 设置一个 HTTP cookie // SetCookie 设置一个 HTTP cookie
func (c *Context) SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool) { // sameSite 参数是可选的,如果不提供则使用通过 SetSameSite 设置的值
func (c *Context) SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool, sameSite ...http.SameSite) {
if path == "" { if path == "" {
path = "/" path = "/"
} }
site := c.sameSite
if len(sameSite) > 0 {
if len(sameSite) > 1 {
c.Warnf("SetCookie: only the first SameSite value will be used, got %d values", len(sameSite))
}
site = sameSite[0]
}
http.SetCookie(c.Writer, &http.Cookie{ http.SetCookie(c.Writer, &http.Cookie{
Name: name, Name: name,
Value: url.QueryEscape(value), Value: url.QueryEscape(value),
MaxAge: maxAge, MaxAge: maxAge,
Path: path, Path: path,
Domain: domain, Domain: domain,
SameSite: c.sameSite, SameSite: site,
Secure: secure, Secure: secure,
HttpOnly: httpOnly, HttpOnly: httpOnly,
}) })
@ -1132,25 +1310,25 @@ func (c *Context) DeleteCookie(name string) {
// === 日志记录 === // === 日志记录 ===
func (c *Context) Debugf(format string, args ...any) { func (c *Context) Debugf(format string, args ...any) {
c.engine.LogReco.Debugf(format, args...) c.engine.logger.Debugf(format, args...)
} }
func (c *Context) Infof(format string, args ...any) { func (c *Context) Infof(format string, args ...any) {
c.engine.LogReco.Infof(format, args...) c.engine.logger.Infof(format, args...)
} }
func (c *Context) Warnf(format string, args ...any) { func (c *Context) Warnf(format string, args ...any) {
c.engine.LogReco.Warnf(format, args...) c.engine.logger.Warnf(format, args...)
} }
func (c *Context) Errorf(format string, args ...any) { func (c *Context) Errorf(format string, args ...any) {
c.engine.LogReco.Errorf(format, args...) c.engine.logger.Errorf(format, args...)
} }
func (c *Context) Fatalf(format string, args ...any) { func (c *Context) Fatalf(format string, args ...any) {
c.engine.LogReco.Fatalf(format, args...) c.engine.logger.Fatalf(format, args...)
} }
func (c *Context) Panicf(format string, args ...any) { func (c *Context) Panicf(format string, args ...any) {
c.engine.LogReco.Panicf(format, args...) c.engine.logger.Panicf(format, args...)
} }

81
context_benchmark_test.go Normal file
View file

@ -0,0 +1,81 @@
package touka
import (
"net/http"
"testing"
)
func TestContextResetKeepsKeysNilUntilSet(t *testing.T) {
c, _ := CreateTestContext(nil)
if c.Keys != nil {
t.Fatalf("expected fresh test context Keys to be nil before first Set")
}
c.Set("answer", 42)
if c.Keys == nil {
t.Fatalf("expected Set to allocate Keys map")
}
if value, exists := c.Get("answer"); !exists || value != 42 {
t.Fatalf("expected stored value to round-trip, got %v, %t", value, exists)
}
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatalf("failed to build request: %v", err)
}
c.reset(UnwrapResponseWriter(c.Writer), req)
if c.Keys != nil {
t.Fatalf("expected reset to clear Keys without allocating a new map")
}
if value, exists := c.Get("answer"); exists || value != nil {
t.Fatalf("expected cleared keys after reset, got %v, %t", value, exists)
}
ctxValue := c.Value("missing")
if ctxValue != nil {
t.Fatalf("expected nil value for missing context key after reset, got %v", ctxValue)
}
defer func() {
if r := recover(); r == nil {
t.Fatalf("expected MustGet to panic for missing key after reset")
}
}()
_ = c.MustGet("answer")
}
func BenchmarkContextReset(b *testing.B) {
b.Run("NoKeysUse", func(b *testing.B) {
c, _ := CreateTestContext(nil)
rawWriter := UnwrapResponseWriter(c.Writer)
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
b.Fatalf("failed to build request: %v", err)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.reset(rawWriter, req)
}
})
b.Run("WithKeysUse", func(b *testing.B) {
c, _ := CreateTestContext(nil)
rawWriter := UnwrapResponseWriter(c.Writer)
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
b.Fatalf("failed to build request: %v", err)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.reset(rawWriter, req)
c.Set("request-id", i)
}
})
}

174
context_bodylimit_test.go Normal file
View file

@ -0,0 +1,174 @@
package touka
import (
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
type zeroNilThenEOFReader struct {
readCalls int
}
func (r *zeroNilThenEOFReader) Read(_ []byte) (int, error) {
r.readCalls++
if r.readCalls == 1 {
return 0, nil
}
return 0, io.EOF
}
func (r *zeroNilThenEOFReader) Close() error {
return nil
}
func TestFileTextUsesProvidedStatusCode(t *testing.T) {
t.Helper()
dir := t.TempDir()
filePath := filepath.Join(dir, "hello.txt")
if err := os.WriteFile(filePath, []byte("hello touka"), 0o644); err != nil {
t.Fatalf("write temp file: %v", err)
}
rr := httptest.NewRecorder()
c, _ := CreateTestContext(rr)
c.FileText(http.StatusCreated, filePath)
if rr.Code != http.StatusCreated {
t.Fatalf("expected status %d, got %d", http.StatusCreated, rr.Code)
}
if got := rr.Header().Get("Content-Type"); got != "text/plain; charset=utf-8" {
t.Fatalf("unexpected content type: %q", got)
}
if body := rr.Body.String(); body != "hello touka" {
t.Fatalf("unexpected body: %q", body)
}
}
func TestMaxBytesReaderAllowsExactLimit(t *testing.T) {
t.Helper()
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcd")), 4)
defer reader.Close()
data, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("expected exact limit read to succeed, got %v", err)
}
if string(data) != "abcd" {
t.Fatalf("unexpected data: %q", string(data))
}
}
func TestMaxBytesReaderRejectsOverLimit(t *testing.T) {
t.Helper()
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcde")), 4)
defer reader.Close()
_, err := io.ReadAll(reader)
if !errors.Is(err, ErrBodyTooLarge) {
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
}
}
func TestMaxBytesReaderAllowsZeroNilThenEOFAtExactLimit(t *testing.T) {
t.Helper()
reader := NewMaxBytesReader(&zeroNilThenEOFReader{}, 1)
defer reader.Close()
buf := make([]byte, 1)
n, err := reader.Read(buf)
if n != 0 || err != nil {
t.Fatalf("expected initial zero,nil read result, got n=%d err=%v", n, err)
}
n, err = reader.Read(buf)
if n != 0 || !errors.Is(err, io.EOF) {
t.Fatalf("expected EOF after retry, got n=%d err=%v", n, err)
}
}
func TestMaxBytesReaderTreatsZeroLimitAsUnlimited(t *testing.T) {
t.Helper()
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abc")), 0)
defer reader.Close()
data, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("expected zero limit to leave body unlimited, got %v", err)
}
if string(data) != "abc" {
t.Fatalf("unexpected data: %q", string(data))
}
}
func TestShouldBindJSONHonorsMaxRequestBodySize(t *testing.T) {
t.Helper()
body := strings.NewReader(`{"name":"abcdef"}`)
req := httptest.NewRequest(http.MethodPost, "/json", body)
req.Header.Set("Content-Type", "application/json")
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
c.SetMaxRequestBodySize(8)
var payload struct {
Name string `json:"name"`
}
err := c.ShouldBindJSON(&payload)
if !errors.Is(err, ErrBodyTooLarge) {
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
}
}
func TestShouldBindFormHonorsMaxRequestBodySize(t *testing.T) {
t.Helper()
body := strings.NewReader("name=abcdef")
req := httptest.NewRequest(http.MethodPost, "/form", body)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
c.SetMaxRequestBodySize(4)
var payload struct {
Name string `form:"name"`
}
err := c.ShouldBindForm(&payload)
if !errors.Is(err, ErrBodyTooLarge) {
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
}
}
func TestPostFormHonorsMaxRequestBodySize(t *testing.T) {
t.Helper()
body := strings.NewReader("name=abcdef")
req := httptest.NewRequest(http.MethodPost, "/form", body)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
c.SetMaxRequestBodySize(4)
if got := c.PostForm("name"); got != "" {
t.Fatalf("expected empty value on over-limit form body, got %q", got)
}
if len(c.Errors) == 0 {
t.Fatal("expected parse error to be recorded")
}
if !errors.Is(c.Errors[0], ErrBodyTooLarge) {
t.Fatalf("expected recorded error to wrap ErrBodyTooLarge, got %v", c.Errors[0])
}
}

58
context_httpc.go Normal file
View file

@ -0,0 +1,58 @@
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2024 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
import (
"context"
"github.com/WJQSERVER-STUDIO/httpc"
)
// contextHTTPClient 包装 httpc.Client自动关联请求的 Context
// 当请求被取消时,出站 HTTP 请求也会自动取消
type contextHTTPClient struct {
client *httpc.Client
ctx context.Context
}
// NewRequestBuilder 创建请求构建器,自动关联请求 Context
func (c *contextHTTPClient) NewRequestBuilder(method, urlStr string) *httpc.RequestBuilder {
return c.client.NewRequestBuilder(method, urlStr).WithContext(c.ctx)
}
// GET 创建 GET 请求构建器
func (c *contextHTTPClient) GET(urlStr string) *httpc.RequestBuilder {
return c.client.GET(urlStr).WithContext(c.ctx)
}
// POST 创建 POST 请求构建器
func (c *contextHTTPClient) POST(urlStr string) *httpc.RequestBuilder {
return c.client.POST(urlStr).WithContext(c.ctx)
}
// PUT 创建 PUT 请求构建器
func (c *contextHTTPClient) PUT(urlStr string) *httpc.RequestBuilder {
return c.client.PUT(urlStr).WithContext(c.ctx)
}
// DELETE 创建 DELETE 请求构建器
func (c *contextHTTPClient) DELETE(urlStr string) *httpc.RequestBuilder {
return c.client.DELETE(urlStr).WithContext(c.ctx)
}
// PATCH 创建 PATCH 请求构建器
func (c *contextHTTPClient) PATCH(urlStr string) *httpc.RequestBuilder {
return c.client.PATCH(urlStr).WithContext(c.ctx)
}
// HEAD 创建 HEAD 请求构建器
func (c *contextHTTPClient) HEAD(urlStr string) *httpc.RequestBuilder {
return c.client.HEAD(urlStr).WithContext(c.ctx)
}
// OPTIONS 创建 OPTIONS 请求构建器
func (c *contextHTTPClient) OPTIONS(urlStr string) *httpc.RequestBuilder {
return c.client.OPTIONS(urlStr).WithContext(c.ctx)
}

View file

@ -44,7 +44,9 @@ r.SetTLSServerConfigurator(func(server *http.Server) {
Touka 支持配置 HTTP/1.1、HTTP/2 和 H2CHTTP/2 Cleartext Touka 支持配置 HTTP/1.1、HTTP/2 和 H2CHTTP/2 Cleartext
```go ```go
// 使用默认协议配置(仅 HTTP/1.1 // 使用默认协议配置
// 普通 HTTP 启动时默认为 HTTP/1.1;若使用 WithTLS(...) 且未手动覆盖协议集,
// HTTPS 服务器会默认启用 HTTP/1.1 与 HTTP/2。
r.SetDefaultProtocols() r.SetDefaultProtocols()
// 自定义协议配置 // 自定义协议配置
@ -57,33 +59,147 @@ r.SetProtocols(&touka.ProtocolsConfig{
### 启动方式 ### 启动方式
Touka 提供了多种服务器启动方式 Touka 统一通过 `Run(opts...)` 启动服务器
```go ```go
// 1. 简单启动(无优雅停机) // 1. 简单启动(无优雅停机)
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
// 2. 带优雅停机的启动 // 2. 带优雅停机的启动
r.RunShutdown(":8080", 10*time.Second) r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second))
// 3. 带上下文的优雅停机 // 3. 带上下文的优雅停机
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
r.RunShutdownWithContext(":8080", ctx, 10*time.Second) defer cancel()
r.Run(
touka.WithAddr(":8080"),
touka.WithGracefulShutdown(10*time.Second),
touka.WithShutdownContext(ctx),
)
// 4. HTTPS 启动 // 4. HTTPS 启动
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
// 其他 TLS 配置... // 其他 TLS 配置...
} }
r.RunTLS(":443", tlsConfig, 10*time.Second) // WithTLS(...) 与优雅关闭相互独立;这里演示 HTTPS + 默认优雅关闭超时。
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithGracefulShutdownDefault(),
)
// 5. HTTPS + HTTP 重定向 // 5. HTTPS + HTTP 重定向
r.RunTLSRedir(":80", ":443", tlsConfig, 10*time.Second) // WithHTTPRedirect(...) 需要与 WithTLS(...) 配合使用。
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(":80"),
touka.WithGracefulShutdown(10*time.Second),
)
// 6. HTTPS + HTTP 重定向(按 header 顺序决定跳转 host
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(true),
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
),
)
// 7. HTTPS + HTTP 重定向(固定跳转到配置的 host
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(false),
touka.WithRedirectHost("example.com"),
),
)
``` ```
### HTTPS Redirect Host 策略
`WithHTTPRedirect(addr, opts...)` 除了开启 HTTP -> HTTPS 重定向外,还支持通过 redirect 子选项控制最终跳转目标的 host。
可用的 redirect 子选项:
- `touka.WithUseHeaderHost(true|false)`
- `touka.WithRedirectHostHeaders([]string{...})`
- `touka.WithRedirectHost("example.com")`
#### 模式一:使用请求输入侧的 host
`WithUseHeaderHost(true)` 时:
- 如果没有配置 `WithRedirectHostHeaders(...)`,使用 `Request.Host`
- 如果配置了 `WithRedirectHostHeaders(...)`,按给定顺序读取这些 header并使用第一个非空值
- 如果配置了 `WithRedirectHostHeaders(...)` 但所有 header 都为空,返回 `426 Upgrade Required`
示例:
```go
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(true),
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
),
)
```
#### 模式二:使用配置的固定 host
`WithUseHeaderHost(false)` 时:
- 不读取 `Request.Host`
- 不读取 `WithRedirectHostHeaders(...)`
- 必须配置 `WithRedirectHost("example.com")`
示例:
```go
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(false),
touka.WithRedirectHost("example.com"),
),
)
```
#### 严格校验规则
以下组合会直接返回配置错误:
- `WithHTTPRedirect(...)` 但没有 `WithTLS(...)`
- 配置了 `WithRedirectHostHeaders(...)`,但没有显式传入 `WithUseHeaderHost(true)`
- `WithUseHeaderHost(false)` 但没有配置 `WithRedirectHost(...)`
- `WithUseHeaderHost(false)` 同时配置了 `WithRedirectHostHeaders(...)`
- `WithUseHeaderHost(true)` 同时配置了 `WithRedirectHost(...)`
#### 优先级关系
1. 是否启用 `WithHTTPRedirect(...)` 决定是否进入 HTTPS + redirect 模式
2. `WithUseHeaderHost(...)` 决定 host 来源模式
3. 当 `WithUseHeaderHost(true)` 时:
- 配置了 `WithRedirectHostHeaders(...)` 就按 header 顺序查询
- 未配置时使用 `Request.Host`
4. 当 `WithUseHeaderHost(false)` 时:
- 只使用 `WithRedirectHost(...)`
**注意:** `WithRedirectHostHeaders(...)` 读取的是普通请求头值。只有在您明确知道请求经过受信任代理并会正确填充这些 header 时,才建议启用它。
## 优雅停机 (Graceful Shutdown) ## 优雅停机 (Graceful Shutdown)
在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。 在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。
```go ```go
r := touka.Default() r := touka.Default()
@ -91,7 +207,7 @@ r := touka.Default()
// 监听 SIGINT 和 SIGTERM 信号 // 监听 SIGINT 和 SIGTERM 信号
// 如果在 10 秒内未处理完,则强制关闭 // 如果在 10 秒内未处理完,则强制关闭
if err := r.RunShutdown(":8080", 10*time.Second); err != nil { if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
log.Fatal("服务器退出异常:", err) log.Fatal("服务器退出异常:", err)
} }
``` ```

188
docs/httpc.md Normal file
View file

@ -0,0 +1,188 @@
# HTTP Client (httpc)
Touka 内置了 [httpc](https://github.com/WJQSERVER-STUDIO/httpc) HTTP 客户端,方便在请求处理函数中发起出站 HTTP 请求。
## 核心特性
- **自动 Context 关联**:使用 `HTTPC()` 方法时,出站请求会自动关联当前请求的 Context
- **请求取消传播**:当客户端断开连接时,出站请求会自动取消,避免资源泄漏
- **链式调用**:保持 httpc 原有的组合式构建器风格
## 基本用法
### 简单 GET 请求
```go
r.GET("/proxy", func(c *touka.Context) {
body, err := c.HTTPC().
GET("https://api.example.com/data").
Text()
if err != nil {
c.JSON(500, touka.H{"error": err.Error()})
return
}
c.String(200, body)
})
```
### POST JSON 请求
```go
r.POST("/users", func(c *touka.Context) {
var req struct {
Name string `json:"name"`
Email string `json:"email"`
}
c.ShouldBindJSON(&req)
var result struct {
ID int `json:"id"`
Name string `json:"name"`
}
err := c.HTTPC().
POST("https://api.example.com/users").
SetHeader("Authorization", "Bearer "+token).
SetJSONBody(req).
DecodeJSON(&result)
if err != nil {
c.JSON(500, touka.H{"error": err.Error()})
return
}
c.JSON(200, result)
})
```
### 带查询参数
```go
r.GET("/search", func(c *touka.Context) {
query := c.Query("q")
var result SearchResult
err := c.HTTPC().
GET("https://api.example.com/search").
SetQueryParam("q", query).
SetQueryParam("limit", "10").
DecodeJSON(&result)
if err != nil {
c.JSON(500, touka.H{"error": err.Error()})
return
}
c.JSON(200, result)
})
```
## API 对比
### 旧方式Deprecated
```go
// 需要手动 WithContext容易忘记
resp, err := c.Client().
WithContext(c.Context()).
GET(url).
Execute()
```
### 新方式(推荐)
```go
// 自动关联请求 Context
resp, err := c.HTTPC().
GET(url).
Execute()
```
## Context 取消机制
使用 `HTTPC()` 时,当客户端断开连接(如关闭浏览器),出站请求会自动取消:
```go
r.GET("/long-task", func(c *touka.Context) {
// 这个请求会在客户端断开时自动取消
resp, err := c.HTTPC().
GET("https://slow-api.example.com/data").
Execute()
// 如果客户端已断开err 会包含 context.Canceled
if errors.Is(err, context.Canceled) {
return // 客户端已断开,无需处理
}
// ...
})
```
## 完整 API
### contextHTTPClient 方法
| 方法 | 返回类型 | 说明 |
|------|----------|------|
| `NewRequestBuilder(method, url)` | `*httpc.RequestBuilder` | 创建通用请求构建器 |
| `GET(url)` | `*httpc.RequestBuilder` | 创建 GET 请求 |
| `POST(url)` | `*httpc.RequestBuilder` | 创建 POST 请求 |
| `PUT(url)` | `*httpc.RequestBuilder` | 创建 PUT 请求 |
| `DELETE(url)` | `*httpc.RequestBuilder` | 创建 DELETE 请求 |
| `PATCH(url)` | `*httpc.RequestBuilder` | 创建 PATCH 请求 |
| `HEAD(url)` | `*httpc.RequestBuilder` | 创建 HEAD 请求 |
| `OPTIONS(url)` | `*httpc.RequestBuilder` | 创建 OPTIONS 请求 |
### httpc.RequestBuilder 链式方法
返回 `*httpc.RequestBuilder`(用于链式调用):
| 方法 | 说明 |
|------|------|
| `WithContext(ctx)` | 设置 Context通常不需要已自动关联 |
| `NoDefaultHeaders()` | 不添加默认 Header |
| `SetHeader(key, value)` | 设置 Header |
| `AddHeader(key, value)` | 添加 Header可重复 |
| `SetHeaders(map)` | 批量设置 Headers |
| `SetQueryParam(key, value)` | 设置查询参数 |
| `AddQueryParam(key, value)` | 添加查询参数(可重复) |
| `SetQueryParams(map)` | 批量设置查询参数 |
| `SetBody(io.Reader)` | 设置请求 Body |
| `SetRawBody([]byte)` | 设置字节 Body |
返回 `(*httpc.RequestBuilder, error)`(可能失败):
| 方法 | 说明 |
|------|------|
| `SetJSONBody(any)` | 设置 JSON Body |
| `SetXMLBody(any)` | 设置 XML Body |
| `SetGOBBody(any)` | 设置 GOB Body |
### 终结方法
| 方法 | 返回类型 | 说明 |
|------|----------|------|
| `Build()` | `(*http.Request, error)` | 构建请求但不执行 |
| `Execute()` | `(*http.Response, error)` | 执行并返回原始响应 |
| `DecodeJSON(v)` | `error` | 执行并解码 JSON |
| `DecodeXML(v)` | `error` | 执行并解码 XML |
| `DecodeGOB(v)` | `error` | 执行并解码 GOB |
| `Text()` | `(string, error)` | 执行并返回文本 |
| `Bytes()` | `([]byte, error)` | 执行并返回字节 |
| `SSE()` | `(*SSEStream, error)` | 建立 SSE 流连接 |
## 迁移指南
### go:fix inline 兼容
旧代码 `c.GetHTTPC()` 可通过 `go fix` 自动迁移到 `c.Client()`
```bash
go fix ./...
```
### 手动迁移
| 旧代码 | 新代码 |
|--------|--------|
| `c.GetHTTPC()` | `c.Client()``c.HTTPC()` |
| `c.Client().WithContext(ctx).GET(url)` | `c.HTTPC().GET(url)` |
## 示例
完整示例请参考 [examples/httpc](../examples/httpc)。

View file

@ -22,6 +22,6 @@ Touka 是一个基于 Go 语言构建的高性能、多层次 Web 框架。其
1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。 1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。
2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。 2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。
3. **健壮性**: 内置优雅停机支持,确保在服务器更新或关闭时请求能得到正确处理 3. **健壮性**: 通过 `Run(...)` 的启动选项提供优雅停机支持,使服务在更新或关闭时能更稳妥地处理进行中的请求
Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。 Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。

View file

@ -0,0 +1,400 @@
# Touka Logger 接口迁移方案
## 基于 Go 1.26 `go:fix inline` 的自动化迁移设计
---
## 一、问题分析
当前架构问题:
```
Engine.LogReco → *reco.Logger (公开字段, 直接访问)
Context.GetLogger() → 返回 *reco.Logger (具体类型)
Context.Debugf/Infof... → 硬编码 c.engine.LogReco.Debugf(...)
```
这导致用户无法替换日志实现(如 zap/logrus
---
## 二、目标架构
```
Engine.logger → Logger 接口 (私有)
Engine.LogReco → *reco.Logger (公开, Deprecated - 保持向后兼容)
Engine.GetLogger() → 返回 Logger 接口
Engine.SetLogger(Logger)→ 设置日志实现
Context.GetLogger() → 返回 Logger 接口
Context.Debugf/Infof... → 调用 c.engine.logger.Debugf(...)
```
---
## 三、Logger 接口定义
```go
// logger.go
package touka
// Logger 是日志接口,支持任意日志库实现
type Logger interface {
Debugf(format string, args ...any)
Infof(format string, args ...any)
Warnf(format string, args ...any)
Errorf(format string, args ...any)
Fatalf(format string, args ...any)
Panicf(format string, args ...any)
}
// CloserLogger 可选扩展,支持关闭操作
type CloserLogger interface {
Logger
Close() error
}
```
---
## 四、Engine 结构变更
```go
// engine.go 变更
type Engine struct {
// ... 其他字段保持不变
// logger 是新的日志接口 (私有)
logger Logger
// logReco 是保留的 reco.Logger 引用 (私有)
// 用于向后兼容,当通过 SetLoggerReco 设置时同步到 logger
logReco *reco.Logger
// 其他字段...
}
```
新增/修改方法:
```go
// GetLogger 返回日志接口
func (engine *Engine) GetLogger() Logger {
return engine.logger
}
// SetLogger 设置任意 Logger 实现
func (engine *Engine) SetLogger(l Logger) {
engine.logger = l
// 如果是 *reco.Logger 类型,同步更新 logReco
if rl, ok := l.(*reco.Logger); ok {
engine.logReco = rl
} else {
engine.logReco = nil
}
}
// SetLoggerCfg 使用 reco.Config 配置日志
func (engine *Engine) SetLoggerCfg(logcfg reco.Config) {
logger := NewLogger(logcfg)
engine.logger = logger
engine.logReco = logger
}
```
---
## 五、`go:fix inline` 兼容性函数
### 5.1 旧 API 包装函数
`compat.go` 中定义:
```go
// compat.go
package touka
import "github.com/fenthope/reco"
// GetLogReco 返回 reco.Logger用于向后兼容
//
//go:fix inline
func (engine *Engine) GetLogReco() *reco.Logger {
return engine.logReco
}
// SetLogReco 设置 reco.Logger用于向后兼容
//
//go:fix inline
func (engine *Engine) SetLogReco(l *reco.Logger) {
engine.logReco = l
engine.logger = l
}
```
### 5.2 Context 日志方法的 inline 包装
```go
// context_compat.go
package touka
// Debugf 记录 Debug 级别日志
//
//go:fix inline
func (c *Context) Debugf(format string, args ...any) {
c.engine.logger.Debugf(format, args...)
}
// Infof 记录 Info 级别日志
//
//go:fix inline
func (c *Context) Infof(format string, args ...any) {
c.engine.logger.Infof(format, args...)
}
// Warnf 记录 Warn 级别日志
//
//go:fix inline
func (c *Context) Warnf(format string, args ...any) {
c.engine.logger.Warnf(format, args...)
}
// Errorf 记录 Error 级别日志
//
//go:fix inline
func (c *Context) Errorf(format string, args ...any) {
c.engine.logger.Errorf(format, args...)
}
// Fatalf 记录 Fatal 级别日志
//
//go:fix inline
func (c *Context) Fatalf(format string, args ...any) {
c.engine.logger.Fatalf(format, args...)
}
// Panicf 记录 Panic 级别日志
//
//go:fix inline
func (c *Context) Panicf(format string, args ...any) {
c.engine.logger.Panicf(format, args...)
}
```
### 5.3 GetLogger 返回类型的兼容处理
由于 `GetLogger()` 返回类型从 `*reco.Logger` 变为 `Logger`,需要提供兼容函数:
```go
// context_compat.go (续)
// GetLoggerReco 返回 *reco.Logger 类型,用于需要具体类型的场景
//
//go:fix inline
func (c *Context) GetLoggerReco() *reco.Logger {
if rl, ok := c.engine.logger.(*reco.Logger); ok {
return rl
}
return nil
}
```
---
## 六、go:fix inline 工作原理
### 迁移前用户代码:
```go
func handler(c *touka.Context) {
// 旧 API 调用
c.Debugf("request: %s", c.Request.URL.Path)
c.engine.LogReco.Infof("server started")
}
```
### go fix 执行后(自动替换):
```go
func handler(c *touka.Context) {
// Debugf 被替换为函数体
c.engine.logger.Debugf("request: %s", c.Request.URL.Path)
// LogReco 访问无法通过 inline 自动处理,需要手动迁移
// 或者通过 getter 调用
}
```
### 对于字段访问的处理策略:
`engine.LogReco` 字段访问无法直接用 `go:fix inline` 处理,采用以下策略:
1. **保留字段但标记 deprecated**:继续导出 `LogReco` 但文档标记为 deprecated
2. **提供 getter/setter**:通过 `go:fix inline` 提供 `GetLogReco/SetLogReco`
3. **渐进迁移**:用户可以在方便时手动迁移到 `GetLogger()/SetLogger()`
---
## 七、迁移前后对比
### 场景 1基本日志调用
**迁移前:**
```go
func myHandler(c *touka.Context) {
c.Debugf("processing request %s", c.Request.URL.Path)
c.Infof("user %s logged in", username)
c.Warnf("slow query: %v", duration)
c.Errorf("db error: %v", err)
}
```
**迁移后(自动替换):**
```go
func myHandler(c *touka.Context) {
c.engine.logger.Debugf("processing request %s", c.Request.URL.Path)
c.engine.logger.Infof("user %s logged in", username)
c.engine.logger.Warnf("slow query: %v", duration)
c.engine.logger.Errorf("db error: %v", err)
}
```
### 场景 2Engine 配置日志
**迁移前:**
```go
engine := touka.New()
engine.LogReco = myLogger // 直接赋值
logger := engine.LogReco // 直接读取
```
**迁移后(手动 + 自动混合):**
```go
engine := touka.New()
// 方式 1使用新 API推荐
engine.SetLogger(myLogger)
logger := engine.GetLogger()
// 方式 2通过 go:fix inline 自动替换为 getter
// engine.SetLogReco(myLogger) ← go fix 替换
// logger := engine.GetLogReco() ← go fix 替换
```
### 场景 3使用第三方日志库新功能
```go
import "go.uber.org/zap"
func main() {
zapLogger, _ := zap.NewProduction()
defer zapLogger.Sync()
engine := touka.New()
// 使用 zap 替代默认的 reco.Logger
engine.SetLogger(&ZapAdapter{logger: zapLogger})
engine.GET("/api", func(c *touka.Context) {
c.Infof("api called") // 自动使用 zap 输出
})
}
// ZapAdapter 适配 zap 到 touka.Logger 接口
type ZapAdapter struct {
logger *zap.Logger
}
func (z *ZapAdapter) Debugf(format string, args ...any) {
z.logger.Debug(fmt.Sprintf(format, args...))
}
func (z *ZapAdapter) Infof(format string, args ...any) {
z.logger.Info(fmt.Sprintf(format, args...))
}
func (z *ZapAdapter) Warnf(format string, args ...any) {
z.logger.Warn(fmt.Sprintf(format, args...))
}
func (z *ZapAdapter) Errorf(format string, args ...any) {
z.logger.Error(fmt.Sprintf(format, args...))
}
func (z *ZapAdapter) Fatalf(format string, args ...any) {
z.logger.Fatal(fmt.Sprintf(format, args...))
}
func (z *ZapAdapter) Panicf(format string, args ...any) {
z.logger.Panic(fmt.Sprintf(format, args...))
}
```
---
## 八、内部使用迁移
框架内部代码也需要迁移,将直接调用 `engine.LogReco` 改为 `engine.logger`
需要修改的文件:
- `context.go`: writeResponseBody 中的 `c.engine.LogReco.Errorf`
- `recovery.go`: 如有使用日志
- `logreco.go`: CloseLogger 方法
```go
// context.go 修改前
func (c *Context) writeResponseBody(data []byte, contextMsg string) {
if _, err := c.Writer.Write(data); err != nil {
if c.engine.LogReco != nil {
c.engine.LogReco.Errorf("%s: %v", contextMsg, err)
}
}
}
// context.go 修改后
func (c *Context) writeResponseBody(data []byte, contextMsg string) {
if _, err := c.Writer.Write(data); err != nil {
if c.engine.logger != nil {
c.engine.logger.Errorf("%s: %v", contextMsg, err)
}
}
}
```
---
## 九、完整文件结构
```
touka/
├── logger.go # Logger 接口定义
├── logreco.go # reco.Logger 相关工具函数
├── compat.go # go:fix inline 兼容性函数 (Engine)
├── context_compat.go # go:fix inline 兼容性函数 (Context)
├── engine.go # Engine 结构变更
├── context.go # Context 日志方法变更
└── ...
```
---
## 十、版本策略
| 版本 | 变更内容 |
|------|---------|
| v1.x | 引入 Logger 接口LogReco 标记 deprecated |
| v2.x | 移除 LogReco 公开字段,仅通过 getter/setter 访问 |
| v3.x | 移除 go:fix inline 兼容函数 |
---
## 十一、go:fix inline 限制说明
1. **字段访问无法自动迁移**`engine.LogReco` 字段访问需要用户手动修改
2. **返回类型变更需谨慎**`GetLogger()` 返回类型变更会导致依赖具体类型的代码失败
3. **inline 函数有大小限制**:函数体过大会影响内联效果
4. **跨包迁移**`go:fix inline` 支持跨包,但用户必须运行 `go fix`
---
## 十二、推荐迁移步骤
1. **框架侧**:添加 Logger 接口,添加 go:fix inline 函数
2. **用户侧**:运行 `go fix ./...` 自动迁移可处理的部分
3. **用户侧**:手动将 `engine.LogReco` 字段访问改为 `engine.SetLogger()/GetLogger()`
4. **用户侧**:如需使用第三方日志,实现 Logger 接口并通过 SetLogger 设置

View file

@ -26,6 +26,41 @@ api.Use(AuthMiddleware())
} }
``` ```
也可以在创建组时直接传入中间件:
```go
api := r.Group("/api", AuthMiddleware(), RateLimitMiddleware())
{
api.GET("/user", handleUser)
api.POST("/data", handleData)
}
```
### 路由级中间件
为单个路由注册中间件,仅对该路由生效。
```go
// 单个路由中间件
r.GET("/protected", AuthMiddleware(), func(c *touka.Context) {
c.String(http.StatusOK, "Protected content")
})
// 多个路由中间件(按顺序执行)
r.POST("/upload",
RateLimitMiddleware(),
AuthMiddleware(),
PermissionCheckMiddleware(),
func(c *touka.Context) {
// 处理上传
},
)
// 路由组中的单个路由也可以使用路由级中间件
api := r.Group("/api")
api.GET("/admin", AdminAuthMiddleware(), adminHandler)
```
## 编写自定义中间件 ## 编写自定义中间件
中间件的函数签名是 `touka.HandlerFunc` 中间件的函数签名是 `touka.HandlerFunc`
@ -67,6 +102,36 @@ func APIKeyAuth() touka.HandlerFunc {
} }
``` ```
## 中间件执行顺序
理解中间件的执行顺序对于构建正确的处理流程至关重要。**注意:注册顺序决定了执行逻辑**,中间件必须在注册路由之前调用(全局中间件应在创建组或定义路由前注册)。中间件按照以下顺序执行:
```go
// 全局中间件
r.Use(GlobalMiddleware1())
r.Use(GlobalMiddleware2())
// 组中间件
api := r.Group("/api", GroupMiddleware1())
api.Use(GroupMiddleware2())
// 路由级中间件
api.GET("/users", RouteMiddleware1(), RouteMiddleware2(), userHandler)
```
对于 `/api/users` 请求,执行顺序为:
1. `GlobalMiddleware1()` - 全局中间件
2. `GlobalMiddleware2()` - 全局中间件
3. `GroupMiddleware1()` - 路由组中间件
4. `GroupMiddleware2()` - 路由组中间件
5. `RouteMiddleware1()` - 路由级中间件
6. `RouteMiddleware2()` - 路由级中间件
7. `userHandler` - 最终处理函数
```
请求进入 → 全局中间件 → 路由组中间件 → 路由级中间件 → 最终处理函数 → 路由级中间件后置逻辑 → 路由组中间件后置逻辑 → 全局中间件后置逻辑 → 响应
```
## 内置中间件 ## 内置中间件
- **Recovery**: 捕获任何发生的 panic恢复运行并返回 500 错误。它还负责调用全局错误处理器。 - **Recovery**: 捕获任何发生的 panic恢复运行并返回 500 错误。它还负责调用全局错误处理器。

View file

@ -46,7 +46,7 @@ func main() {
// 4. 启动服务器并监听 8080 端口 // 4. 启动服务器并监听 8080 端口
log.Println("Touka server is running on :8080") log.Println("Touka server is running on :8080")
if err := r.Run(":8080"); err != nil { if err := r.Run(touka.WithAddr(":8080")); err != nil {
log.Fatalf("Server failed: %v", err) log.Fatalf("Server failed: %v", err)
} }
} }
@ -66,11 +66,11 @@ go run main.go
## 优雅停机 ## 优雅停机
在生产环境中,我们推荐使用 `RunShutdown` 方法来启动服务器,它会监听系统信号并在关闭前等待正在处理的请求完成 在生产环境中,我们推荐`Run` 追加优雅关闭选项。启用后Touka 会监听 `SIGINT`/`SIGTERM`,在关闭时取消活动请求的上下文,并在超时前等待正在处理的请求完成。如需由应用内部事件触发关闭,还可以额外配合 `touka.WithShutdownContext(ctx)`
```go ```go
// 等待 10 秒以处理剩余请求 // 等待 10 秒以处理剩余请求
if err := r.RunShutdown(":8080", 10*time.Second); err != nil { if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
log.Fatalf("Server forced to shutdown: %v", err) log.Fatalf("Server forced to shutdown: %v", err)
} }
``` ```

View file

@ -28,7 +28,7 @@ func main() {
Target: target, Target: target,
})) }))
_ = r.Run(":8080") _ = r.Run(touka.WithAddr(":8080"))
} }
``` ```
@ -60,10 +60,15 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
```go ```go
type ReverseProxyConfig struct { type ReverseProxyConfig struct {
Target *url.URL Target *url.URL
Targets []string
LoadBalancing ReverseProxyLoadBalancingConfig
PassiveHealth ReverseProxyPassiveHealthConfig
Transport http.RoundTripper Transport http.RoundTripper
FlushInterval time.Duration FlushInterval time.Duration
BufferPool BufferPool BufferPool BufferPool
AllowH2CUpstream bool
ModifyRequest func(*http.Request) ModifyRequest func(*http.Request)
ModifyResponse func(*http.Response) error ModifyResponse func(*http.Response) error
@ -78,12 +83,133 @@ type ReverseProxyConfig struct {
### `Target` ### `Target`
必填。表示后端目标地址,至少需要提供 `scheme``host` `Targets` 二选一。表示单个后端目标地址,至少需要提供 `scheme``host`
```go ```go
target, _ := url.Parse("http://backend:9000") target, _ := url.Parse("http://backend:9000")
``` ```
### `Targets`
可选。用于配置多个后端目标地址。
- `Target``Targets` 互斥,只能使用其中一种
- `Targets` 的每一项都必须是完整 URL
- 每个 target 仍然可以自带 base path 和 query
```go
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Targets: []string{
"http://127.0.0.1:9001/base?from=a",
"http://127.0.0.1:9002/base?from=b",
},
}))
```
这意味着不同 upstream 仍然可以保留各自的路径前缀和固定查询参数。
### `LoadBalancing`
用于配置 upstream 选择策略和重试行为。
```go
type ReverseProxyLoadBalancingConfig struct {
Policy ReverseProxyLBPolicy
Retries int
TryDuration time.Duration
TryInterval time.Duration
}
```
当前内置策略:
- `touka.LBRandom()`
- `touka.LBRoundRobin()`
- `touka.LBFirst()`
- `touka.LBLeastConn()`
- `touka.LBIPHash()`
- `touka.LBClientIPHash()`
- `touka.LBURIHash()`
- `touka.LBHeader("X-Upstream", fallback)`
- `touka.LBQuery("tenant", fallback)`
其中:
- `LBFirst()` 适合主备/故障转移顺序
- `LBHeader` / `LBQuery` 只有在对应 header/query **缺失**时才会走 fallback
- 如果 `LBHeader` / `LBQuery` 没有显式 fallback则默认回退到 `LBRandom()`
```go
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Targets: []string{
"http://127.0.0.1:9001",
"http://127.0.0.1:9002",
},
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
Policy: touka.LBHeader("X-Upstream", touka.LBFirst()),
Retries: 1,
},
}))
```
重试说明:
- 只对未开始收到上游响应的失败进行重试
- 默认仅对 RFC 定义的安全方法(`GET` / `HEAD` / `OPTIONS` / `TRACE`)重试
- `Retries` 表示额外重试次数
- `TryDuration` 表示总尝试时间预算;如果配置了它,会优先于重试次数控制停止时机
- `TryInterval` 表示两次重试之间的等待间隔
### `PassiveHealth`
用于配置被动健康检查。它不会后台探测 upstream而是根据真实代理请求的失败结果临时把某个 upstream 视为不健康。
```go
type ReverseProxyPassiveHealthConfig struct {
FailDuration time.Duration
MaxFails int
UnhealthyStatus []int
}
```
- `FailDuration > 0` 时启用被动健康跟踪
- `MaxFails <= 0` 时默认按 `1` 处理
- `UnhealthyStatus` 中的状态码会被记为一次失败,但当前请求仍会先收到该响应;后续请求才会绕过这个 upstream
```go
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Targets: []string{
"http://127.0.0.1:9001",
"http://127.0.0.1:9002",
},
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
Policy: touka.LBFirst(),
},
PassiveHealth: touka.ReverseProxyPassiveHealthConfig{
FailDuration: time.Minute,
UnhealthyStatus: []int{http.StatusServiceUnavailable},
},
}))
```
### `AllowH2CUpstream`
允许代理使用未加密 HTTP/2h2c`http://` upstream 通信。
- 默认关闭
- 这是一个显式配置项
- 启用后Touka 会为该 upstream 使用 h2c prior-knowledge 方式连接上游
- 这意味着上游本身也必须显式支持 h2c它不是“先试 h2c失败再自动回退到 h1”的协商模式
```go
r.GET("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Target: target,
AllowH2CUpstream: true,
}))
```
对于下游 HTTP/2 extended `CONNECT` websocket 场景Touka 会只在该特殊桥接路径上强制与上游使用 HTTP/1.1 websocket upgrade以匹配 Caddy 风格的桥接语义;普通 HTTP 请求不会因为这个特性而被强制降级为 HTTP/1.1。
### `Transport` ### `Transport`
可选。用于自定义底层转发所使用的 `http.RoundTripper` 可选。用于自定义底层转发所使用的 `http.RoundTripper`
@ -150,6 +276,8 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
在请求真正发往后端前,对出站请求做最后修改。 在请求真正发往后端前,对出站请求做最后修改。
如果启用了多 upstream 重试,`ModifyRequest` 可能会在同一个客户端请求里被调用多次:每一次实际发往 upstream 的尝试都会重新构造一份请求并再次执行它。因此,这个回调最好保持幂等,不要依赖“只会执行一次”的副作用。
常见用途: 常见用途:
- 覆盖 `Host` - 覆盖 `Host`
@ -242,11 +370,20 @@ const (
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Target: target, Target: target,
ForwardedHeaders: touka.ForwardedBoth, ForwardedHeaders: touka.ForwardedBoth,
ForwardedBy: "gateway-1", ForwardedBy: "_gateway-1",
Via: "edge-1", Via: "edge-1",
})) }))
``` ```
如果您配置了 `ForwardedBy`,它必须是一个符合 RFC 7239 的 node identifier。
- IPv4`203.0.113.43`
- IPv6 / 带端口:`[2001:db8::17]:443`
- 匿名标识:`_gateway-1`
- 未知:`unknown`
`gateway-1` 这类普通 token 不再被视为合法的 `by=` 值。
`Via` 不是“留空即禁用”的开关。当前实现中: `Via` 不是“留空即禁用”的开关。当前实现中:
- 如果 `Via` 非空,则使用该值追加 `Via` - 如果 `Via` 非空,则使用该值追加 `Via`
@ -282,11 +419,14 @@ Touka 会尽量遵循代理链语义:
Touka 的反向代理实现支持以下能力: Touka 的反向代理实现支持以下能力:
- `CONNECT` 隧道转发HTTP/1.x
- HTTP/2 extended `CONNECT`
- `Connection: Upgrade` / `Upgrade` 协议升级转发 - `Connection: Upgrade` / `Upgrade` 协议升级转发
- WebSocket 等 101 Switching Protocols 场景 - WebSocket 等 101 Switching Protocols 场景
- SSEServer-Sent Events立即刷新 - SSEServer-Sent Events立即刷新
- Trailer 透传 - Trailer 透传
- 1xx 响应透传 - 1xx 响应透传
- `TRACE` / `OPTIONS` 上的 `Max-Forwards` 递减与本地终止处理
例如,代理 WebSocket 服务: 例如,代理 WebSocket 服务:
@ -341,7 +481,7 @@ func main() {
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Target: target, Target: target,
ForwardedHeaders: touka.ForwardedBoth, ForwardedHeaders: touka.ForwardedBoth,
ForwardedBy: "gateway-1", ForwardedBy: "_gateway-1",
Via: "gateway-1", Via: "gateway-1",
FlushInterval: 100 * time.Millisecond, FlushInterval: 100 * time.Millisecond,
ModifyRequest: func(req *http.Request) { ModifyRequest: func(req *http.Request) {
@ -357,7 +497,7 @@ func main() {
}, },
})) }))
if err := r.RunShutdown(":8080", 10*time.Second); err != nil { if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }

View file

@ -22,6 +22,8 @@ r.ANY("/any", handle)
r.HandleFunc([]string{"GET", "POST"}, "/multi", handle) r.HandleFunc([]string{"GET", "POST"}, "/multi", handle)
``` ```
服务器级 `OPTIONS *` 请求不需要单独注册路由。Touka 会直接返回一个空的 `200 OK` 响应,而不会把它当成 `/` 路由来匹配。
## 路径参数 (Named Parameters) ## 路径参数 (Named Parameters)
使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。 使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。
@ -140,7 +142,7 @@ func main() {
r := touka.Default() r := touka.Default()
fsroot, _ := fs.Sub(content, "dist") fsroot, _ := fs.Sub(content, "dist")
r.StaticFS("/", http.FS(fsroot)) r.StaticFS("/", http.FS(fsroot))
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
``` ```

View file

@ -40,43 +40,40 @@ r.GET("/events", func(c *touka.Context) {
## 模式二:通道模式 (EventStreamChan) ## 模式二:通道模式 (EventStreamChan)
如果您需要更高级的并发控制(例如从多个异步源接收数据),可以使用通道模式。 如果您需要更高级的并发控制(例如从多个异步源接收数据),可以使用通道模式。与回调模式类似,此方法是**阻塞的**handler 会在此方法中停留,直到事件 channel 被关闭或客户端断开连接。
```go ```go
r.GET("/events-chan", func(c *touka.Context) { r.GET("/events-chan", func(c *touka.Context) {
eventChan, errChan := c.EventStreamChan() eventChan := make(chan touka.Event)
ctx := c.Request.Context()
// 监听错误/断开连接 // 在独立的 goroutine 中发送事件.
go func() { go func() {
if err := <-errChan; err != nil { defer close(eventChan) // 务必在结束时关闭以结束事件流.
log.Printf("SSE 错误: %v", err)
}
}()
// 发送数据
go func() {
defer close(eventChan) // 务必在结束时关闭
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
select { select {
case <-c.Request.Context().Done(): case <-ctx.Done():
return return // 客户端已断开, 退出 goroutine.
default: case eventChan <- touka.Event{
eventChan <- touka.Event{
Data: fmt.Sprintf("消息 #%d", i), Data: fmt.Sprintf("消息 #%d", i),
}:
} }
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
}
}() }()
// EventStreamChan 会阻塞直到流结束.
c.EventStreamChan(eventChan)
}) })
``` ```
## 最佳实践 ## 最佳实践
1. **资源回收**: 确保在 `EventStreamChan` 模式下正确监听 `c.Request.Context().Done()` 以避免 Goroutine 泄漏。 1. **资源回收**: `EventStreamChan` 是阻塞的handler 在事件流结束前不会返回。将 `c.Request.Context().Done()``eventChan <- ...` 作为同一个 `select` 的两个分支,确保发送操作本身能够响应客户端断开。
2. **数据格式**: SSE 协议要求数据为 UTF-8。Touka 的 `Render` 方法会自动处理多行数据并加上必要的 `data:` 前缀。 2. **关闭 Channel**: 生产者完成发送后必须 `close(eventChan)`,否则 handler 会永远阻塞。
3. **超时管理**: SSE 连接通常是长连接,请确保您的反向代理(如 Nginx配置了足够大的写超时时间。 3. **数据格式**: SSE 协议要求数据为 UTF-8。Touka 的 `Render` 方法会自动处理多行数据并加上必要的 `data:` 前缀。
4. **超时管理**: SSE 连接通常是长连接,请确保您的反向代理(如 Nginx配置了足够大的写超时时间。
## 优雅关闭与资源清理 ## 优雅关闭与资源清理
@ -128,4 +125,4 @@ r.GET("/events-graceful", func(c *touka.Context) {
2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。 2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。
3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。 3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。
**注意:** 请务必使用 `RunShutdown``RunTLS``RunTLSRedir` 来启动服务器,以便框架能自动管理这些信号 **注意:** 请务必通过 `r.Run(...)` 并显式传入优雅关闭选项来启动服务器,例如 `touka.WithGracefulShutdown(...)``touka.WithGracefulShutdownDefault()`。只有启用了优雅关闭,框架才会在服务退出时取消这些请求上下文

View file

@ -39,7 +39,7 @@ func main() {
// 您也可以使用 StaticFS 服务根路径 // 您也可以使用 StaticFS 服务根路径
// r.StaticFS("/", http.FS(fsroot)) // r.StaticFS("/", http.FS(fsroot))
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
``` ```

2
ecw.go
View file

@ -197,7 +197,7 @@ func (ecw *errorCapturingResponseWriter) Written() bool {
func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := ecw.w.(http.Hijacker) hijacker, ok := ecw.w.(http.Hijacker)
if !ok { if !ok {
return nil, nil, errors.New("the underlying ResponseWriter does not support the Hijacker interface") return nil, nil, http.ErrNotSupported
} }
return hijacker.Hijack() return hijacker.Hijack()
} }

59
ecw_benchmark_test.go Normal file
View file

@ -0,0 +1,59 @@
package touka
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestErrorCapturingResponseWriterResetClearsHeaderSnapshot(t *testing.T) {
c, _ := CreateTestContext(nil)
ecw := AcquireErrorCapturingResponseWriter(c)
defer ReleaseErrorCapturingResponseWriter(ecw)
ecw.capturedErrorSignal = true
ecw.Header().Set("Content-Type", "text/plain")
ecw.Header().Add("X-Test", "one")
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatalf("failed to build request: %v", err)
}
ecw.reset(httptest.NewRecorder(), req, c, c.engine.errorHandle.handler)
if len(ecw.headerSnapshot) != 0 {
t.Fatalf("expected header snapshot to be empty after reset, got %#v", ecw.headerSnapshot)
}
}
func BenchmarkErrorCapturingResponseWriterReset(b *testing.B) {
c, _ := CreateTestContext(nil)
ecw := AcquireErrorCapturingResponseWriter(c)
defer ReleaseErrorCapturingResponseWriter(ecw)
rawWriter := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
b.Fatalf("failed to build request: %v", err)
}
keys := make([]string, 16)
for i := range keys {
keys[i] = http.CanonicalHeaderKey("X-Test-" + string(rune('A'+i)))
}
values := []string{"one", "two", "three"}
for _, key := range keys {
ecw.headerSnapshot[key] = values
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ecw.reset(rawWriter, req, c, c.engine.errorHandle.handler)
for _, key := range keys {
ecw.headerSnapshot[key] = values
}
}
}

384
engine.go
View file

@ -7,9 +7,11 @@ package touka
import ( import (
"context" "context"
"errors" "errors"
"io"
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
"unicode/utf8"
"net/http" "net/http"
@ -17,6 +19,7 @@ import (
"github.com/WJQSERVER-STUDIO/httpc" "github.com/WJQSERVER-STUDIO/httpc"
"github.com/fenthope/reco" "github.com/fenthope/reco"
"github.com/go-json-experiment/json"
) )
// Last 返回链中的最后一个处理函数 // Last 返回链中的最后一个处理函数
@ -49,8 +52,14 @@ type Engine struct {
HTTPClient *httpc.Client // 用于在此上下文中执行出站 HTTP 请求 HTTPClient *httpc.Client // 用于在此上下文中执行出站 HTTP 请求
// LogReco 保留的 reco.Logger 字段
// Deprecated: 使用 SetLogger/GetLogger 替代
LogReco *reco.Logger LogReco *reco.Logger
// logger 是新的日志接口,支持任意 Logger 实现
// 优先级: logger > LogReco
logger Logger
HTMLRender any // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口 HTMLRender any // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口
routesInfo []RouteInfo // 存储所有注册的路由信息 routesInfo []RouteInfo // 存储所有注册的路由信息
@ -81,6 +90,11 @@ type Engine struct {
// GlobalMaxRequestBodySize 全局请求体Body大小限制 // GlobalMaxRequestBodySize 全局请求体Body大小限制
GlobalMaxRequestBodySize int64 GlobalMaxRequestBodySize int64
notFoundChain HandlersChain
notFoundNoMethodChain HandlersChain
unmatchedFSChain HandlersChain
unmatchedFSNoMethodChain HandlersChain
} }
// HandleFunc 注册一个或多个 HTTP 方法的路由 // HandleFunc 注册一个或多个 HTTP 方法的路由
@ -116,6 +130,90 @@ type ErrorHandle struct {
type ErrorHandler func(c *Context, code int, err error) type ErrorHandler func(c *Context, code int, err error)
var errMethodNotAllowed = errors.New("method not allowed")
var errNotFound = errors.New("not found")
type defaultErrorResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Error string `json:"error"`
}
var defaultNotFoundBody = mustMarshalDefaultErrorBody(http.StatusNotFound, errNotFound.Error())
var defaultMethodNotAllowedBody = mustMarshalDefaultErrorBody(http.StatusMethodNotAllowed, errMethodNotAllowed.Error())
func mustMarshalDefaultErrorBody(code int, errMsg string) []byte {
body, err := json.Marshal(defaultErrorResponse{
Code: code,
Message: http.StatusText(code),
Error: errMsg,
})
if err != nil {
panic(err)
}
return body
}
func writeDefaultErrorJSON(c *Context, code int, body []byte) {
if c == nil || c.Writer == nil {
return
}
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
c.Writer.WriteHeader(code)
c.writeResponseBody(body, "failed to write default error response")
c.Writer.Flush()
c.Abort()
}
var methodNotAllowedHandler HandlerFunc = func(c *Context) {
httpMethod := c.Request.Method
requestPath := routeLookupPath(c.Request)
engine := c.engine
// 是否是OPTIONS方式
if httpMethod == http.MethodOptions {
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
allowedMethods := engine.allowedMethodsForPath(requestPath, c.allowedMethodsBuf[:0])
c.allowedMethodsBuf = allowedMethods[:0]
if len(allowedMethods) > 0 {
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
allowHeader := c.allowHeaderBuf[:0]
for i, method := range allowedMethods {
if i > 0 {
allowHeader = append(allowHeader, ',', ' ')
}
allowHeader = append(allowHeader, method...)
}
c.allowHeaderBuf = allowHeader[:0]
c.Writer.Header().Set("Allow", string(allowHeader))
c.Status(http.StatusOK)
return
}
return
}
// 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径
tempSkippedNodes := GetTempSkippedNodes()
for _, treeIter := range engine.methodTrees {
if treeIter.method == httpMethod { // 已经处理过当前方法,跳过
continue
}
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
*tempSkippedNodes = (*tempSkippedNodes)[:0]
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数
if value.handlers != nil {
PutTempSkippedNodes(tempSkippedNodes)
// 使用定义的ErrorHandle处理
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errMethodNotAllowed)
return
}
}
PutTempSkippedNodes(tempSkippedNodes)
}
var notFoundHandler HandlerFunc = func(c *Context) {
engine := c.engine
engine.errorHandle.handler(c, http.StatusNotFound, errNotFound)
}
// defaultErrorHandle 默认错误处理 // defaultErrorHandle 默认错误处理
func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接 func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接
select { select {
@ -126,16 +224,22 @@ func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是
if c.Writer.Written() { if c.Writer.Written() {
return return
} }
if len(c.Errors) == 0 {
switch {
case code == http.StatusNotFound && errors.Is(err, errNotFound):
writeDefaultErrorJSON(c, code, defaultNotFoundBody)
return
case code == http.StatusMethodNotAllowed && errors.Is(err, errMethodNotAllowed):
writeDefaultErrorJSON(c, code, defaultMethodNotAllowedBody)
return
}
}
// 输出json 状态码与状态码对应描述 // 输出json 状态码与状态码对应描述
var errMsg string var errMsg string
if err != nil { if err != nil {
errMsg = err.Error() errMsg = err.Error()
} }
c.JSON(code, H{ c.JSON(code, defaultErrorResponse{Code: code, Message: http.StatusText(code), Error: errMsg})
"code": code,
"message": http.StatusText(code),
"error": errMsg,
})
c.Writer.Flush() c.Writer.Flush()
c.Abort() c.Abort()
return return
@ -210,6 +314,7 @@ func New() *Engine {
TLSServerConfigurator: nil, TLSServerConfigurator: nil,
GlobalMaxRequestBodySize: -1, GlobalMaxRequestBodySize: -1,
} }
engine.rebuildFallbackChains()
engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background()) engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background())
//engine.SetProtocols(GetDefaultProtocolsConfig()) //engine.SetProtocols(GetDefaultProtocolsConfig())
engine.SetDefaultProtocols() engine.SetDefaultProtocols()
@ -265,16 +370,30 @@ func (engine *Engine) SetRedirectFixedPath(enable bool) {
// 是否开启MethodNotAllowed // 是否开启MethodNotAllowed
func (engine *Engine) SetHandleMethodNotAllowed(enable bool) { func (engine *Engine) SetHandleMethodNotAllowed(enable bool) {
engine.HandleMethodNotAllowed = enable engine.HandleMethodNotAllowed = enable
engine.rebuildFallbackChains()
} }
// SetLogger传入实例 // SetLogger 传入 Logger 接口实例
func (engine *Engine) SetLogger(logger *reco.Logger) { func (engine *Engine) SetLogger(logger Logger) {
engine.LogReco = logger engine.logger = logger
// 同步更新 LogReco 以保持向后兼容
if rl, ok := logger.(*reco.Logger); ok {
engine.LogReco = rl
} else {
engine.LogReco = nil
}
} }
// 配置日志LoggerCfg // GetLogger 返回 Logger 接口实例
func (engine *Engine) GetLogger() Logger {
return engine.logger
}
// SetLoggerCfg 使用 reco.Config 配置日志
func (engine *Engine) SetLoggerCfg(logcfg reco.Config) { func (engine *Engine) SetLoggerCfg(logcfg reco.Config) {
engine.LogReco = NewLogger(logcfg) logger := NewLogger(logcfg)
engine.logger = logger
engine.LogReco = logger
} }
// 设置自定义错误处理 // 设置自定义错误处理
@ -305,6 +424,7 @@ func (engine *Engine) SetUnMatchFSChain(fs http.FileSystem, handlers ...HandlerF
engine.unMatchFS.ServeUnmatchedAsFS = false engine.unMatchFS.ServeUnmatchedAsFS = false
engine.UnMatchFSRoutes = nil engine.UnMatchFSRoutes = nil
} }
engine.rebuildFallbackChains()
} }
// 获取默认Protocol配置 // 获取默认Protocol配置
@ -340,11 +460,28 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) {
}() }()
} }
func cloneServerProtocols(protocols *http.Protocols) *http.Protocols {
if protocols == nil {
return nil
}
cloned := *protocols
return &cloned
}
func applyServerProtocols(srv *http.Server, protocols *http.Protocols) {
if protocols != nil {
srv.Protocols = cloneServerProtocols(protocols)
if srv.Protocols.HTTP2() || srv.Protocols.UnencryptedHTTP2() {
if err := configureHTTP2ExtendedConnectServer(srv); err != nil {
panic(err)
}
}
}
}
// applyDefaultServerConfig 应用框架的默认配置到 http.Server // applyDefaultServerConfig 应用框架的默认配置到 http.Server
func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { func (engine *Engine) applyDefaultServerConfig(srv *http.Server) {
if engine.serverProtocols != nil { applyServerProtocols(srv, engine.serverProtocols)
srv.Protocols = engine.serverProtocols
}
} }
// 配置全局Req Body大小限制 // 配置全局Req Body大小限制
@ -473,66 +610,64 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) {
// 405中间件 // 405中间件
func MethodNotAllowed() HandlerFunc { func MethodNotAllowed() HandlerFunc {
return func(c *Context) { return methodNotAllowedHandler
httpMethod := c.Request.Method
requestPath := c.Request.URL.Path
engine := c.engine
// 是否是OPTIONS方式
if httpMethod == http.MethodOptions {
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
allowedMethods := []string{}
for _, treeIter := range engine.methodTrees {
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
tempSkippedNodes := GetTempSkippedNodes()
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
PutTempSkippedNodes(tempSkippedNodes)
if value.handlers != nil {
allowedMethods = append(allowedMethods, treeIter.method)
}
}
if len(allowedMethods) > 0 {
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", "))
c.Status(http.StatusOK)
return
}
}
// 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径
for _, treeIter := range engine.methodTrees {
if treeIter.method == httpMethod { // 已经处理过当前方法,跳过
continue
}
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
tempSkippedNodes := GetTempSkippedNodes()
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数
PutTempSkippedNodes(tempSkippedNodes)
if value.handlers != nil {
// 使用定义的ErrorHandle处理
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed"))
return
}
}
}
} }
// 404最后处理 // 404最后处理
func NotFound() HandlerFunc { func NotFound() HandlerFunc {
return func(c *Context) { return notFoundHandler
engine := c.engine
engine.errorHandle.handler(c, http.StatusNotFound, errors.New("not found"))
}
} }
// 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理) // 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理)
func (Engine *Engine) NoRoute(handler HandlerFunc) { func (Engine *Engine) NoRoute(handler HandlerFunc) {
Engine.noRoute = handler Engine.noRoute = handler
Engine.noRoutes = nil Engine.noRoutes = nil
Engine.rebuildFallbackChains()
} }
// 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理) // 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理)
func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) { func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) {
Engine.noRoute = nil Engine.noRoute = nil
Engine.noRoutes = handlerFuncs Engine.noRoutes = handlerFuncs
Engine.rebuildFallbackChains()
}
func (engine *Engine) rebuildFallbackChains() {
buildChain := func(includeMethodNotAllowed bool, includeUnmatchedFS bool) HandlersChain {
finalSize := len(engine.globalHandlers) + 1 // 最后的 NotFound
if includeMethodNotAllowed {
finalSize++
}
if includeUnmatchedFS {
finalSize += len(engine.UnMatchFSRoutes)
}
if engine.noRoute != nil {
finalSize++
} else {
finalSize += len(engine.noRoutes)
}
chain := make(HandlersChain, 0, finalSize)
chain = append(chain, engine.globalHandlers...)
if includeMethodNotAllowed {
chain = append(chain, methodNotAllowedHandler)
}
if includeUnmatchedFS {
chain = append(chain, engine.UnMatchFSRoutes...)
}
if engine.noRoute != nil {
chain = append(chain, engine.noRoute)
} else if len(engine.noRoutes) > 0 {
chain = append(chain, engine.noRoutes...)
}
chain = append(chain, notFoundHandler)
return chain
}
engine.notFoundChain = buildChain(engine.HandleMethodNotAllowed, false)
engine.notFoundNoMethodChain = buildChain(false, false)
engine.unmatchedFSChain = buildChain(engine.HandleMethodNotAllowed, engine.unMatchFS.ServeUnmatchedAsFS)
engine.unmatchedFSNoMethodChain = buildChain(false, engine.unMatchFS.ServeUnmatchedAsFS)
} }
// combineHandlers 组合多个处理函数链为一个 // combineHandlers 组合多个处理函数链为一个
@ -547,8 +682,9 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle
// Use 将全局中间件添加到 Engine // Use 将全局中间件添加到 Engine
// 这些中间件将应用于所有注册的路由 // 这些中间件将应用于所有注册的路由
func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { func (engine *Engine) Use(middleware ...HandlerFunc) Router {
engine.globalHandlers = append(engine.globalHandlers, middleware...) engine.globalHandlers = append(engine.globalHandlers, middleware...)
engine.rebuildFallbackChains()
return engine return engine
} }
@ -615,7 +751,7 @@ func (engine *Engine) GetRouterInfo() []RouteInfo {
// Group 创建一个新的路由组 // Group 创建一个新的路由组
// 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起 // 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起
func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter { func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router {
return &RouterGroup{ return &RouterGroup{
Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件 Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件
basePath: resolveRoutePath("/", relativePath), basePath: resolveRoutePath("/", relativePath),
@ -624,7 +760,7 @@ func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRoute
} }
// RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由 // RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由
// 它也实现了 IRouter 接口,允许嵌套分组 // 它也实现了 Router 接口,允许嵌套分组
type RouterGroup struct { type RouterGroup struct {
Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由 Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由
basePath string // 组路径前缀 basePath string // 组路径前缀
@ -633,7 +769,7 @@ type RouterGroup struct {
// Use 将中间件应用于当前路由组 // Use 将中间件应用于当前路由组
// 这些中间件将应用于当前组及其子组的所有路由 // 这些中间件将应用于当前组及其子组的所有路由
func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter { func (group *RouterGroup) Use(middleware ...HandlerFunc) Router {
group.Handlers = append(group.Handlers, middleware...) group.Handlers = append(group.Handlers, middleware...)
return group return group
} }
@ -679,7 +815,7 @@ func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) {
} }
// Group 为当前组创建一个新的子组 // Group 为当前组创建一个新的子组
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter { func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) Router {
return &RouterGroup{ return &RouterGroup{
Handlers: group.engine.combineHandlers(group.Handlers, handlers), Handlers: group.engine.combineHandlers(group.Handlers, handlers),
basePath: resolveRoutePath(group.basePath, relativePath), basePath: resolveRoutePath(group.basePath, relativePath),
@ -704,8 +840,13 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// handleRequest 负责根据请求查找路由并执行相应的处理函数链 // handleRequest 负责根据请求查找路由并执行相应的处理函数链
// 这是路由查找和执行的核心逻辑 // 这是路由查找和执行的核心逻辑
func (engine *Engine) handleRequest(c *Context) { func (engine *Engine) handleRequest(c *Context) {
if isGeneralOptionsRequest(c.Request) {
engine.handleGeneralOptions(c)
return
}
httpMethod := c.Request.Method httpMethod := c.Request.Method
requestPath := c.Request.URL.Path requestPath := routeLookupPath(c.Request)
// 查找对应的路由树的根节点 // 查找对应的路由树的根节点
rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型 rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型
@ -725,7 +866,7 @@ func (engine *Engine) handleRequest(c *Context) {
} }
// 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复) // 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复)
if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向 if httpMethod != http.MethodConnect && requestPath != "/" && !isGeneralOptionsRequest(c.Request) { // CONNECT 方法、服务器级 OPTIONS 和根路径不进行重定向
if value.tsr && engine.RedirectTrailingSlash { if value.tsr && engine.RedirectTrailingSlash {
// 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/ // 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/
redirectPath := requestPath redirectPath := requestPath
@ -737,51 +878,98 @@ func (engine *Engine) handleRequest(c *Context) {
c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向
return return
} }
// 尝试不区分大小写的查找 if engine.RedirectFixedPath && shouldTryFixedPathLookup(requestPath, rootNode) {
// 直接在 rootNode 上调用 findCaseInsensitivePath 方法 // 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历.
ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) ciPath, found := rootNode.findCaseInsensitivePathWithBuffer(requestPath, c.fixedPathBuf, engine.RedirectTrailingSlash)
if found && engine.RedirectFixedPath { if found {
c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 c.fixedPathBuf = ciPath[:0]
c.Redirect(http.StatusMovedPermanently, string(ciPath)) // 301 永久重定向到修正后的路径
return return
} }
c.fixedPathBuf = c.fixedPathBuf[:0]
}
} }
} }
// 构建处理链
// 组合全局中间件和路由处理函数
handlers := engine.globalHandlers
// 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由
// 则在全局中间件之后添加 MethodNotAllowed 处理器
if engine.HandleMethodNotAllowed {
handlers = append(handlers, MethodNotAllowed())
}
// 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed
// 则在处理链的最后添加 UnMatchFS 处理器
if engine.unMatchFS.ServeUnmatchedAsFS { if engine.unMatchFS.ServeUnmatchedAsFS {
/* c.handlers = engine.unmatchedFSChain
var unMatchFSHandle = c.engine.unMatchFileServer } else {
handlers = append(handlers, unMatchFSHandle) c.handlers = engine.notFoundChain
*/
handlers = append(handlers, engine.UnMatchFSRoutes...)
} }
// 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS
// 则在处理链的最后添加 NoRoute 处理器
if engine.noRoute != nil {
handlers = append(handlers, engine.noRoute)
} else if len(engine.noRoutes) > 0 {
handlers = append(handlers, engine.noRoutes...)
}
handlers = append(handlers, NotFound())
c.handlers = handlers
c.Next() // 执行处理函数链 c.Next() // 执行处理函数链
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送 //c.Writer.Flush() // 确保所有缓冲的响应数据被发送
} }
func routeLookupPath(req *http.Request) string {
if req == nil {
return ""
}
if req.Method == http.MethodConnect && req.RequestURI != "" && req.RequestURI != "*" && !strings.HasPrefix(req.RequestURI, "/") && !strings.Contains(req.RequestURI, "://") {
return "/" + req.RequestURI
}
if isGeneralOptionsRequest(req) {
return ""
}
if req.URL == nil {
return ""
}
return req.URL.Path
}
func isGeneralOptionsRequest(req *http.Request) bool {
return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*"
}
func shouldTryFixedPathLookup(path string, root *node) bool {
if root != nil && root.hasCaseInsensitivePath {
return true
}
for i := 0; i < len(path); i++ {
c := path[i]
if c >= utf8.RuneSelf {
return true
}
if c >= 'A' && c <= 'Z' {
return true
}
}
return false
}
func (engine *Engine) allowedMethodsForPath(requestPath string, allowedMethods []string) []string {
if cap(allowedMethods) < len(engine.methodTrees) {
allowedMethods = make([]string, 0, len(engine.methodTrees))
} else {
allowedMethods = allowedMethods[:0]
}
tempSkippedNodes := GetTempSkippedNodes()
for _, treeIter := range engine.methodTrees {
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
*tempSkippedNodes = (*tempSkippedNodes)[:0]
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
if value.handlers != nil {
allowedMethods = append(allowedMethods, treeIter.method)
}
}
PutTempSkippedNodes(tempSkippedNodes)
return allowedMethods
}
func (engine *Engine) handleGeneralOptions(c *Context) {
if c == nil || c.Request == nil {
return
}
c.Writer.Header().Set("Content-Length", "0")
if c.Request.ContentLength != 0 {
mb := http.MaxBytesReader(c.Writer, c.Request.Body, 4<<10)
_, _ = io.Copy(io.Discard, mb)
}
c.Writer.WriteHeader(http.StatusOK)
c.Abort()
}
// Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消. // Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消.
// 它可以用于在长连接 (如 SSE) 中监听关闭信号. // 它可以用于在长连接 (如 SSE) 中监听关闭信号.
func (engine *Engine) Context() context.Context { func (engine *Engine) Context() context.Context {

71
engine_benchmark_test.go Normal file
View file

@ -0,0 +1,71 @@
package touka
import (
"net/http"
"net/http/httptest"
"testing"
)
var benchmarkStatusCode int
func buildServeHTTPBenchmarkEngine() *Engine {
engine := New()
engine.GET("/api/v1/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.GET("/api/v1/users/:id", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.POST("/api/v1/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
return engine
}
func benchmarkServeHTTP(b *testing.B, engine *Engine, method, path string) {
b.Helper()
req, err := http.NewRequest(method, path, nil)
if err != nil {
b.Fatalf("failed to build request: %v", err)
}
rr := httptest.NewRecorder()
engine.ServeHTTP(rr, req)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr = httptest.NewRecorder()
engine.ServeHTTP(rr, req)
}
benchmarkStatusCode = rr.Code
}
func BenchmarkServeHTTP(b *testing.B) {
engine := buildServeHTTPBenchmarkEngine()
b.Run("StaticHit", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodGet, "/api/v1/users")
})
b.Run("NotFound", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodGet, "/does/not/exist")
})
b.Run("MethodNotAllowed", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodDelete, "/api/v1/users")
})
b.Run("OptionsAllow", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users")
})
b.Run("FixedPathRedirect", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS")
})
}

306
engine_test.go Normal file
View file

@ -0,0 +1,306 @@
package touka
import (
"bufio"
"encoding/json"
"errors"
"html/template"
"net"
"net/http"
"testing"
)
type failingResponseWriter struct {
header http.Header
status int
err error
}
func (w *failingResponseWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *failingResponseWriter) WriteHeader(statusCode int) {
if w.status == 0 {
w.status = statusCode
}
}
func (w *failingResponseWriter) Write(p []byte) (int, error) {
if w.status == 0 {
w.status = http.StatusOK
}
if w.err != nil {
return 0, w.err
}
return len(p), nil
}
func (w *failingResponseWriter) Flush() {}
func (w *failingResponseWriter) Status() int {
return w.status
}
func (w *failingResponseWriter) Size() int {
return 0
}
func (w *failingResponseWriter) Written() bool {
return w.status != 0
}
func (w *failingResponseWriter) IsHijacked() bool {
return false
}
func (w *failingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, http.ErrNotSupported
}
func TestHandleRequestRedirectFixedPath(t *testing.T) {
engine := New()
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS", nil, nil)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected fixed-path redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "/api/v1/users/123/settings" {
t.Fatalf("expected fixed-path redirect location %q, got %q", "/api/v1/users/123/settings", location)
}
}
func TestHandleRequestSkipsFixedPathLookupForLowercaseMiss(t *testing.T) {
engine := New()
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodGet, "/does/not/exist", nil, nil)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected lowercase miss to stay as 404, got %d", rr.Code)
}
}
func TestHandleRequestKeepsFixedPathLookupForUppercaseMiss(t *testing.T) {
engine := New()
engine.GET("/Users/Profile", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodGet, "/users/profile", nil, nil)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected uppercase route miss to trigger fixed-path redirect, got %d", rr.Code)
}
if location := rr.Header().Get("Location"); location != "/Users/Profile" {
t.Fatalf("expected uppercase route redirect location %q, got %q", "/Users/Profile", location)
}
}
func TestHandleRequestFixedPathLookupMissDoesNotPanic(t *testing.T) {
engine := New()
engine.GET("/Users/Profile", func(c *Context) {
c.Status(http.StatusNoContent)
})
defer func() {
if r := recover(); r != nil {
t.Fatalf("unexpected panic for fixed-path miss: %v", r)
}
}()
rr := PerformRequest(engine, http.MethodGet, "/users/unknown", nil, nil)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected fixed-path miss to stay as 404, got %d", rr.Code)
}
}
func TestNoRouteCanContinueToDefaultNotFound(t *testing.T) {
engine := New()
engine.NoRoute(func(c *Context) {
c.Writer.Header().Set("X-NoRoute", "hit")
c.Next()
})
rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected default not found status %d, got %d", http.StatusNotFound, rr.Code)
}
if got := rr.Header().Get("X-NoRoute"); got != "hit" {
t.Fatalf("expected NoRoute middleware header to be preserved, got %q", got)
}
}
func TestMethodNotAllowedDoesNotContinueToNoRoute(t *testing.T) {
engine := New()
engine.GET("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.NoRoute(func(c *Context) {
c.Writer.Header().Set("X-NoRoute", "hit")
c.Next()
})
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
if rr.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected method not allowed status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
}
if got := rr.Header().Get("X-NoRoute"); got != "" {
t.Fatalf("expected NoRoute chain to be skipped after 405, got header %q", got)
}
}
func TestOptionsAllowHeaderListsMatchingMethods(t *testing.T) {
engine := New()
engine.GET("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.POST("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodOptions, "/users", nil, nil)
if rr.Code != http.StatusOK {
t.Fatalf("expected OPTIONS allow status %d, got %d", http.StatusOK, rr.Code)
}
allow := rr.Header().Get("Allow")
if allow != "GET, POST" && allow != "POST, GET" {
t.Fatalf("expected Allow header to list matching methods, got %q", allow)
}
}
func TestDefaultErrorHandleJSONShape(t *testing.T) {
engine := New()
rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected status %d, got %d", http.StatusNotFound, rr.Code)
}
var body struct {
Code int `json:"code"`
Message string `json:"message"`
Error string `json:"error"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil {
t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err)
}
if body.Code != http.StatusNotFound || body.Message != http.StatusText(http.StatusNotFound) || body.Error != "not found" {
t.Fatalf("unexpected error payload: %+v", body)
}
}
func TestDefaultMethodNotAllowedJSONShape(t *testing.T) {
engine := New()
engine.GET("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
if rr.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
}
var body struct {
Code int `json:"code"`
Message string `json:"message"`
Error string `json:"error"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil {
t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err)
}
if body.Code != http.StatusMethodNotAllowed || body.Message != http.StatusText(http.StatusMethodNotAllowed) || body.Error != "method not allowed" {
t.Fatalf("unexpected error payload: %+v", body)
}
}
func TestCustomErrorHandlerStillOverridesDefaultFastPath(t *testing.T) {
engine := New()
engine.SetErrorHandler(func(c *Context, code int, err error) {
c.Writer.Header().Set("X-Custom-Error", "1")
c.String(code, "custom:%v", err)
})
engine.GET("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
if rr.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
}
if got := rr.Header().Get("X-Custom-Error"); got != "1" {
t.Fatalf("expected custom error header, got %q", got)
}
if rr.Body.String() != "custom:method not allowed" {
t.Fatalf("expected custom error body, got %q", rr.Body.String())
}
}
func TestResponseHelpersCaptureWriteErrors(t *testing.T) {
testCases := []struct {
name string
run func(*Context)
}{
{name: "Raw", run: func(c *Context) { c.Raw(http.StatusOK, "application/octet-stream", []byte("payload")) }},
{name: "String", run: func(c *Context) { c.String(http.StatusOK, "value=%d", 1) }},
{name: "Text", run: func(c *Context) { c.Text(http.StatusOK, "payload") }},
{name: "JSONBuf", run: func(c *Context) { c.JSONBuf(http.StatusOK, map[string]string{"a": "b"}) }},
{name: "GOBBuf", run: func(c *Context) { c.GOBBuf(http.StatusOK, struct{ A string }{A: "b"}) }},
{name: "WANFBuf", run: func(c *Context) { c.WANFBuf(http.StatusOK, map[string]string{"a": "b"}) }},
{name: "HTMLFallback", run: func(c *Context) { c.HTML(http.StatusOK, "page", map[string]string{"a": "b"}) }},
{name: "HTMLBuf", run: func(c *Context) {
c.engine.HTMLRender = template.Must(template.New("page").Parse(`{{.a}}`))
c.HTMLBuf(http.StatusOK, "page", map[string]string{"a": "b"})
}},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
writerErr := errors.New("write failed")
w := &failingResponseWriter{err: writerErr}
c, _ := CreateTestContext(w)
tc.run(c)
if got := len(c.Errors); got != 1 {
t.Fatalf("expected exactly one captured error, got %d", got)
}
if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) {
t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1])
}
})
}
}
func TestDefaultErrorFastPathCapturesWriteErrors(t *testing.T) {
writerErr := errors.New("write failed")
w := &failingResponseWriter{err: writerErr}
engine := New()
c, _ := CreateTestContext(w)
c.engine = engine
req, err := http.NewRequest(http.MethodGet, "/missing", nil)
if err != nil {
t.Fatalf("failed to build request: %v", err)
}
c.reset(w, req)
defaultErrorHandle(c, http.StatusNotFound, errNotFound)
if len(c.Errors) == 0 {
t.Fatal("expected write error to be captured")
}
if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) {
t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1])
}
if c.Writer.Status() != http.StatusNotFound {
t.Fatalf("expected status %d, got %d", http.StatusNotFound, c.Writer.Status())
}
if !c.IsAborted() {
t.Fatal("expected fast path to abort context")
}
}

103
examples/httpc/main.go Normal file
View file

@ -0,0 +1,103 @@
package main
import (
"fmt"
"net/http"
"github.com/infinite-iroha/touka"
)
func main() {
r := touka.Default()
// 示例 1简单 GET 请求(自动关联请求 Context
r.GET("/proxy", func(c *touka.Context) {
// 使用 HTTPC() 方法,自动关联请求 Context
// 当客户端断开连接时,出站请求也会自动取消
body, err := c.HTTPC().
GET("https://httpbin.org/get").
Text()
if err != nil {
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
return
}
c.String(http.StatusOK, "%s", body)
})
// 示例 2带 Header 的 POST 请求
r.POST("/users", func(c *touka.Context) {
var req struct {
Name string `json:"name"`
Email string `json:"email"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()})
return
}
var result struct {
ID int `json:"id"`
Name string `json:"name"`
}
// 链式调用,保持 httpc 风格
// 注意SetJSONBody 返回 (*RequestBuilder, error)
rb, err := c.HTTPC().
POST("https://httpbin.org/post").
SetHeader("X-API-Key", "secret").
SetJSONBody(req)
if err != nil {
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
return
}
if err := rb.DecodeJSON(&result); err != nil {
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
})
// 示例 3带查询参数的请求
r.GET("/search", func(c *touka.Context) {
query := c.DefaultQuery("q", "")
page := c.DefaultQuery("page", "1")
var result struct {
Items []string `json:"items"`
Total int `json:"total"`
}
err := c.HTTPC().
GET("https://httpbin.org/get").
SetQueryParam("q", query).
SetQueryParam("page", page).
DecodeJSON(&result)
if err != nil {
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
})
// 示例 4使用底层 httpc.Client旧方式仍可用但不推荐
r.GET("/legacy", func(c *touka.Context) {
// 旧方式:需要手动 WithContext
body, err := c.Client().
GET("https://httpbin.org/get").
WithContext(c.Context()).
Text()
if err != nil {
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
return
}
c.String(http.StatusOK, "%s", body)
})
fmt.Println("Server running on :8080")
fmt.Println("Try:")
fmt.Println(" curl http://localhost:8080/proxy")
fmt.Println(" curl -X POST -d '{\"name\":\"test\",\"email\":\"test@example.com\"}' http://localhost:8080/users")
fmt.Println(" curl 'http://localhost:8080/search?q=golang&page=1'")
// r.Run(touka.WithAddr(":8080"))
}

View file

@ -0,0 +1,71 @@
package main
import (
"fmt"
"log/slog"
"net/http"
"os"
"github.com/infinite-iroha/touka"
)
// SlogAdapter 将 slog.Logger 适配到 touka.Logger 接口
type SlogAdapter struct {
logger *slog.Logger
}
func NewSlogAdapter(handler slog.Handler) *SlogAdapter {
return &SlogAdapter{
logger: slog.New(handler),
}
}
func (s *SlogAdapter) Debugf(format string, args ...any) {
s.logger.Debug(fmt.Sprintf(format, args...))
}
func (s *SlogAdapter) Infof(format string, args ...any) {
s.logger.Info(fmt.Sprintf(format, args...))
}
func (s *SlogAdapter) Warnf(format string, args ...any) {
s.logger.Warn(fmt.Sprintf(format, args...))
}
func (s *SlogAdapter) Errorf(format string, args ...any) {
s.logger.Error(fmt.Sprintf(format, args...))
}
func (s *SlogAdapter) Fatalf(format string, args ...any) {
s.logger.Error(fmt.Sprintf(format, args...))
os.Exit(1)
}
func (s *SlogAdapter) Panicf(format string, args ...any) {
s.logger.Error(fmt.Sprintf(format, args...))
panic(fmt.Sprintf(format, args...))
}
func main() {
engine := touka.New()
// 使用 slog 替换默认的 reco.Logger
handler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
Level: slog.LevelDebug,
})
slogAdapter := NewSlogAdapter(handler)
engine.SetLogger(slogAdapter)
engine.GET("/", func(c *touka.Context) {
c.Infof("request received: %s", c.Request.URL.Path)
c.JSON(http.StatusOK, map[string]string{"message": "hello"})
})
// 也可以获取 Logger 接口
logger := engine.GetLogger()
logger.Debugf("engine started")
// 也可以直接使用 slog
slog.Info("Server running", "addr", ":8080")
// engine.Run(":8080")
}

7
go.mod
View file

@ -3,14 +3,15 @@ module github.com/infinite-iroha/touka
go 1.26 go 1.26
require ( require (
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3
github.com/WJQSERVER-STUDIO/httpc v0.9.0 github.com/WJQSERVER-STUDIO/httpc v0.9.3
github.com/WJQSERVER/wanf v0.0.8 github.com/WJQSERVER/wanf v0.0.8
github.com/fenthope/reco v0.0.5 github.com/fenthope/reco v0.0.5
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433
golang.org/x/net v0.53.0
) )
require ( require (
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
golang.org/x/net v0.52.0 // indirect golang.org/x/text v0.36.0 // indirect
) )

14
go.sum
View file

@ -1,7 +1,7 @@
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 h1:AiIHXP21LpK7pFfqUlUstgQEWzjbekZgxOuvVwiMfyM= github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3 h1:Hc1O6D50U3URkdSzfQ/SgeUU750wUBCYhefdvAbE2Ck=
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok= github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3/go.mod h1:nFQzepAwwdj5Hp5U+X19l4FVvsaOSBTW41BzfI/CkMA=
github.com/WJQSERVER-STUDIO/httpc v0.9.0 h1:MpXcQQqukrSLHH/2tTfnXrhqD6nEDHB/gbzehXaS8o4= github.com/WJQSERVER-STUDIO/httpc v0.9.3 h1:wYZkz9f/+2WuDuzPlExebvnn0q6QeArM15Y51HJ5UUI=
github.com/WJQSERVER-STUDIO/httpc v0.9.0/go.mod h1:filzryrl4eAtFVyl4oVHcJqx1SpNFbrCn+ddQPLlCSg= github.com/WJQSERVER-STUDIO/httpc v0.9.3/go.mod h1:vtaDmN/8gN8Es1DJsGvvrFr8kErysJndu87i+KOWUHY=
github.com/WJQSERVER/wanf v0.0.8 h1:1Ri9d7nKhu22hGxP8O9B9rXnYym6DYGKgi6WRVx3VF8= github.com/WJQSERVER/wanf v0.0.8 h1:1Ri9d7nKhu22hGxP8O9B9rXnYym6DYGKgi6WRVx3VF8=
github.com/WJQSERVER/wanf v0.0.8/go.mod h1:R0Zw/1skEMVlQ9m5atbkmanlW+9h2bkdq7+wbPY+F/8= github.com/WJQSERVER/wanf v0.0.8/go.mod h1:R0Zw/1skEMVlQ9m5atbkmanlW+9h2bkdq7+wbPY+F/8=
github.com/fenthope/reco v0.0.5 h1:Z/bOunFf4LSgYP/IxG9fe2pTrIq7bPsDflflbNR5Agw= github.com/fenthope/reco v0.0.5 h1:Z/bOunFf4LSgYP/IxG9fe2pTrIq7bPsDflflbNR5Agw=
@ -10,5 +10,7 @@ github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 h1:vymEbVw
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg= github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=

88
http2xconnect.go Normal file
View file

@ -0,0 +1,88 @@
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2026 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
import (
"crypto/tls"
"net"
"net/http"
"sync"
"time"
_ "unsafe"
"golang.org/x/net/http2"
)
var enableHTTP2ExtendedConnectOnce sync.Once
//go:linkname xnetDisableHTTP2ExtendedConnectProtocol golang.org/x/net/http2.disableExtendedConnectProtocol
var xnetDisableHTTP2ExtendedConnectProtocol bool
func enableHTTP2ExtendedConnectProtocol() {
enableHTTP2ExtendedConnectOnce.Do(func() {
xnetDisableHTTP2ExtendedConnectProtocol = false
})
}
func configureHTTP2ExtendedConnectServer(srv *http.Server) error {
if srv == nil {
return nil
}
enableHTTP2ExtendedConnectProtocol()
return http2.ConfigureServer(srv, nil)
}
func newHTTP2ExtendedConnectTransport() http.RoundTripper {
enableHTTP2ExtendedConnectProtocol()
transport := cloneDefaultTransport()
transport.Protocols = new(http.Protocols)
transport.Protocols.SetHTTP1(true)
transport.Protocols.SetHTTP2(true)
return transport
}
func newHTTP1BridgeTransport() http.RoundTripper {
return newHTTP1BridgeTransportWithTLSConfig(&tls.Config{NextProtos: []string{"http/1.1"}})
}
func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper {
transport := cloneDefaultTransport()
transport.Protocols = new(http.Protocols)
transport.Protocols.SetHTTP1(true)
if tlsConfig == nil {
transport.TLSClientConfig = &tls.Config{}
} else {
transport.TLSClientConfig = tlsConfig.Clone()
}
if len(transport.TLSClientConfig.NextProtos) == 0 {
transport.TLSClientConfig.NextProtos = []string{"http/1.1"}
}
return transport
}
func newH2CTransport() http.RoundTripper {
transport := cloneDefaultTransport()
transport.Protocols = new(http.Protocols)
transport.Protocols.SetUnencryptedHTTP2(true)
return transport
}
func cloneDefaultTransport() *http.Transport {
if transport, ok := http.DefaultTransport.(*http.Transport); ok {
return transport.Clone()
}
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

150
iox_benchmark_test.go Normal file
View file

@ -0,0 +1,150 @@
package touka
import (
"bytes"
"io"
"testing"
"github.com/WJQSERVER-STUDIO/go-utils/iox"
)
type benchmarkResetReader struct {
data []byte
off int
}
func (r *benchmarkResetReader) Read(p []byte) (int, error) {
if r.off >= len(r.data) {
return 0, io.EOF
}
n := copy(p, r.data[r.off:])
r.off += n
return n, nil
}
func (r *benchmarkResetReader) Reset() {
r.off = 0
}
type benchmarkDiscardWriter struct{}
func (benchmarkDiscardWriter) Write(p []byte) (int, error) {
return len(p), nil
}
var benchmarkIOXResult int64
var benchmarkIOXBytes []byte
func BenchmarkIOXCopyComparison(b *testing.B) {
payload := bytes.Repeat([]byte("0123456789abcdef"), 4096)
b.Run("io.Copy", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
w := benchmarkDiscardWriter{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
n, err := io.Copy(w, r)
if err != nil {
b.Fatalf("io.Copy failed: %v", err)
}
benchmarkIOXResult = n
}
})
b.Run("iox.Copy", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
w := benchmarkDiscardWriter{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
n, err := iox.Copy(w, r)
if err != nil {
b.Fatalf("iox.Copy failed: %v", err)
}
benchmarkIOXResult = n
}
})
}
func BenchmarkIOXCopyBufferComparison(b *testing.B) {
payload := bytes.Repeat([]byte("0123456789abcdef"), 4096)
b.Run("io.CopyBuffer", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
w := benchmarkDiscardWriter{}
buf := make([]byte, 32*1024)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
n, err := io.CopyBuffer(w, r, buf)
if err != nil {
b.Fatalf("io.CopyBuffer failed: %v", err)
}
benchmarkIOXResult = n
}
})
b.Run("iox.CopyBuffer", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
w := benchmarkDiscardWriter{}
buf := make([]byte, 32*1024)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
n, err := iox.CopyBuffer(w, r, buf)
if err != nil {
b.Fatalf("iox.CopyBuffer failed: %v", err)
}
benchmarkIOXResult = n
}
})
}
func BenchmarkIOXReadAllComparison(b *testing.B) {
payload := bytes.Repeat([]byte("0123456789abcdef"), 4096)
b.Run("io.ReadAll", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
data, err := io.ReadAll(r)
if err != nil {
b.Fatalf("io.ReadAll failed: %v", err)
}
benchmarkIOXBytes = data
}
})
b.Run("iox.ReadAll", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
data, err := io.ReadAll(r)
if err != nil {
b.Fatalf("iox.ReadAll failed: %v", err)
}
benchmarkIOXBytes = data
}
})
}

23
logger.go Normal file
View file

@ -0,0 +1,23 @@
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2024 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
// Logger 是日志接口支持多种日志库实现reco、zap、logrus 等)
// 用户可以通过实现此接口来替换默认的日志实现
type Logger interface {
Debugf(format string, args ...any)
Infof(format string, args ...any)
Warnf(format string, args ...any)
Errorf(format string, args ...any)
Fatalf(format string, args ...any)
Panicf(format string, args ...any)
}
// CloserLogger 可选扩展接口,支持关闭操作
// 如果 Logger 实现了此接口Engine 在关闭时会调用 Close()
type CloserLogger interface {
Logger
Close() error
}

View file

@ -39,7 +39,16 @@ func CloseLogger(logger *reco.Logger) {
} }
} }
// CloseLogger 关闭 Engine 的日志实现
// 如果 logger 实现了 CloserLogger 接口,会调用其 Close 方法
func (engine *Engine) CloseLogger() { func (engine *Engine) CloseLogger() {
if cl, ok := engine.logger.(CloserLogger); ok {
if err := cl.Close(); err != nil {
log.Printf("Close Logger Error: %s", err)
}
return
}
// 兼容旧代码
if engine.LogReco != nil { if engine.LogReco != nil {
CloseLogger(engine.LogReco) CloseLogger(engine.LogReco)
} }

View file

@ -23,19 +23,21 @@ type maxBytesReader struct {
n int64 n int64
// read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数. // read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数.
read atomic.Int64 read atomic.Int64
// emptyAtLimit 记录在达到上限后是否已经遇到过一次 0,nil 读.
emptyAtLimit atomic.Bool
} }
// NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据, // NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据,
// 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误. // 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误.
// //
// 如果 r 为 nil, 会 panic. // 如果 r 为 nil, 会 panic.
// 如果 n 小于 0, 则读取不受限制, 直接返回原始的 r. // 如果 n 小于等于 0, 则读取不受限制, 直接返回原始的 r.
func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser { func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
if r == nil { if r == nil {
panic("NewMaxBytesReader called with a nil reader") panic("NewMaxBytesReader called with a nil reader")
} }
// 如果限制为数, 意味着不限制, 直接返回原始的 ReadCloser. // 如果限制为非正数, 意味着不限制, 直接返回原始的 ReadCloser.
if n < 0 { if n <= 0 {
return r return r
} }
return &maxBytesReader{ return &maxBytesReader{
@ -46,48 +48,53 @@ func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
// Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制. // Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制.
func (mbr *maxBytesReader) Read(p []byte) (int, error) { func (mbr *maxBytesReader) Read(p []byte) (int, error) {
// 在函数开始时只加载一次原子变量, 减少后续的原子操作开销. if len(p) == 0 {
readSoFar := mbr.read.Load() return 0, nil
// 快速失败路径: 如果在读取之前就已经达到了限制, 立即返回错误.
if readSoFar >= mbr.n {
return 0, ErrBodyTooLarge
} }
// 计算当前还可以读取多少字节. // 在函数开始时只加载一次原子变量, 减少后续的原子操作开销.
readSoFar := mbr.read.Load()
remaining := mbr.n - readSoFar remaining := mbr.n - readSoFar
if remaining < 0 {
return 0, ErrBodyTooLarge
}
if remaining == 0 {
var probe [1]byte
n, err := mbr.r.Read(probe[:])
if n > 0 {
mbr.read.Add(int64(n))
return 0, ErrBodyTooLarge
}
if err != nil {
return 0, err
}
if mbr.emptyAtLimit.Swap(true) {
return 0, ErrBodyTooLarge
}
return 0, nil
}
mbr.emptyAtLimit.Store(false)
// 如果请求读取的长度大于剩余可读长度, 我们需要限制本次读取的长度. // 最多多读一个字节, 以区分“恰好到上限”和“已经超限”。
// 这样可以保证即使 p 很大, 我们也只读取到恰好达到 maxBytes 的字节数. if int64(len(p))-1 > remaining {
if int64(len(p)) > remaining { p = p[:remaining+1]
p = p[:remaining]
} }
// 从底层 Reader 读取数据. // 从底层 Reader 读取数据.
n, err := mbr.r.Read(p) n, err := mbr.r.Read(p)
// 如果实际读取到了数据, 更新原子计数器. if int64(n) <= remaining {
if n > 0 { if n > 0 {
readSoFar = mbr.read.Add(int64(n)) mbr.read.Add(int64(n))
} }
// 如果底层 Read 返回错误 (例如 io.EOF).
if err != nil {
// 如果是 EOF, 并且我们还没有读满 n 个字节, 这是一个正常的结束.
// 如果已经读满了 n 个字节, 即使是 EOF, 也可以认为成功了.
return n, err return n, err
} }
// 读后检查: 如果这次读取使得总字节数超过了限制, 返回超限错误. // 读取结果跨过了限制,只向上层暴露允许的部分。
// 这是处理"跨越"限制情况的关键. if remaining > 0 {
if readSoFar > mbr.n { mbr.read.Add(remaining)
// 返回实际读取的字节数 n, 并附上超限错误.
// 上层调用者知道已经有 n 字节被读入了缓冲区 p, 但流已因超限而关闭.
return n, ErrBodyTooLarge
} }
return int(remaining), ErrBodyTooLarge
// 一切正常, 返回读取的字节数和 nil 错误.
return n, nil
} }
// Close 方法关闭底层的 ReadCloser, 保证资源释放. // Close 方法关闭底层的 ReadCloser, 保证资源释放.

View file

@ -11,18 +11,16 @@ import (
) )
// mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型. // mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型.
// 嵌入 cancelCtx 作为基础 context, 支持 cause 传播.
// deadlineCtx 作为 cancelCtx 的子 context, 确保 deadline 到期时 cancelCtx 也被取消.
type mergedContext struct { type mergedContext struct {
// 嵌入一个基础 context, 它持有最早的 deadline 和取消信号.
context.Context context.Context
// 保存了所有的父 context, 用于 Value() 方法的查找.
parents []context.Context parents []context.Context
// 用于手动取消此 mergedContext 的函数.
cancel context.CancelFunc
} }
// MergeCtx 创建并返回一个新的 context.Context. // MergeCtx 创建并返回一个新的 context.Context.
// 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时, // 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时,
// 自动被取消 (逻辑或关系). // 自动被取消 (逻辑或关系). 父 context 的取消原因 (cause) 会自动传播到返回的 context.
// //
// 新的 context 会继承: // 新的 context 会继承:
// - Deadline: 所有父 context 中最早的截止时间. // - Deadline: 所有父 context 中最早的截止时间.
@ -32,7 +30,8 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C
return context.WithCancel(context.Background()) return context.WithCancel(context.Background())
} }
if len(parents) == 1 { if len(parents) == 1 {
return context.WithCancel(parents[0]) ctx, cancel := context.WithCancelCause(parents[0])
return ctx, func() { cancel(nil) }
} }
var earliestDeadline time.Time var earliestDeadline time.Time
@ -44,79 +43,93 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C
} }
} }
var baseCtx context.Context // cancelCtx 作为基础 context, 提供 CancelCauseFunc 以支持 cause 传播.
var baseCancel context.CancelFunc cancelCtx, cancelCause := context.WithCancelCause(context.Background())
// deadlineCtx 作为 cancelCtx 的子 context (如果有 deadline).
// 当 cancelCtx 被取消时, deadlineCtx 也会被取消;
// 当 deadline 到期时, deadlineCtx 自行取消, watcher 负责关闭 cancelCtx.
var deadlineCtx context.Context
var deadlineCancel context.CancelFunc
if !earliestDeadline.IsZero() { if !earliestDeadline.IsZero() {
baseCtx, baseCancel = context.WithDeadline(context.Background(), earliestDeadline) deadlineCtx, deadlineCancel = context.WithDeadlineCause(cancelCtx, earliestDeadline, context.DeadlineExceeded)
} else { }
baseCtx, baseCancel = context.WithCancel(context.Background())
// 嵌入的 context: 有 deadline 时用 deadlineCtx (以返回正确的 Deadline),
// 否则用 cancelCtx.
embedCtx := cancelCtx
if deadlineCtx != nil {
embedCtx = deadlineCtx
} }
mc := &mergedContext{ mc := &mergedContext{
Context: baseCtx, Context: embedCtx,
parents: parents, parents: parents,
cancel: baseCancel,
} }
// 启动一个监控 goroutine. // 启动监控 goroutine, 监听 parent 取消或 deadline 到期.
go func() { go func() {
defer mc.cancel() // 将 cancelCtx 加入 orDone, 确保手动 cancel() 时 orDone goroutine 能退出, 防止泄漏.
parentDone := orDone(append(mc.parents, cancelCtx)...)
// orDone 会返回一个 channel, 当任何一个父 context 被取消时, 这个 channel 就会关闭. if deadlineCtx != nil {
// 同时监听 baseCtx.Done() 以便支持手动取消. defer deadlineCancel()
select { select {
case <-orDone(mc.parents...): case <-parentDone:
case <-mc.Context.Done(): // parent 取消或手动 cancel()
for _, p := range mc.parents {
if p.Err() != nil {
cancelCause(context.Cause(p))
return
}
}
// 手动 cancel(), cause 已由 cancelCause() 设置
case <-deadlineCtx.Done():
// deadline 到期, 需要关闭 cancelCtx 并设置 cause
cancelCause(context.DeadlineExceeded)
}
} else {
<-parentDone
for _, p := range mc.parents {
if p.Err() != nil {
cancelCause(context.Cause(p))
return
}
}
} }
}() }()
return mc, mc.cancel return mc, func() { cancelCause(nil) }
} }
// Value 返回当前Ctx Value // Value 返回当前Ctx Value. 先检查嵌入的 context (以支持 context.Cause),
// 再按传入顺序从 parents 中查找.
func (mc *mergedContext) Value(key any) any { func (mc *mergedContext) Value(key any) any {
return mc.Context.Value(key) if v := mc.Context.Value(key); v != nil {
return v
}
for _, p := range mc.parents {
if val := p.Value(key); val != nil {
return val
}
}
return nil
} }
// Deadline 实现了 context.Context 的 Deadline 方法. // Deadline, Done, Err 均由嵌入的 context.Context 提供.
func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) {
return mc.Context.Deadline()
}
// Done 实现了 context.Context 的 Done 方法. // orDone 返回一个 channel, 当任意一个输入 context 的 Done() channel 关闭时关闭.
func (mc *mergedContext) Done() <-chan struct{} {
return mc.Context.Done()
}
// Err 实现了 context.Context 的 Err 方法.
func (mc *mergedContext) Err() error {
return mc.Context.Err()
}
// orDone 是一个辅助函数, 返回一个 channel.
// 当任意一个输入 context 的 Done() channel 关闭时, orDone 返回的 channel 也会关闭.
// 这是一个非阻塞的、不会泄漏 goroutine 的实现.
func orDone(contexts ...context.Context) <-chan struct{} { func orDone(contexts ...context.Context) <-chan struct{} {
done := make(chan struct{}) done := make(chan struct{})
var once sync.Once var once sync.Once
closeDone := func() {
once.Do(func() {
close(done)
})
}
// 为每个父 context 启动一个 goroutine.
for _, ctx := range contexts { for _, ctx := range contexts {
go func(c context.Context) { go func(c context.Context) {
select { select {
case <-c.Done(): case <-c.Done():
closeDone() once.Do(func() { close(done) })
case <-done: case <-done:
// orDone 已经被其他 goroutine 关闭了, 当前 goroutine 可以安全退出.
} }
}(ctx) }(ctx)
} }
return done return done
} }

256
mergectx_test.go Normal file
View file

@ -0,0 +1,256 @@
package touka
import (
"context"
"errors"
"testing"
"time"
)
func TestMergeCtx_NoParents(t *testing.T) {
ctx, cancel := MergeCtx()
defer cancel()
if ctx.Err() != nil {
t.Fatal("expected no error before cancel")
}
cancel()
if ctx.Err() == nil {
t.Fatal("expected error after cancel")
}
}
func TestMergeCtx_SingleParent(t *testing.T) {
parent, parentCancel := context.WithCancel(context.Background())
ctx, cancel := MergeCtx(parent)
defer cancel()
if ctx.Err() != nil {
t.Fatal("expected no error before parent cancel")
}
parentCancel()
<-ctx.Done()
if ctx.Err() == nil {
t.Fatal("expected error after parent cancel")
}
}
func TestMergeCtx_MultipleParents_FirstCancels(t *testing.T) {
p1, cancel1 := context.WithCancel(context.Background())
p2, cancel2 := context.WithCancel(context.Background())
defer cancel2()
ctx, cancel := MergeCtx(p1, p2)
defer cancel()
cancel1()
<-ctx.Done()
if ctx.Err() == nil {
t.Fatal("expected error after p1 cancel")
}
// p2 should still be fine
if p2.Err() != nil {
t.Fatal("expected p2 to be unaffected")
}
}
func TestMergeCtx_MultipleParents_SecondCancels(t *testing.T) {
p1, cancel1 := context.WithCancel(context.Background())
p2, cancel2 := context.WithCancel(context.Background())
defer cancel1()
ctx, cancel := MergeCtx(p1, p2)
defer cancel()
cancel2()
<-ctx.Done()
if ctx.Err() == nil {
t.Fatal("expected error after p2 cancel")
}
}
func TestMergeCtx_ExternalCancel(t *testing.T) {
p1, cancel1 := context.WithCancel(context.Background())
p2, cancel2 := context.WithCancel(context.Background())
defer cancel1()
defer cancel2()
ctx, cancel := MergeCtx(p1, p2)
cancel()
<-ctx.Done()
if ctx.Err() == nil {
t.Fatal("expected error after external cancel")
}
}
func TestMergeCtx_CausePropagation(t *testing.T) {
testErr := errors.New("test cause")
p1, cancel1 := context.WithCancelCause(context.Background())
p2, cancel2 := context.WithCancel(context.Background())
defer cancel2()
ctx, cancel := MergeCtx(p1, p2)
defer cancel()
cancel1(testErr)
<-ctx.Done()
if ctx.Err() == nil {
t.Fatal("expected error after p1 cancel")
}
cause := context.Cause(ctx)
if cause != testErr {
t.Fatalf("expected cause %v, got %v", testErr, cause)
}
cancel1(nil) // cleanup (already cancelled, no-op)
}
func TestMergeCtx_CausePropagation_SecondParent(t *testing.T) {
testErr := errors.New("second parent cause")
p1, cancel1 := context.WithCancel(context.Background())
p2, cancel2 := context.WithCancelCause(context.Background())
ctx, cancel := MergeCtx(p1, p2)
defer cancel()
cancel2(testErr)
<-ctx.Done()
if ctx.Err() == nil {
t.Fatal("expected error after p2 cancel")
}
cause := context.Cause(ctx)
if cause != testErr {
t.Fatalf("expected cause %v, got %v", testErr, cause)
}
cancel1()
}
func TestMergeCtx_Deadline_Earliest(t *testing.T) {
now := time.Now()
early := now.Add(100 * time.Millisecond)
late := now.Add(1 * time.Hour)
p1, cancel1 := context.WithDeadline(context.Background(), late)
p2, cancel2 := context.WithDeadline(context.Background(), early)
defer cancel1()
defer cancel2()
ctx, cancel := MergeCtx(p1, p2)
defer cancel()
dl, ok := ctx.Deadline()
if !ok {
t.Fatal("expected deadline to be set")
}
if !dl.Equal(early) {
t.Fatalf("expected deadline %v, got %v", early, dl)
}
}
func TestMergeCtx_Deadline_Expires(t *testing.T) {
p, cancelP := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancelP()
ctx, cancel := MergeCtx(p)
defer cancel()
<-ctx.Done()
if ctx.Err() == nil {
t.Fatal("expected error after deadline expires")
}
}
func TestMergeCtx_ValueLookup(t *testing.T) {
type key struct{}
p1 := context.WithValue(context.Background(), key{}, "from_p1")
p2 := context.WithValue(context.Background(), key{}, "from_p2")
ctx, cancel := MergeCtx(p1, p2)
defer cancel()
val := ctx.Value(key{})
if val != "from_p1" {
t.Fatalf("expected 'from_p1', got %v", val)
}
}
func TestMergeCtx_ValueLookup_SecondParent(t *testing.T) {
type key1 struct{}
type key2 struct{}
p1 := context.WithValue(context.Background(), key1{}, "val1")
p2 := context.WithValue(context.Background(), key2{}, "val2")
ctx, cancel := MergeCtx(p1, p2)
defer cancel()
if v := ctx.Value(key1{}); v != "val1" {
t.Fatalf("expected 'val1', got %v", v)
}
if v := ctx.Value(key2{}); v != "val2" {
t.Fatalf("expected 'val2', got %v", v)
}
if v := ctx.Value("missing"); v != nil {
t.Fatalf("expected nil, got %v", v)
}
}
func TestMergeCtx_ContextInterface(t *testing.T) {
p1, cancel1 := context.WithCancel(context.Background())
p2, cancel2 := context.WithCancel(context.Background())
defer cancel1()
defer cancel2()
var ctx context.Context
ctx, _ = MergeCtx(p1, p2)
// Verify all Context interface methods work
_ = ctx.Done()
_ = ctx.Err()
_, _ = ctx.Deadline()
_ = ctx.Value("any")
}
func TestOrDone_SingleContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
done := orDone(ctx)
cancel()
<-done // should not block
}
func TestOrDone_MultipleContexts(t *testing.T) {
p1, cancel1 := context.WithCancel(context.Background())
p2, cancel2 := context.WithCancel(context.Background())
defer cancel2()
done := orDone(p1, p2)
cancel1()
<-done // should not block
}
func TestOrDone_SecondContextCancels(t *testing.T) {
p1, cancel1 := context.WithCancel(context.Background())
p2, cancel2 := context.WithCancel(context.Background())
defer cancel1()
done := orDone(p1, p2)
cancel2()
<-done // should not block
}

View file

@ -70,42 +70,25 @@ func TestApplyDefaultServerConfig(t *testing.T) {
} }
} }
func TestRunTLSProtocolInheritance(t *testing.T) { func TestTLSRunDefaultsProtocolInheritance(t *testing.T) {
engine := New() engine := New()
// 模拟 RunTLS 中的逻辑: 如果使用默认协议, 则启用 HTTP/2 srv := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
if engine.useDefaultProtocols {
engine.setProtocols(&ProtocolsConfig{
Http1: true,
Http2: true,
})
}
srv := &http.Server{TLSConfig: &tls.Config{}}
engine.applyDefaultServerConfig(srv)
if !srv.Protocols.HTTP2() { if !srv.Protocols.HTTP2() {
t.Error("RunTLS simulation: Expected HTTP/2 to be enabled for default config") t.Error("TLS run defaults: expected HTTP/2 to be enabled for default config")
} }
// 模拟用户设置了自定义协议后调用 RunTLS // 模拟用户设置了自定义协议后进入 TLS 运行模式
engine = New() engine = New()
engine.SetProtocols(&ProtocolsConfig{ engine.SetProtocols(&ProtocolsConfig{
Http1: true, Http1: true,
Http2: false, // 用户明确不想要 HTTP/2 Http2: false, // 用户明确不想要 HTTP/2
}) })
if engine.useDefaultProtocols { srv2 := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
engine.setProtocols(&ProtocolsConfig{
Http1: true,
Http2: true,
})
}
srv2 := &http.Server{TLSConfig: &tls.Config{}}
engine.applyDefaultServerConfig(srv2)
if srv2.Protocols.HTTP2() { if srv2.Protocols.HTTP2() {
t.Error("RunTLS simulation: Expected HTTP/2 to be DISABLED if user set custom protocols previously") t.Error("TLS run defaults: expected HTTP/2 to remain disabled when user set custom protocols")
} }
} }

View file

@ -113,7 +113,7 @@ func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// 尝试从底层 ResponseWriter 获取 Hijacker 接口 // 尝试从底层 ResponseWriter 获取 Hijacker 接口
hj, ok := rw.ResponseWriter.(http.Hijacker) hj, ok := rw.ResponseWriter.(http.Hijacker)
if !ok { if !ok {
return nil, nil, errors.New("http.Hijacker interface not supported") return nil, nil, http.ErrNotSupported
} }
// 调用底层的 Hijack 方法 // 调用底层的 Hijack 方法

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,355 @@
package touka
import (
"bufio"
"bytes"
"errors"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
)
type benchmarkReadSeeker struct {
data []byte
off int
}
func (r *benchmarkReadSeeker) Read(p []byte) (int, error) {
if r.off >= len(r.data) {
return 0, io.EOF
}
n := copy(p, r.data[r.off:])
r.off += n
return n, nil
}
func (r *benchmarkReadSeeker) Reset() {
r.off = 0
}
type benchmarkResponseWriter struct {
header http.Header
status int
size int
}
func newBenchmarkResponseWriter() *benchmarkResponseWriter {
return &benchmarkResponseWriter{header: make(http.Header)}
}
func (w *benchmarkResponseWriter) Header() http.Header {
return w.header
}
func (w *benchmarkResponseWriter) WriteHeader(statusCode int) {
if w.status == 0 {
w.status = statusCode
}
}
func (w *benchmarkResponseWriter) Write(p []byte) (int, error) {
if w.status == 0 {
w.status = http.StatusOK
}
w.size += len(p)
return len(p), nil
}
func (w *benchmarkResponseWriter) Flush() {}
func (w *benchmarkResponseWriter) Status() int {
return w.status
}
func (w *benchmarkResponseWriter) Size() int {
return w.size
}
func (w *benchmarkResponseWriter) Written() bool {
return w.status != 0
}
func (w *benchmarkResponseWriter) IsHijacked() bool {
return false
}
func (w *benchmarkResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, http.ErrNotSupported
}
func (w *benchmarkResponseWriter) reset() {
clear(w.header)
w.status = 0
w.size = 0
}
var benchmarkReverseProxySink int
func BenchmarkReverseProxyCopyResponse(b *testing.B) {
body := bytes.Repeat([]byte("0123456789abcdef"), 4096)
proxy := newReverseProxyHandler(ReverseProxyConfig{})
dst := newBenchmarkResponseWriter()
src := &benchmarkReadSeeker{data: body}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
dst.reset()
src.Reset()
if err := proxy.copyResponse(dst, src, 0); err != nil {
b.Fatalf("copyResponse failed: %v", err)
}
}
benchmarkReverseProxySink = dst.Size()
}
func BenchmarkReverseProxyAvailableUpstreams(b *testing.B) {
proxy := &reverseProxyHandler{
upstreams: []*reverseProxyUpstream{
{key: "a", index: 0},
{key: "b", index: 1},
{key: "c", index: 2},
{key: "d", index: 3},
},
config: ReverseProxyConfig{
PassiveHealth: ReverseProxyPassiveHealthConfig{
FailDuration: time.Minute,
MaxFails: 3,
},
},
}
now := time.Now()
proxy.upstreams[0].failures = []time.Time{now.Add(-30 * time.Second)}
proxy.upstreams[1].failures = []time.Time{now.Add(-20 * time.Second), now.Add(-10 * time.Second)}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkReverseProxySink = len(proxy.availableUpstreams(now, nil))
}
}
func BenchmarkReverseProxySelectUpstream(b *testing.B) {
proxy := &reverseProxyHandler{
upstreams: []*reverseProxyUpstream{
{key: "a", index: 0},
{key: "b", index: 1},
{key: "c", index: 2},
{key: "d", index: 3},
},
config: ReverseProxyConfig{
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()},
PassiveHealth: ReverseProxyPassiveHealthConfig{
FailDuration: time.Minute,
MaxFails: 3,
},
},
}
proxy.upstreams[0].failures = []time.Time{time.Now().Add(-30 * time.Second)}
c, _ := CreateTestContext(nil)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
selected, err := proxy.selectUpstream(c, nil)
if err != nil {
b.Fatalf("selectUpstream failed: %v", err)
}
benchmarkReverseProxySink = selected.index
}
}
func BenchmarkReverseProxySelectUpstreamHeaderPolicy(b *testing.B) {
proxy := &reverseProxyHandler{
upstreams: []*reverseProxyUpstream{
{key: "a", index: 0},
{key: "b", index: 1},
{key: "c", index: 2},
{key: "d", index: 3},
},
config: ReverseProxyConfig{
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())},
},
}
c, _ := CreateTestContext(nil)
c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b", "tenant-c"}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
selected, err := proxy.selectUpstream(c, nil)
if err != nil {
b.Fatalf("selectUpstream failed: %v", err)
}
benchmarkReverseProxySink = selected.index
}
}
func TestReverseProxyCopyResponseWithoutBufferPool(t *testing.T) {
proxy := newReverseProxyHandler(ReverseProxyConfig{})
dst := newBenchmarkResponseWriter()
src := bytes.NewBufferString("hello, reverse proxy")
if err := proxy.copyResponse(dst, src, 0); err != nil {
t.Fatalf("copyResponse failed: %v", err)
}
if got, want := dst.Size(), len("hello, reverse proxy"); got != want {
t.Fatalf("expected %d bytes copied, got %d", want, got)
}
}
type fixedLenBufferPool struct {
buf []byte
}
func (p *fixedLenBufferPool) Get() []byte {
return p.buf
}
func (p *fixedLenBufferPool) Put(buf []byte) {
p.buf = buf
}
type recordingReader struct {
chunk int
reads []int
left int
}
func (r *recordingReader) Read(p []byte) (int, error) {
if r.left == 0 {
return 0, io.EOF
}
n := min(r.chunk, len(p), r.left)
if n == 0 {
return 0, errors.New("reader received zero-length buffer")
}
for i := range n {
p[i] = 'x'
}
r.left -= n
r.reads = append(r.reads, len(p))
return n, nil
}
func TestReverseProxyCopyResponseRespectsCustomBufferLength(t *testing.T) {
pool := &fixedLenBufferPool{buf: make([]byte, 8, 32*1024)}
proxy := newReverseProxyHandler(ReverseProxyConfig{BufferPool: pool})
dst := newBenchmarkResponseWriter()
src := &recordingReader{chunk: 8, left: 24}
if err := proxy.copyResponse(dst, src, 0); err != nil {
t.Fatalf("copyResponse failed: %v", err)
}
if len(src.reads) == 0 {
t.Fatal("expected reader to be used")
}
for _, size := range src.reads {
if size != 8 {
t.Fatalf("expected custom buffer length 8 to be preserved, got read size %d", size)
}
}
}
func TestReverseProxyAvailableUpstreamsFiltersExcludedAndUnhealthy(t *testing.T) {
now := time.Now()
proxy := &reverseProxyHandler{
upstreams: []*reverseProxyUpstream{
{key: "a"},
{key: "b", failures: []time.Time{now.Add(-20 * time.Second), now.Add(-10 * time.Second)}},
{key: "c"},
},
config: ReverseProxyConfig{
PassiveHealth: ReverseProxyPassiveHealthConfig{
FailDuration: time.Minute,
MaxFails: 2,
},
},
}
available := proxy.availableUpstreams(now, map[string]struct{}{"c": {}})
if len(available) != 1 {
t.Fatalf("expected only one available upstream, got %d", len(available))
}
if available[0].key != "a" {
t.Fatalf("expected upstream 'a', got %q", available[0].key)
}
}
func TestReverseProxyHeaderPolicyUsesAllHeaderValues(t *testing.T) {
proxy := &reverseProxyHandler{
upstreams: []*reverseProxyUpstream{
{key: "a", index: 0},
{key: "b", index: 1},
{key: "c", index: 2},
},
config: ReverseProxyConfig{
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())},
},
}
c, _ := CreateTestContext(nil)
c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b"}
selectedA, err := proxy.selectUpstream(c, nil)
if err != nil {
t.Fatalf("selectUpstream failed: %v", err)
}
selectedB, err := proxy.selectUpstream(c, nil)
if err != nil {
t.Fatalf("selectUpstream failed: %v", err)
}
if selectedA.key != selectedB.key {
t.Fatalf("expected stable selection for identical multi-value header, got %q and %q", selectedA.key, selectedB.key)
}
c.Request.Header["X-Tenant"] = []string{"tenant-b", "tenant-a"}
selectedC, err := proxy.selectUpstream(c, nil)
if err != nil {
t.Fatalf("selectUpstream failed: %v", err)
}
if selectedC == nil {
t.Fatal("expected upstream for reordered multi-value header")
}
}
func TestReverseProxyHeaderPolicyMatchesJoinCompatibility(t *testing.T) {
candidates := []*reverseProxyUpstream{
{key: "a", index: 0},
{key: "b", index: 1},
{key: "c", index: 2},
}
testCases := [][]string{
{"tenant-a"},
{"tenant-a", "tenant-b"},
{"", "tenant-b"},
{"tenant-a", ""},
{"", ""},
}
for _, values := range testCases {
got := reverseProxySelectHRWValues(candidates, values)
want := reverseProxySelectHRW(candidates, strings.Join(values, ","))
if got == nil || want == nil {
t.Fatalf("expected non-nil upstreams for values %v", values)
}
if got.key != want.key {
t.Fatalf("expected joined compatibility for values %v, got %q want %q", values, got.key, want.key)
}
}
}
var _ io.Writer = (*benchmarkResponseWriter)(nil)

View file

@ -0,0 +1,530 @@
package touka
import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"testing"
)
func TestReverseProxyHeaderOpsReplaceSubstring(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("X-Server"); got != "Caddy" {
t.Errorf("expected X-Server=Caddy, got %q", got)
}
if got := r.Header.Get("X-Location"); got != "/api/v2/resource" {
t.Errorf("expected X-Location=/api/v2/resource, got %q", got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Replace: map[string][]Replacement{
"X-Server": {{Search: "NGINX", Replace: "Caddy"}},
"X-Location": {{Search: "v1", Replace: "v2"}},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
req.Header.Set("X-Server", "NGINX")
req.Header.Set("X-Location", "/api/v1/resource")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestReverseProxyHeaderOpsReplaceRegexp(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("X-Route"); got != "/proxy-upstream" {
t.Errorf("expected X-Route=/proxy-upstream, got %q", got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Replace: map[string][]Replacement{
"X-Route": {{SearchRegexp: `^/([^/]+)/(.+)$`, Replace: "/proxy-$2"}},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
req.Header.Set("X-Route", "/original/upstream")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestReverseProxyHeaderOpsReplaceWildcard(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("X-Host-A"); got != "new.example.com" {
t.Errorf("expected X-Host-A=new.example.com, got %q", got)
}
if got := r.Header.Get("X-Host-B"); got != "new.example.com" {
t.Errorf("expected X-Host-B=new.example.com, got %q", got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Replace: map[string][]Replacement{
"*": {{Search: "old.example.com", Replace: "new.example.com"}},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
req.Header.Set("X-Host-A", "old.example.com")
req.Header.Set("X-Host-B", "old.example.com")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestReverseProxyHeaderOpsReplaceResponse(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Backend", "backend-internal:8080")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
ResponseHeaders: &RespHeaderOps{
HeaderOps: &HeaderOps{
Replace: map[string][]Replacement{
"X-Backend": {{Search: "backend-internal:8080", Replace: "public.example.com"}},
},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
resp, err := http.Get(proxy.URL + "/test")
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if got := resp.Header.Get("X-Backend"); got != "public.example.com" {
t.Errorf("expected X-Backend=public.example.com, got %q", got)
}
}
func TestReverseProxyHeaderOpsProvisionInvalidRegexp(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Replace: map[string][]Replacement{
"X-Test": {{SearchRegexp: "[invalid"}},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusInternalServerError {
t.Errorf("expected status 500, got %d", resp.StatusCode)
}
}
func TestReplacementApply(t *testing.T) {
tests := []struct {
name string
r *Replacement
s string
want string
}{
{name: "nil replacement", r: nil, s: "hello", want: "hello"},
{name: "empty string", r: &Replacement{Search: "x", Replace: "y"}, s: "", want: ""},
{name: "substring match", r: &Replacement{Search: "world", Replace: "go"}, s: "hello world", want: "hello go"},
{name: "substring no match", r: &Replacement{Search: "foo", Replace: "bar"}, s: "hello world", want: "hello world"},
{name: "substring multiple", r: &Replacement{Search: "a", Replace: "b"}, s: "aaa", want: "bbb"},
{name: "regexp match", r: &Replacement{SearchRegexp: `\d+`, Replace: "N", re: regexp.MustCompile(`\d+`)}, s: "abc123def", want: "abcNdef"},
{name: "regexp no match", r: &Replacement{SearchRegexp: `z+`, Replace: "Z", re: regexp.MustCompile(`z+`)}, s: "abc", want: "abc"},
{name: "empty search and regexp", r: &Replacement{}, s: "unchanged", want: "unchanged"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.r.apply(tt.s); got != tt.want {
t.Errorf("Replacement.apply() = %q, want %q", got, tt.want)
}
})
}
}
func BenchmarkHeaderOpsAdd(b *testing.B) {
ops := &HeaderOps{
Add: map[string][]string{
"X-Custom-1": {"value-1"},
"X-Custom-2": {"value-2"},
"X-Custom-3": {"value-3"},
},
}
hdr := make(http.Header)
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr = make(http.Header)
ops.applyTo(hdr, repl)
}
}
func BenchmarkHeaderOpsSet(b *testing.B) {
ops := &HeaderOps{
Set: map[string][]string{
"X-Frame-Options": {"DENY"},
"X-Content-Type-Options": {"nosniff"},
"X-XSS-Protection": {"1; mode=block"},
},
}
hdr := make(http.Header)
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr = make(http.Header)
ops.applyTo(hdr, repl)
}
}
func BenchmarkHeaderOpsDeleteSingle(b *testing.B) {
ops := &HeaderOps{
Delete: []string{"X-Powered-By"},
}
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr := make(http.Header)
hdr.Set("X-Powered-By", "Express")
hdr.Set("X-Keep", "value")
ops.applyTo(hdr, repl)
}
}
func BenchmarkHeaderOpsDeleteWildcard(b *testing.B) {
ops := &HeaderOps{
Delete: []string{"X-Debug-*"},
}
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr := make(http.Header)
hdr.Set("X-Debug-1", "v1")
hdr.Set("X-Debug-2", "v2")
hdr.Set("X-Keep", "value")
ops.applyTo(hdr, repl)
}
}
func BenchmarkHeaderOpsReplaceSubstring(b *testing.B) {
ops := &HeaderOps{
Replace: map[string][]Replacement{
"Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}},
},
}
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr := make(http.Header)
hdr.Set("Location", "http://internal:8080/api/v1/users")
ops.applyTo(hdr, repl)
}
}
func BenchmarkHeaderOpsReplaceRegexp(b *testing.B) {
re := regexp.MustCompile(`^http://([^/]+)(/.*)$`)
ops := &HeaderOps{
Replace: map[string][]Replacement{
"Location": {{SearchRegexp: `^http://([^/]+)(/.*)$`, Replace: "https://public.example.com$2", re: re}},
},
}
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr := make(http.Header)
hdr.Set("Location", "http://internal:8080/api/v1/users")
ops.applyTo(hdr, repl)
}
}
func BenchmarkHeaderOpsReplaceWildcard(b *testing.B) {
ops := &HeaderOps{
Replace: map[string][]Replacement{
"*": {{Search: "internal.example.com", Replace: "public.example.com"}},
},
}
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr := make(http.Header)
hdr.Set("X-Host", "internal.example.com")
hdr.Set("X-Origin", "internal.example.com")
ops.applyTo(hdr, repl)
}
}
func BenchmarkHeaderOpsMixed(b *testing.B) {
ops := &HeaderOps{
Add: map[string][]string{
"X-Request-ID": {"req-123"},
},
Set: map[string][]string{
"X-Frame-Options": {"DENY"},
},
Delete: []string{"X-Powered-By"},
Replace: map[string][]Replacement{
"Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}},
},
}
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr := make(http.Header)
hdr.Set("X-Powered-By", "Express")
hdr.Set("Location", "http://internal:8080/api")
ops.applyTo(hdr, repl)
}
}
func BenchmarkReplacementApplySubstring(b *testing.B) {
r := &Replacement{Search: "old.example.com", Replace: "new.example.com"}
s := "https://old.example.com/api/v1/resource"
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = r.apply(s)
}
}
func BenchmarkReplacementApplyRegexp(b *testing.B) {
r := &Replacement{SearchRegexp: `^https?://[^/]+`, Replace: "https://new.example.com", re: regexp.MustCompile(`^https?://[^/]+`)}
s := "https://old.example.com/api/v1/resource"
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = r.apply(s)
}
}
func TestReverseProxyReplacerDynamicVars(t *testing.T) {
req, _ := http.NewRequest(http.MethodGet, "http://example.com/api/v1/users?sort=name&limit=10", nil)
req.Host = "example.com"
repl := newReverseProxyReplacer(req)
tests := []struct {
name string
input string
want string
}{
{"method", "{method}", "GET"},
{"host", "{host}", "example.com"},
{"path", "{path}", "/api/v1/users"},
{"query", "{query}", "sort=name&limit=10"},
{"scheme", "{scheme}", "http"},
{"proto", "{proto}", "HTTP/1.1"},
{"combined", "X-{method}-{path}", "X-GET-/api/v1/users"},
{"no vars", "static-value", "static-value"},
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := repl.Replace(tt.input); got != tt.want {
t.Errorf("Replace(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestReverseProxyReplacerNilRequest(t *testing.T) {
repl := newReverseProxyReplacer(nil)
if got := repl.Replace("{method}"); got != "{method}" {
t.Errorf("expected unchanged string with nil request, got %q", got)
}
}
func TestReverseProxyReplacerNilReplacer(t *testing.T) {
var repl *reverseProxyReplacer
if got := repl.Replace("{method}"); got != "{method}" {
t.Errorf("expected unchanged string with nil replacer, got %q", got)
}
}
func TestReverseProxyReplacerFromHeader(t *testing.T) {
hdr := make(http.Header)
repl := newReverseProxyReplacerFromHeader(hdr)
if got := repl.Replace("{method}"); got != "{method}" {
t.Errorf("expected unchanged string from header replacer, got %q", got)
}
}
func TestReverseProxyHeaderOpsWithDynamicVars(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("X-Forwarded-Path"); got != "/dynamic/path" {
t.Errorf("expected X-Forwarded-Path=/dynamic/path, got %q", got)
}
if got := r.Header.Get("X-Forwarded-Method"); got != "GET" {
t.Errorf("expected X-Forwarded-Method=GET, got %q", got)
}
if got := r.Header.Get("X-Forwarded-Host"); got != "client.example" {
t.Errorf("expected X-Forwarded-Host=client.example, got %q", got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/dynamic/path", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Add: map[string][]string{
"X-Forwarded-Path": {"{path}"},
"X-Forwarded-Method": {"{method}"},
"X-Forwarded-Host": {"{host}"},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/dynamic/path", nil)
req.Host = "client.example"
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}

View file

@ -0,0 +1,220 @@
package touka
import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
func TestReverseProxyHeaderOpsAdd(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("X-Custom-Header"); got != "test-value" {
t.Errorf("expected X-Custom-Header=test-value, got %q", got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Add: map[string][]string{
"X-Custom-Header": {"test-value"},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
resp, err := http.Get(proxy.URL + "/test")
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestReverseProxyHeaderOpsDelete(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Sensitive") != "" {
t.Errorf("expected X-Sensitive header to be deleted, got %q", r.Header.Get("X-Sensitive"))
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Delete: []string{"X-Sensitive"},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
req.Header.Set("X-Sensitive", "should-be-removed")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestReverseProxyHeaderOpsSet(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got := r.Header.Get("X-Replace")
if got != "new-value" {
t.Errorf("expected X-Replace=new-value, got %q", got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Set: map[string][]string{
"X-Replace": {"new-value"},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
req.Header.Set("X-Replace", "old-value")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestReverseProxyResponseHeaderOps(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Backend", "backend-server")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
ResponseHeaders: &RespHeaderOps{
HeaderOps: &HeaderOps{
Set: map[string][]string{
"X-Custom": {"custom-value"},
},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
resp, err := http.Get(proxy.URL + "/test")
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if got := resp.Header.Get("X-Custom"); got != "custom-value" {
t.Errorf("expected X-Custom=custom-value, got %q", got)
}
}
func TestReverseProxyResponseHeaderOpsDelete(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Powered-By", "Express")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
ResponseHeaders: &RespHeaderOps{
HeaderOps: &HeaderOps{
Delete: []string{"X-Powered-By"},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
resp, err := http.Get(proxy.URL + "/test")
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if got := resp.Header.Get("X-Powered-By"); got != "" {
t.Errorf("expected X-Powered-By to be deleted, got %q", got)
}
}

409
reverseproxy_lb.go Normal file
View file

@ -0,0 +1,409 @@
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2026 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
import (
"fmt"
"math/rand/v2"
"net/http"
"net/textproto"
"net/url"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
)
// ReverseProxyLoadBalancingConfig configures upstream selection and retries.
type ReverseProxyLoadBalancingConfig struct {
Policy ReverseProxyLBPolicy
Retries int
TryDuration time.Duration
TryInterval time.Duration
}
// ReverseProxyPassiveHealthConfig configures inline passive health tracking.
type ReverseProxyPassiveHealthConfig struct {
FailDuration time.Duration
MaxFails int
UnhealthyStatus []int
}
// ReverseProxyLBPolicy selects an upstream from the configured target pool.
// Use the helper constructors such as LBRandom or LBHeader to build a policy.
type ReverseProxyLBPolicy struct {
kind reverseProxyLBPolicyKind
key string
fallback *ReverseProxyLBPolicy
}
type reverseProxyLBPolicyKind uint8
const (
reverseProxyLBPolicyRandom reverseProxyLBPolicyKind = iota
reverseProxyLBPolicyRoundRobin
reverseProxyLBPolicyFirst
reverseProxyLBPolicyLeastConn
reverseProxyLBPolicyIPHash
reverseProxyLBPolicyClientIPHash
reverseProxyLBPolicyURIHash
reverseProxyLBPolicyHeader
reverseProxyLBPolicyQuery
)
type reverseProxyUpstream struct {
key string
target *url.URL
index int
useH2C bool
extendedConnectTransport http.RoundTripper
bridgeTransport http.RoundTripper
h2cTransport http.RoundTripper
inFlight atomic.Int64
passiveMu sync.Mutex
failures []time.Time
}
func LBRandom() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRandom}
}
func LBRoundRobin() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRoundRobin}
}
func LBFirst() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyFirst}
}
func LBLeastConn() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyLeastConn}
}
func LBIPHash() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyIPHash}
}
func LBClientIPHash() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyClientIPHash}
}
func LBURIHash() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyURIHash}
}
func LBHeader(field string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyHeader, key: textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(field))}
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
policy.fallback = &fallback
}
return policy
}
func LBQuery(key string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyQuery, key: strings.TrimSpace(key)}
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
policy.fallback = &fallback
}
return policy
}
func validateReverseProxyLBPolicy(policy ReverseProxyLBPolicy) error {
switch policy.kind {
case reverseProxyLBPolicyRandom, reverseProxyLBPolicyRoundRobin, reverseProxyLBPolicyFirst,
reverseProxyLBPolicyLeastConn, reverseProxyLBPolicyIPHash, reverseProxyLBPolicyClientIPHash,
reverseProxyLBPolicyURIHash:
return nil
case reverseProxyLBPolicyHeader:
if policy.key == "" {
return fmt.Errorf("reverse proxy header load-balancing policy requires a header field")
}
case reverseProxyLBPolicyQuery:
if policy.key == "" {
return fmt.Errorf("reverse proxy query load-balancing policy requires a query key")
}
default:
return fmt.Errorf("reverse proxy load-balancing policy is invalid")
}
if policy.fallback != nil {
return validateReverseProxyLBPolicy(*policy.fallback)
}
return nil
}
func (p *reverseProxyHandler) selectUpstream(c *Context, excluded map[string]struct{}) (*reverseProxyUpstream, error) {
now := time.Now()
policy := p.config.LoadBalancing.Policy
candidateBuf := reverseProxyCandidatePool.Get().(*[]*reverseProxyUpstream)
candidates := p.availableUpstreamsInto(now, excluded, *candidateBuf)
if len(candidates) == 0 && len(excluded) > 0 {
candidates = p.availableUpstreamsInto(now, nil, candidates[:0])
}
if len(candidates) == 0 {
*candidateBuf = candidates[:0]
reverseProxyCandidatePool.Put(candidateBuf)
return nil, errReverseProxyNoAvailableUpstreams
}
selected := p.selectUpstreamWithPolicy(c, candidates, policy)
*candidateBuf = candidates[:0]
reverseProxyCandidatePool.Put(candidateBuf)
return selected, nil
}
func (p *reverseProxyHandler) availableUpstreams(now time.Time, excluded map[string]struct{}) []*reverseProxyUpstream {
return p.availableUpstreamsInto(now, excluded, nil)
}
func (p *reverseProxyHandler) availableUpstreamsInto(now time.Time, excluded map[string]struct{}, candidates []*reverseProxyUpstream) []*reverseProxyUpstream {
if cap(candidates) < len(p.upstreams) {
candidates = make([]*reverseProxyUpstream, 0, len(p.upstreams))
} else {
candidates = candidates[:0]
}
for _, upstream := range p.upstreams {
if _, skip := excluded[upstream.key]; skip {
continue
}
if !upstream.healthy(now, p.config.PassiveHealth) {
continue
}
candidates = append(candidates, upstream)
}
return candidates
}
func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates []*reverseProxyUpstream, policy ReverseProxyLBPolicy) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
switch policy.kind {
case reverseProxyLBPolicyRoundRobin:
return candidates[p.nextRoundRobinIndex(len(candidates))]
case reverseProxyLBPolicyFirst:
return candidates[0]
case reverseProxyLBPolicyLeastConn:
return p.selectLeastConnUpstream(candidates)
case reverseProxyLBPolicyIPHash:
return reverseProxySelectHRW(candidates, reverseProxyClientIP(c.Request.RemoteAddr))
case reverseProxyLBPolicyClientIPHash:
return reverseProxySelectHRW(candidates, c.RequestIP())
case reverseProxyLBPolicyURIHash:
if c.Request == nil || c.Request.URL == nil {
return reverseProxySelectRandom(candidates)
}
return reverseProxySelectHRW(candidates, c.Request.URL.RequestURI())
case reverseProxyLBPolicyHeader:
if c.Request != nil && c.Request.Header != nil {
if values, ok := c.Request.Header[policy.key]; ok {
return reverseProxySelectHRWValues(candidates, values)
}
}
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
case reverseProxyLBPolicyQuery:
if c.Request != nil && c.Request.URL != nil {
if values, ok := c.Request.URL.Query()[policy.key]; ok {
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
}
}
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
case reverseProxyLBPolicyRandom:
fallthrough
default:
return reverseProxySelectRandom(candidates)
}
}
func (p *reverseProxyHandler) nextRoundRobinIndex(size int) int {
if size <= 1 {
return 0
}
return int((p.roundRobin.Add(1) - 1) % uint64(size))
}
func (p *reverseProxyHandler) selectLeastConnUpstream(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
selected := candidates[0]
lowest := selected.inFlight.Load()
ties := []*reverseProxyUpstream{selected}
for _, upstream := range candidates[1:] {
count := upstream.inFlight.Load()
switch {
case count < lowest:
selected = upstream
lowest = count
ties = []*reverseProxyUpstream{upstream}
case count == lowest:
ties = append(ties, upstream)
}
}
if len(ties) == 1 {
return selected
}
return ties[p.nextRoundRobinIndex(len(ties))]
}
func reverseProxySelectRandom(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
if len(candidates) == 1 {
return candidates[0]
}
return candidates[rand.IntN(len(candidates))]
}
func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
if key == "" {
return reverseProxySelectRandom(candidates)
}
selected := candidates[0]
bestScore := reverseProxyHRWScore(key, selected.key)
for _, upstream := range candidates[1:] {
score := reverseProxyHRWScore(key, upstream.key)
if score > bestScore {
selected = upstream
bestScore = score
}
}
return selected
}
func reverseProxySelectHRWValues(candidates []*reverseProxyUpstream, values []string) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
if len(values) == 0 {
return reverseProxySelectRandom(candidates)
}
selected := candidates[0]
bestScore := reverseProxyHRWValuesScore(values, selected.key)
for _, upstream := range candidates[1:] {
score := reverseProxyHRWValuesScore(values, upstream.key)
if score > bestScore {
selected = upstream
bestScore = score
}
}
return selected
}
func reverseProxyHRWScore(key, upstreamKey string) uint64 {
const (
offset64 = 14695981039346656037
prime64 = 1099511628211
)
h := uint64(offset64)
for i := 0; i < len(key); i++ {
h ^= uint64(key[i])
h *= prime64
}
h ^= 0xff
h *= prime64
for i := 0; i < len(upstreamKey); i++ {
h ^= uint64(upstreamKey[i])
h *= prime64
}
return h
}
func reverseProxyHRWValuesScore(values []string, upstreamKey string) uint64 {
const (
offset64 = 14695981039346656037
prime64 = 1099511628211
)
h := uint64(offset64)
for valueIndex, value := range values {
for i := 0; i < len(value); i++ {
h ^= uint64(value[i])
h *= prime64
}
if valueIndex+1 < len(values) {
h ^= ','
h *= prime64
}
}
h ^= 0xff
h *= prime64
for i := 0; i < len(upstreamKey); i++ {
h ^= uint64(upstreamKey[i])
h *= prime64
}
return h
}
func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy {
if policy.fallback != nil {
return *policy.fallback
}
return LBRandom()
}
func (u *reverseProxyUpstream) healthy(now time.Time, config ReverseProxyPassiveHealthConfig) bool {
maxFails := reverseProxyPassiveMaxFails(config)
if config.FailDuration <= 0 || maxFails <= 0 {
return true
}
u.passiveMu.Lock()
defer u.passiveMu.Unlock()
u.pruneFailuresLocked(now, config.FailDuration)
return len(u.failures) < maxFails
}
func (u *reverseProxyUpstream) recordFailure(now time.Time, config ReverseProxyPassiveHealthConfig) {
maxFails := reverseProxyPassiveMaxFails(config)
if config.FailDuration <= 0 || maxFails <= 0 {
return
}
u.passiveMu.Lock()
defer u.passiveMu.Unlock()
u.pruneFailuresLocked(now, config.FailDuration)
u.failures = append(u.failures, now)
}
func (u *reverseProxyUpstream) pruneFailuresLocked(now time.Time, window time.Duration) {
if len(u.failures) == 0 || window <= 0 {
if window <= 0 {
u.failures = nil
}
return
}
cutoff := now.Add(-window)
keep := 0
for _, failureAt := range u.failures {
if failureAt.Before(cutoff) {
continue
}
u.failures[keep] = failureAt
keep++
}
u.failures = u.failures[:keep]
}
func reverseProxyPassiveMaxFails(config ReverseProxyPassiveHealthConfig) int {
if config.FailDuration <= 0 {
return 0
}
if config.MaxFails <= 0 {
return 1
}
return config.MaxFails
}
func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, status int) bool {
if status <= 0 {
return false
}
return slices.Contains(config.UnhealthyStatus, status)
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,130 @@
package touka
import "testing"
var (
benchmarkRouteHandlers HandlersChain
benchmarkRouteFullPath string
benchmarkRouteParamsLen int
benchmarkRouteCIPath []byte
benchmarkRouteCIFound bool
)
func buildRouteMatchBenchmarkTree() *node {
tree := &node{}
routes := []string{
"/",
"/health",
"/contact",
"/api/v1/users",
"/api/v1/users/:id",
"/api/v1/users/:id/settings",
"/assets/*filepath",
"/abc/b",
"/abc/:p1/cde",
"/abc/:p1/:p2/def/*filepath",
}
for _, route := range routes {
tree.addRoute(route, fakeHandler(route))
}
return tree
}
func benchmarkRouteLookup(b *testing.B, tree *node, path string, wantFullPath string) {
b.Helper()
params := make(Params, 0, 4)
skipped := make([]skippedNode, 0, 8)
value := tree.getValue(path, &params, &skipped, true)
if wantFullPath == "" {
if value.handlers != nil {
b.Fatalf("expected no match for %q, got %q", path, value.fullPath)
}
} else {
if value.handlers == nil {
b.Fatalf("expected match for %q, got nil handlers", path)
}
if value.fullPath != wantFullPath {
b.Fatalf("expected full path %q for %q, got %q", wantFullPath, path, value.fullPath)
}
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
params = params[:0]
skipped = skipped[:0]
value = tree.getValue(path, &params, &skipped, true)
}
benchmarkRouteHandlers = value.handlers
benchmarkRouteFullPath = value.fullPath
if value.params != nil {
benchmarkRouteParamsLen = len(*value.params)
} else {
benchmarkRouteParamsLen = 0
}
}
func BenchmarkRouteMatch(b *testing.B) {
tree := buildRouteMatchBenchmarkTree()
b.Run("StaticHit", func(b *testing.B) {
benchmarkRouteLookup(b, tree, "/api/v1/users", "/api/v1/users")
})
b.Run("ParamHit", func(b *testing.B) {
benchmarkRouteLookup(b, tree, "/api/v1/users/123", "/api/v1/users/:id")
})
b.Run("BacktrackingHit", func(b *testing.B) {
benchmarkRouteLookup(b, tree, "/abc/b/d/def/some/file.txt", "/abc/:p1/:p2/def/*filepath")
})
b.Run("Miss", func(b *testing.B) {
benchmarkRouteLookup(b, tree, "/does/not/exist", "")
})
b.Run("CaseInsensitiveHit", func(b *testing.B) {
path := "/API/V1/USERS/123/SETTINGS"
out, found := tree.findCaseInsensitivePath(path, true)
if !found {
b.Fatalf("expected fixed-path match for %q", path)
}
if got := string(out); got != "/api/v1/users/123/settings" {
b.Fatalf("expected fixed-path result %q, got %q", "/api/v1/users/123/settings", got)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
out, found = tree.findCaseInsensitivePath(path, true)
}
benchmarkRouteCIPath = out
benchmarkRouteCIFound = found
})
b.Run("CaseInsensitiveMiss", func(b *testing.B) {
path := "/DOES/NOT/EXIST"
out, found := tree.findCaseInsensitivePath(path, true)
if found || out != nil {
b.Fatalf("expected no fixed-path match for %q, got %q, %t", path, string(out), found)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
out, found = tree.findCaseInsensitivePath(path, true)
}
benchmarkRouteCIPath = out
benchmarkRouteCIFound = found
})
}

757
serve.go
View file

@ -14,6 +14,7 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"strings"
"sync" "sync"
"syscall" "syscall"
"time" "time"
@ -21,329 +22,322 @@ import (
"github.com/fenthope/reco" "github.com/fenthope/reco"
) )
// defaultShutdownTimeout 定义了在强制关闭前等待优雅关闭的最长时间
const defaultShutdownTimeout = 5 * time.Second const defaultShutdownTimeout = 5 * time.Second
// --- 内部辅助函数 --- type runMode uint8
// resolveAddress 解析传入的地址参数,如果没有则返回默认的 ":8080" const (
func resolveAddress(addr []string) string { runModeHTTP runMode = iota
switch len(addr) { runModeHTTPS
case 0: runModeHTTPSRedirect
return ":8080" )
case 1:
return addr[0] type runConfig struct {
default: addr string
panic("too many parameters provided for server address") httpRedirectAddr string
tlsConfig *tls.Config
redirectHost string
redirectHostHeaders []string
useHeaderHost bool
useHeaderHostSet bool
graceful bool
shutdownTimeout time.Duration
gracefulCtx context.Context
mode runMode
shutdownDefaultSet bool
shutdownTimeoutSet bool
}
type RunOption interface {
apply(*runConfig) error
}
type runOptionFunc func(*runConfig) error
func (f runOptionFunc) apply(cfg *runConfig) error {
return f(cfg)
}
func defaultRunConfig() runConfig {
return runConfig{
addr: ":8080",
shutdownTimeout: defaultShutdownTimeout,
mode: runModeHTTP,
useHeaderHost: true,
} }
} }
// getShutdownTimeout 解析可选的超时参数,如果无效或未提供则返回默认值 type HTTPRedirectOption interface {
func getShutdownTimeout(timeouts []time.Duration) time.Duration { applyRedirect(*runConfig) error
if len(timeouts) > 0 && timeouts[0] > 0 {
return timeouts[0]
}
return defaultShutdownTimeout
} }
// runServer 是一个内部辅助函数,负责在一个新的 goroutine 中启动一个 http.Server, type redirectOptionFunc func(*runConfig) error
// 并处理其启动失败的致命错误
// serverType 用于在日志中标识服务器类型 (例如 "HTTP", "HTTPS") func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error {
func runServer(serverType string, srv *http.Server) { return f(cfg)
}
func WithAddr(addr string) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
if addr == "" {
return errors.New("run address must not be empty")
}
cfg.addr = addr
return nil
})
}
func WithTLS(tlsConfig *tls.Config) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
if tlsConfig == nil {
return errors.New("tls.Config must not be nil")
}
cfg.tlsConfig = tlsConfig
if cfg.mode == runModeHTTP {
cfg.mode = runModeHTTPS
}
return nil
})
}
func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
if addr == "" {
return errors.New("http redirect address must not be empty")
}
cfg.httpRedirectAddr = addr
cfg.mode = runModeHTTPSRedirect
for _, opt := range opts {
if opt == nil {
continue
}
if err := opt.applyRedirect(cfg); err != nil {
return err
}
}
return nil
})
}
func WithUseHeaderHost(enabled bool) HTTPRedirectOption {
return redirectOptionFunc(func(cfg *runConfig) error {
cfg.useHeaderHost = enabled
cfg.useHeaderHostSet = true
return nil
})
}
func WithRedirectHost(host string) HTTPRedirectOption {
return redirectOptionFunc(func(cfg *runConfig) error {
if host == "" {
return errors.New("redirect host must not be empty")
}
cfg.redirectHost = host
return nil
})
}
func WithRedirectHostHeaders(headers []string) HTTPRedirectOption {
return redirectOptionFunc(func(cfg *runConfig) error {
cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0]
for _, header := range headers {
trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header))
if trimmed != "" {
cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed)
}
}
return nil
})
}
func WithGracefulShutdown(timeout time.Duration) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
cfg.graceful = true
cfg.shutdownTimeoutSet = true
if timeout > 0 {
cfg.shutdownTimeout = timeout
} else {
cfg.shutdownTimeout = defaultShutdownTimeout
}
return nil
})
}
func WithGracefulShutdownDefault() RunOption {
return runOptionFunc(func(cfg *runConfig) error {
cfg.graceful = true
cfg.shutdownDefaultSet = true
cfg.shutdownTimeout = defaultShutdownTimeout
return nil
})
}
func WithShutdownContext(ctx context.Context) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
if ctx == nil {
return errors.New("shutdown context must not be nil")
}
cfg.gracefulCtx = ctx
return nil
})
}
func serveServer(srv *http.Server, serveTLS bool) error {
if serveTLS {
return srv.ListenAndServeTLS("", "")
}
return srv.ListenAndServe()
}
func runServer(serverType string, srv *http.Server, serveTLS bool) {
go func() { go func() {
var err error
protocol := "http" protocol := "http"
if srv.TLSConfig != nil { if serveTLS {
protocol = "https" protocol = "https"
} }
log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr) log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr)
if srv.TLSConfig != nil { err := serveServer(srv, serveTLS)
// 对于 HTTPS 服务器,如果 srv.TLSConfig.Certificates 已配置,
// ListenAndServeTLS 的前两个参数可以为空字符串
err = srv.ListenAndServeTLS("", "")
} else {
err = srv.ListenAndServe()
}
// 如果服务器停止不是因为被优雅关闭 (http.ErrServerClosed),
// 则认为是一个严重错误,并终止程序
if err != nil && !errors.Is(err, http.ErrServerClosed) { if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Touka %s server failed: %v", serverType, err) log.Fatalf("Touka %s server failed: %v", serverType, err)
} }
}() }()
} }
// handleGracefulShutdown 监听系统信号 (SIGINT, SIGTERM) 并优雅地关闭所有提供的服务器 func cloneTLSConfig(tlsConfig *tls.Config) *tls.Config {
// 这是所有支持优雅关闭的 RunXXX 方法的最终归宿
func handleGracefulShutdown(servers []*http.Server, timeout time.Duration, logger *reco.Logger) error {
// 创建一个 channel 来接收操作系统信号
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号
<-quit // 阻塞,直到接收到上述信号之一
log.Println("Shutting down Touka server(s)...")
// 关闭日志记录器
if logger != nil {
go func() {
log.Println("Closing Touka logger...")
CloseLogger(logger)
}()
}
// 创建一个带超时的上下文,用于 Shutdown
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
var wg sync.WaitGroup
errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel
// 并发地关闭所有服务器
for _, srv := range servers {
wg.Add(1)
go func(s *http.Server) {
defer wg.Done()
if err := s.Shutdown(ctx); err != nil {
// 将错误发送到 channel
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
}
}(srv)
}
wg.Wait() // 等待所有服务器的关闭 goroutine 完成
close(errChan) // 关闭 channel,以便可以安全地遍历它
// 收集所有关闭过程中发生的错误
var shutdownErrors []error
for err := range errChan {
shutdownErrors = append(shutdownErrors, err)
log.Printf("Shutdown error: %v", err)
}
if len(shutdownErrors) > 0 {
return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误
}
log.Println("Touka server(s) exited gracefully.")
return nil
}
func handleGracefulShutdownWithContext(servers []*http.Server, ctx context.Context, timeout time.Duration, logger *reco.Logger) error {
// 创建一个 channel 来接收操作系统信号
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号
// 启动服务器
serverStopped := make(chan error, 1)
for _, srv := range servers {
go func(s *http.Server) {
serverStopped <- s.ListenAndServe()
}(srv)
}
select {
case <-ctx.Done():
// Context 被取消 (例如,通过外部取消函数)
log.Println("Context cancelled, shutting down Touka server(s)...")
case err := <-serverStopped:
// 服务器自身停止 (例如,端口被占用,或 ListenAndServe 返回错误)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("Touka HTTP server failed: %w", err)
}
log.Println("Touka HTTP server stopped gracefully.")
return nil // 服务器已自行优雅关闭,无需进一步处理
case <-quit:
// 接收到操作系统信号
log.Println("Shutting down Touka server(s) due to OS signal...")
}
// 关闭日志记录器
if logger != nil {
go func() {
log.Println("Closing Touka logger...")
CloseLogger(logger)
}()
}
// 创建一个带超时的上下文,用于 Shutdown
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
var wg sync.WaitGroup
errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel
// 并发地关闭所有服务器
for _, srv := range servers {
wg.Add(1)
go func(s *http.Server) {
defer wg.Done()
if err := s.Shutdown(shutdownCtx); err != nil {
// 将错误发送到 channel
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
}
}(srv)
}
wg.Wait()
close(errChan) // 关闭 channel,以便可以安全地遍历它
// 收集所有关闭过程中发生的错误
var shutdownErrors []error
for err := range errChan {
shutdownErrors = append(shutdownErrors, err)
log.Printf("Shutdown error: %v", err)
}
if len(shutdownErrors) > 0 {
return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误
}
log.Println("Touka server(s) exited gracefully.")
return nil
}
// --- 公共 Run 方法 ---
// Run 启动一个不支持优雅关闭的 HTTP 服务器
// 这是一个阻塞调用,主要用于简单的场景或快速测试
// 建议在生产环境中使用 RunShutdown 或其他支持优雅关闭的方法
func (engine *Engine) Run(addr ...string) error {
address := resolveAddress(addr)
srv := &http.Server{Addr: address, Handler: engine}
// 即使是不支持优雅关闭的 Run,也应用默认和用户配置,以保持行为一致性
engine.applyDefaultServerConfig(srv)
if engine.ServerConfigurator != nil {
engine.ServerConfigurator(srv)
}
log.Printf("Starting Touka HTTP server on %s (no graceful shutdown)", address)
return srv.ListenAndServe()
}
// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器
func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error {
srv := &http.Server{
Addr: addr,
Handler: engine,
BaseContext: func(l net.Listener) context.Context {
return engine.shutdownCtx
},
}
srv.RegisterOnShutdown(engine.shutdownCancel)
// 应用框架的默认配置和用户提供的自定义配置
engine.applyDefaultServerConfig(srv)
if engine.ServerConfigurator != nil {
engine.ServerConfigurator(srv)
}
runServer("HTTP", srv)
return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco)
}
// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器
func (engine *Engine) RunShutdownWithContext(addr string, ctx context.Context, timeouts ...time.Duration) error {
srv := &http.Server{
Addr: addr,
Handler: engine,
BaseContext: func(l net.Listener) context.Context {
return engine.shutdownCtx
},
}
srv.RegisterOnShutdown(engine.shutdownCancel)
// 应用框架的默认配置和用户提供的自定义配置
engine.applyDefaultServerConfig(srv)
if engine.ServerConfigurator != nil {
engine.ServerConfigurator(srv)
}
return handleGracefulShutdownWithContext([]*http.Server{srv}, ctx, getShutdownTimeout(timeouts), engine.LogReco)
}
// RunTLS 启动一个支持优雅关闭的 HTTPS 服务器
func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
if tlsConfig == nil { if tlsConfig == nil {
return errors.New("tls.Config must not be nil for RunTLS") return nil
}
return tlsConfig.Clone()
} }
// 配置 HTTP/2 支持 (如果使用默认配置) func parseHTTPSPort(addr string) (string, error) {
if engine.useDefaultProtocols { _, port, err := net.SplitHostPort(addr)
engine.setProtocols(&ProtocolsConfig{ if err != nil {
Http1: true, return "", fmt.Errorf("https address %q must include a port: %w", addr, err)
Http2: true, // 默认在 TLS 上启用 HTTP/2 }
}) return port, nil
} }
srv := &http.Server{ func applyMainServerConfig(engine *Engine, srv *http.Server, serveTLS bool) {
Addr: addr, if serveTLS {
Handler: engine,
TLSConfig: tlsConfig,
BaseContext: func(l net.Listener) context.Context {
return engine.shutdownCtx
},
}
srv.RegisterOnShutdown(engine.shutdownCancel)
// 应用框架的默认配置和用户提供的自定义配置
// 优先使用 TLSServerConfigurator,如果未设置,则回退到通用的 ServerConfigurator
engine.applyDefaultServerConfig(srv)
if engine.TLSServerConfigurator != nil { if engine.TLSServerConfigurator != nil {
engine.TLSServerConfigurator(srv) engine.TLSServerConfigurator(srv)
} else if engine.ServerConfigurator != nil { return
}
}
if engine.ServerConfigurator != nil {
engine.ServerConfigurator(srv) engine.ServerConfigurator(srv)
} }
runServer("HTTPS", srv)
return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco)
} }
// RunWithTLS 是 RunTLS 的别名,为了保持向后兼容性或更直观的命名 func applyRedirectServerConfig(engine *Engine, srv *http.Server) {
func (engine *Engine) RunWithTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { applyServerProtocols(srv, engine.serverProtocols)
return engine.RunTLS(addr, tlsConfig, timeouts...) if engine.ServerConfigurator != nil {
engine.ServerConfigurator(srv)
}
} }
// RunTLSRedir 启动 HTTP 重定向服务器和 HTTPS 应用服务器,两者都支持优雅关闭 func effectiveServerProtocols(engine *Engine, serveTLS bool) *http.Protocols {
func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { if engine == nil {
if tlsConfig == nil { return nil
return errors.New("tls.Config must not be nil for RunTLSRedir") }
if serveTLS && engine.useDefaultProtocols {
protocols := &http.Protocols{}
protocols.SetHTTP1(true)
protocols.SetHTTP2(true)
return protocols
}
return cloneServerProtocols(engine.serverProtocols)
} }
// --- HTTPS 服务器 --- func buildMainServer(engine *Engine, cfg runConfig) *http.Server {
if engine.useDefaultProtocols { serveTLS := cfg.mode != runModeHTTP
engine.setProtocols(&ProtocolsConfig{Http1: true, Http2: true}) server := &http.Server{
} Addr: cfg.addr,
httpsSrv := &http.Server{
Addr: httpsAddr,
Handler: engine, Handler: engine,
TLSConfig: tlsConfig, TLSConfig: cloneTLSConfig(cfg.tlsConfig),
BaseContext: func(l net.Listener) context.Context { }
if cfg.graceful {
server.BaseContext = func(net.Listener) context.Context {
return engine.shutdownCtx return engine.shutdownCtx
},
} }
httpsSrv.RegisterOnShutdown(engine.shutdownCancel) server.RegisterOnShutdown(engine.shutdownCancel)
engine.applyDefaultServerConfig(httpsSrv) }
if engine.TLSServerConfigurator != nil { applyServerProtocols(server, effectiveServerProtocols(engine, serveTLS))
engine.TLSServerConfigurator(httpsSrv) applyMainServerConfig(engine, server, serveTLS)
} else if engine.ServerConfigurator != nil { return server
engine.ServerConfigurator(httpsSrv) }
func firstRedirectHeaderHost(r *http.Request, headers []string) string {
if r == nil {
return ""
}
for _, header := range headers {
value := strings.TrimSpace(r.Header.Get(header))
if value == "" {
continue
}
if comma := strings.IndexByte(value, ','); comma >= 0 {
value = strings.TrimSpace(value[:comma])
}
if value != "" {
return value
}
}
return ""
}
func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) {
if cfg.useHeaderHostSet && !cfg.useHeaderHost {
if cfg.redirectHost == "" {
return "", http.StatusInternalServerError, false
}
return cfg.redirectHost, 0, true
}
if len(cfg.redirectHostHeaders) > 0 {
host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders)
if host == "" {
return "", http.StatusUpgradeRequired, false
}
return host, 0, true
}
if r == nil {
return "", http.StatusUpgradeRequired, false
}
host := strings.TrimSpace(r.Host)
if host == "" {
return "", http.StatusUpgradeRequired, false
}
return host, 0, true
}
func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) {
httpsAddr := cfg.addr
httpAddr := cfg.httpRedirectAddr
httpsPort, err := parseHTTPSPort(httpsAddr)
if err != nil {
return nil, err
} }
// --- HTTP 重定向服务器 ---
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host) host, statusCode, ok := redirectTargetHost(r, cfg)
if err != nil { if !ok {
host = r.Host http.Error(w, http.StatusText(statusCode), statusCode)
return
} }
_, httpsPort, err := net.SplitHostPort(httpsAddr) if parsedHost, _, err := net.SplitHostPort(host); err == nil {
if err != nil { host = parsedHost
// 如果 httpsAddr 没有端口,这是一个配置错误 if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
host = "[" + host + "]"
log.Fatalf("Invalid HTTPS address for redirection '%s': must include a port.", httpsAddr) }
} }
targetURL := "https://" + host targetURL := "https://" + host
// 只有在非标准 HTTPS 端口 (443) 时才附加端口号
if httpsPort != "443" { if httpsPort != "443" {
targetURL = "https://" + net.JoinHostPort(host, httpsPort) targetURL = "https://" + net.JoinHostPort(host, httpsPort)
} }
@ -351,22 +345,205 @@ func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Con
http.Redirect(w, r, targetURL, http.StatusMovedPermanently) http.Redirect(w, r, targetURL, http.StatusMovedPermanently)
}) })
httpSrv := &http.Server{
Addr: httpAddr, server := &http.Server{Addr: httpAddr, Handler: redirectHandler}
Handler: redirectHandler, applyRedirectServerConfig(engine, server)
} return server, nil
engine.applyDefaultServerConfig(httpSrv)
if engine.ServerConfigurator != nil {
engine.ServerConfigurator(httpSrv)
} }
// --- 启动服务器和优雅关闭 --- func validateRunConfig(cfg runConfig) error {
runServer("HTTPS", httpsSrv) if cfg.mode == runModeHTTPSRedirect && cfg.tlsConfig == nil {
runServer("HTTP Redirect", httpSrv) return errors.New("WithHTTPRedirect requires WithTLS")
return handleGracefulShutdown([]*http.Server{httpsSrv, httpSrv}, getShutdownTimeout(timeouts), engine.LogReco) }
if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil {
return errors.New("https mode requires WithTLS")
}
if cfg.gracefulCtx != nil && !cfg.graceful {
return errors.New("WithShutdownContext requires graceful shutdown")
}
if len(cfg.redirectHostHeaders) > 0 {
if !cfg.useHeaderHostSet || !cfg.useHeaderHost {
return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)")
}
}
if cfg.useHeaderHostSet && cfg.useHeaderHost {
if cfg.redirectHost != "" {
return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)")
}
} else if cfg.useHeaderHostSet && !cfg.useHeaderHost {
if cfg.redirectHost == "" {
return errors.New("WithUseHeaderHost(false) requires WithRedirectHost")
}
if len(cfg.redirectHostHeaders) > 0 {
return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)")
}
}
return nil
} }
// RunWithTLSRedir 是 RunTLSRedir 的别名,为了保持向后兼容性 func effectiveShutdownTimeout(cfg runConfig) time.Duration {
func (engine *Engine) RunWithTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { if cfg.shutdownTimeoutSet || cfg.shutdownDefaultSet {
return engine.RunTLSRedir(httpAddr, httpsAddr, tlsConfig, timeouts...) if cfg.shutdownTimeout > 0 {
return cfg.shutdownTimeout
}
}
return defaultShutdownTimeout
}
func closeLoggerAsync(logger *reco.Logger) {
if logger == nil {
return
}
go func() {
log.Println("Closing Touka logger...")
CloseLogger(logger)
}()
}
func shutdownServers(servers []*http.Server, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
var wg sync.WaitGroup
errChan := make(chan error, len(servers))
for _, srv := range servers {
wg.Add(1)
go func(s *http.Server) {
defer wg.Done()
if err := s.Shutdown(ctx); err != nil {
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
}
}(srv)
}
wg.Wait()
close(errChan)
var shutdownErrors []error
for err := range errChan {
shutdownErrors = append(shutdownErrors, err)
log.Printf("Shutdown error: %v", err)
}
if len(shutdownErrors) > 0 {
return errors.Join(shutdownErrors...)
}
return nil
}
func gracefulServe(servers []*http.Server, serveTLS []bool, timeout time.Duration, logger *reco.Logger, shutdownCtx context.Context) error {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(quit)
serverStopped := make(chan error, len(servers))
for i, srv := range servers {
serveTLSFlag := serveTLS[i]
go func(server *http.Server, useTLS bool) {
serverStopped <- serveServer(server, useTLS)
}(srv, serveTLSFlag)
}
select {
case err := <-serverStopped:
if err != nil && !errors.Is(err, http.ErrServerClosed) {
if shutdownErr := shutdownServers(servers, timeout); shutdownErr != nil {
return errors.Join(err, shutdownErr)
}
return err
}
log.Println("Touka server stopped gracefully.")
return nil
case <-quit:
log.Println("Shutting down Touka server(s) due to OS signal...")
case <-shutdownCtx.Done():
log.Println("Context cancelled, shutting down Touka server(s)...")
}
closeLoggerAsync(logger)
if err := shutdownServers(servers, timeout); err != nil {
return err
}
log.Println("Touka server(s) exited gracefully.")
return nil
}
// Run starts the engine with the provided startup options.
//
// Default behavior with no options:
// - HTTP only
// - listens on :8080
// - no graceful shutdown orchestration
//
// Add WithGracefulShutdown(...) or WithGracefulShutdownDefault() to enable
// signal-aware graceful shutdown and request-context cancellation semantics.
// Add WithTLS(...) to run HTTPS; this is independent from graceful shutdown.
func (engine *Engine) Run(opts ...RunOption) error {
cfg := defaultRunConfig()
for _, opt := range opts {
if opt == nil {
continue
}
if err := opt.apply(&cfg); err != nil {
return err
}
}
if cfg.httpRedirectAddr != "" {
cfg.mode = runModeHTTPSRedirect
} else if cfg.tlsConfig != nil {
cfg.mode = runModeHTTPS
}
if err := validateRunConfig(cfg); err != nil {
return err
}
serveTLS := cfg.mode != runModeHTTP
mainServer := buildMainServer(engine, cfg)
servers := []*http.Server{mainServer}
serveTLSFlags := []bool{serveTLS}
if cfg.mode == runModeHTTPSRedirect {
redirectServer, err := buildRedirectServer(engine, cfg)
if err != nil {
return err
}
servers = append(servers, redirectServer)
serveTLSFlags = append(serveTLSFlags, false)
}
if !cfg.graceful {
if len(servers) > 1 {
serverStopped := make(chan error, len(servers))
for i, srv := range servers {
serveTLSFlag := serveTLSFlags[i]
go func(server *http.Server, useTLS bool) {
serverStopped <- serveServer(server, useTLS)
}(srv, serveTLSFlag)
}
err := <-serverStopped
if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil {
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return errors.Join(err, shutdownErr)
}
return shutdownErr
}
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
}
protocolLabel := "HTTP"
if serveTLS {
protocolLabel = "HTTPS"
}
log.Printf("Starting Touka %s server on %s", protocolLabel, cfg.addr)
return serveServer(mainServer, serveTLS)
}
shutdownCtx := context.Background()
if cfg.gracefulCtx != nil {
shutdownCtx = cfg.gracefulCtx
}
return gracefulServe(servers, serveTLSFlags, effectiveShutdownTimeout(cfg), engine.LogReco, shutdownCtx)
} }

492
serve_test.go Normal file
View file

@ -0,0 +1,492 @@
package touka
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func generateSelfSignedCert(t *testing.T) tls.Certificate {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate private key: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "127.0.0.1"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatalf("create self-signed cert: %v", err)
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatalf("parse self-signed cert: %v", err)
}
return cert
}
func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen on ephemeral port: %v", err)
}
addr := listener.Addr().String()
if err := listener.Close(); err != nil {
t.Fatalf("close temporary listener: %v", err)
}
srv := &http.Server{
Addr: addr,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("ok"))
}),
// RunShutdown uses the HTTP startup path and must not let a shared
// ServerConfigurator accidentally turn it into HTTPS.
TLSConfig: &tls.Config{},
}
errCh := make(chan error, 1)
go func() {
errCh <- serveServer(srv, false)
}()
client := &http.Client{Timeout: 200 * time.Millisecond}
var resp *http.Response
requestURL := "http://" + addr
deadline := time.Now().Add(3 * time.Second)
for time.Now().Before(deadline) {
resp, err = client.Get(requestURL)
if err == nil {
break
}
time.Sleep(20 * time.Millisecond)
}
if err != nil {
select {
case serveErr := <-errCh:
t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: request error=%v, serve error=%v", err, serveErr)
default:
t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: %v", err)
}
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: got %d want %d", resp.StatusCode, http.StatusOK)
}
if string(body) != "ok" {
t.Fatalf("unexpected body: got %q want %q", string(body), "ok")
}
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
t.Fatalf("shutdown server: %v", err)
}
if err := <-errCh; !errors.Is(err, http.ErrServerClosed) {
t.Fatalf("serveServer should stop with ErrServerClosed after shutdown, got %v", err)
}
}
func TestRunRejectsRedirectWithoutTLS(t *testing.T) {
engine := New()
err := engine.Run(WithHTTPRedirect(":80"))
if err == nil {
t.Fatal("expected redirect mode without TLS to fail")
}
}
func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) {
engine := New()
err := engine.Run(
WithAddr(":443"),
WithTLS(&tls.Config{}),
WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})),
)
if err == nil {
t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail")
}
}
func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) {
cfg := defaultRunConfig()
if err := WithGracefulShutdownDefault().apply(&cfg); err != nil {
t.Fatalf("apply graceful default option: %v", err)
}
if !cfg.graceful {
t.Fatal("expected graceful shutdown to be enabled")
}
if cfg.shutdownTimeout != defaultShutdownTimeout {
t.Fatalf("expected default shutdown timeout %v, got %v", defaultShutdownTimeout, cfg.shutdownTimeout)
}
}
func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) {
cfg := defaultRunConfig()
tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12}
if err := WithTLS(tlsConfig).apply(&cfg); err != nil {
t.Fatalf("apply TLS option: %v", err)
}
if cfg.mode != runModeHTTPS {
t.Fatalf("expected HTTPS mode, got %v", cfg.mode)
}
if cfg.graceful {
t.Fatal("expected TLS option to remain independent from graceful shutdown")
}
if cfg.tlsConfig != tlsConfig {
t.Fatal("expected TLS config to be preserved in run config")
}
}
func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) {
engine := New()
if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil {
t.Fatal("expected redirect server builder to reject https address without port")
}
}
func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) {
cfg := defaultRunConfig()
ctx := t.Context()
if err := WithShutdownContext(ctx).apply(&cfg); err != nil {
t.Fatalf("apply shutdown context option: %v", err)
}
if err := validateRunConfig(cfg); err == nil {
t.Fatal("expected shutdown context without graceful shutdown to fail validation")
}
}
func TestValidateRunConfigDoesNotMutateMode(t *testing.T) {
cfg := defaultRunConfig()
cfg.httpRedirectAddr = ":80"
if err := validateRunConfig(cfg); err != nil {
t.Fatalf("validate run config: %v", err)
}
if cfg.mode != runModeHTTP {
t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode)
}
}
func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) {
cfg := defaultRunConfig()
cfg.mode = runModeHTTPSRedirect
cfg.tlsConfig = &tls.Config{}
cfg.useHeaderHost = false
cfg.useHeaderHostSet = true
if err := validateRunConfig(cfg); err == nil {
t.Fatal("expected configured host mode without redirect host to fail validation")
}
}
func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) {
cfg := defaultRunConfig()
cfg.mode = runModeHTTPSRedirect
cfg.tlsConfig = &tls.Config{}
cfg.useHeaderHost = true
cfg.useHeaderHostSet = true
cfg.redirectHost = "configured.example"
if err := validateRunConfig(cfg); err == nil {
t.Fatal("expected redirect host to be rejected when header host mode is enabled")
}
}
func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) {
engine := New()
server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP})
if server.BaseContext == nil {
t.Fatal("expected graceful main server to set BaseContext")
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen for base context check: %v", err)
}
defer listener.Close()
if got := server.BaseContext(listener); got != engine.shutdownCtx {
t.Fatal("expected graceful main server to use engine shutdown context")
}
}
func TestBuildMainServerTLSConfiguratorPrecedence(t *testing.T) {
engine := New()
serverConfigured := false
tlsConfigured := false
engine.SetServerConfigurator(func(s *http.Server) {
serverConfigured = true
s.ReadTimeout = time.Second
})
engine.SetTLSServerConfigurator(func(s *http.Server) {
tlsConfigured = true
s.IdleTimeout = time.Second
})
server := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
if !tlsConfigured {
t.Fatal("expected TLS configurator to run for HTTPS main server")
}
if serverConfigured {
t.Fatal("expected generic server configurator to be skipped when TLS configurator is set")
}
if server.IdleTimeout != time.Second {
t.Fatal("expected TLS configurator changes to be applied to HTTPS main server")
}
}
func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) {
engine := New()
configured := false
engine.SetServerConfigurator(func(s *http.Server) {
configured = true
s.ReadTimeout = time.Second
})
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
if !configured {
t.Fatal("expected redirect server to use generic server configurator")
}
if server.ReadTimeout != time.Second {
t.Fatal("expected redirect server configurator changes to be applied")
}
}
func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) {
engine := New()
httpsServer := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
if !httpsServer.Protocols.HTTP2() {
t.Fatal("expected HTTPS server to enable HTTP/2 under default protocol settings")
}
httpServer := buildMainServer(engine, defaultRunConfig())
if httpServer.Protocols.HTTP2() {
t.Fatal("expected later plain HTTP server to keep default HTTP/2 disabled")
}
}
func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://example.com/plain/path?q=1" {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{
addr: ":443",
httpRedirectAddr: ":80",
useHeaderHost: true,
useHeaderHostSet: true,
redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"},
})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
req.Header.Set("X-Forwarded-Host", "forwarded.example")
req.Header.Set("X-First-Host", "first.example")
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{
addr: ":443",
httpRedirectAddr: ":80",
useHeaderHost: true,
useHeaderHostSet: true,
redirectHostHeaders: []string{"X-Forwarded-Host"},
})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUpgradeRequired {
t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code)
}
}
func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{
addr: ":443",
httpRedirectAddr: ":80",
useHeaderHost: false,
useHeaderHostSet: true,
redirectHost: "configured.example",
})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
req.Header.Set("X-Forwarded-Host", "forwarded.example")
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestBuildRedirectServerPreservesIPv6BracketsInRedirectURL(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://[::1]/plain/path?q=1", nil)
req.Host = "[::1]:80"
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://[::1]/plain/path?q=1" {
t.Fatalf("unexpected IPv6 redirect location: %q", location)
}
}
func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) {
occupied, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen on occupied addr: %v", err)
}
occupiedAddr := occupied.Addr().String()
defer occupied.Close()
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen for redirect addr: %v", err)
}
redirectAddr := redirectListener.Addr().String()
if err := redirectListener.Close(); err != nil {
t.Fatalf("close redirect addr probe: %v", err)
}
engine := New()
redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
mainServer := &http.Server{Addr: occupiedAddr, Handler: engine}
err = gracefulServe([]*http.Server{mainServer, redirectServer}, []bool{false, false}, 200*time.Millisecond, nil, context.Background())
if err == nil {
t.Fatal("expected gracefulServe to fail when one server cannot bind")
}
if !strings.Contains(err.Error(), occupiedAddr) {
t.Fatalf("expected startup failure to mention occupied address %q, got %v", occupiedAddr, err)
}
conn, dialErr := net.DialTimeout("tcp", redirectAddr, 200*time.Millisecond)
if dialErr == nil {
conn.Close()
t.Fatalf("expected sibling redirect server to be shut down after startup failure, but %s is still accepting connections", redirectAddr)
}
if !strings.Contains(dialErr.Error(), "refused") && !strings.Contains(dialErr.Error(), "reset") {
t.Fatalf("unexpected dial result after shutdown, got %v", dialErr)
}
}
func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) {
occupied, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen on occupied addr: %v", err)
}
occupiedAddr := occupied.Addr().String()
defer occupied.Close()
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen for redirect addr: %v", err)
}
redirectAddr := redirectListener.Addr().String()
if err := redirectListener.Close(); err != nil {
t.Fatalf("close redirect addr probe: %v", err)
}
engine := New()
err = engine.Run(
WithAddr(occupiedAddr),
WithTLS(&tls.Config{}),
WithHTTPRedirect(redirectAddr),
)
if err == nil {
t.Fatal("expected non-graceful TLS redirect startup to return bind error")
}
if !strings.Contains(err.Error(), occupiedAddr) {
t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err)
}
}

66
sse.go
View file

@ -111,46 +111,40 @@ func (c *Context) EventStream(streamer func(w io.Writer) bool) {
// EventStreamChan 返回用于 SSE 事件流的 channel. // EventStreamChan 返回用于 SSE 事件流的 channel.
// 这是为高级并发场景设计的、更灵活的API. // 这是为高级并发场景设计的、更灵活的API.
// //
// 重要: // 与 EventStream 回调模式类似, 此方法是阻塞的: handler 会在此方法中停留,
// - 调用者必须 close(eventChan) 来结束事件流. // 直到事件 channel 被关闭 (close eventChan) 或客户端断开连接.
// - 调用者必须在独立的 goroutine 中消费 errChan 来处理错误和连接断开. // 这保证了 Context 不会在 SSE 流期间被 pool 回收.
// - 为防止 goroutine 泄漏, 建议发送方在 select 中同时监听 c.Request.Context().Done(). //
// eventChan 必须在调用此方法之前创建, 以便调用者可以在独立的 goroutine 中发送事件.
// 调用者必须在完成后 close(eventChan) 来结束流.
// 生产者 goroutine 必须在 select 中监听 c.Request.Context().Done(), 否则在客户端断开时会产生 goroutine 泄漏.
// //
// 详细用法: // 详细用法:
// //
// r.GET("/sse/channel", func(c *touka.Context) { // r.GET("/sse/channel", func(c *touka.Context) {
// eventChan, errChan := c.EventStreamChan() // eventChan := make(chan touka.Event)
// //
// // 必须在独立的goroutine中处理错误和连接断开. // // 在独立的 goroutine 中异步发送事件.
// go func() { // go func() {
// if err := <-errChan; err != nil { // defer close(eventChan) // 完成后关闭 channel 以结束事件流.
// c.Errorf("SSE channel error: %v", err)
// }
// }()
//
// // 在另一个goroutine中异步发送事件.
// go func() {
// // 重要: 必须在逻辑结束时关闭channel, 以通知框架.
// defer close(eventChan)
// //
// for i := 1; i <= 5; i++ { // for i := 1; i <= 5; i++ {
// select { // select {
// case <-c.Request.Context().Done(): // case <-c.Request.Context().Done():
// return // 客户端已断开, 退出 goroutine. // return // 客户端已断开, 退出 goroutine.
// default: // case eventChan <- touka.Event{
// eventChan <- touka.Event{
// Id: fmt.Sprintf("%d", i), // Id: fmt.Sprintf("%d", i),
// Data: "hello from channel", // Data: "hello from channel",
// }:
// } // }
// time.Sleep(2 * time.Second) // time.Sleep(2 * time.Second)
// } // }
// }
// }() // }()
//
// // 阻塞直到事件流结束.
// c.EventStreamChan(eventChan)
// }) // })
func (c *Context) EventStreamChan() (chan<- Event, <-chan error) { func (c *Context) EventStreamChan(eventChan <-chan Event) {
eventChan := make(chan Event)
errChan := make(chan error, 1)
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8") c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
c.Writer.Header().Set("Cache-Control", "no-cache, no-transform") c.Writer.Header().Set("Cache-Control", "no-cache, no-transform")
c.Writer.Header().Del("Connection") c.Writer.Header().Del("Connection")
@ -159,8 +153,16 @@ func (c *Context) EventStreamChan() (chan<- Event, <-chan error) {
c.Writer.WriteHeader(http.StatusOK) c.Writer.WriteHeader(http.StatusOK)
c.Writer.Flush() c.Writer.Flush()
// 捕获稳定的引用, 不持有 *Context 指针, 以免 Context 被 pool 回收后出现竞态.
w := c.Writer
fl, _ := w.(http.Flusher)
reqCtx := c.Request.Context()
goroutineExited := make(chan struct{})
// 写入 goroutine: 从 eventChan 消费事件并写入响应.
go func() { go func() {
defer close(errChan) defer close(goroutineExited)
for { for {
select { select {
@ -168,17 +170,23 @@ func (c *Context) EventStreamChan() (chan<- Event, <-chan error) {
if !ok { if !ok {
return return
} }
if err := event.Render(c.Writer); err != nil { if err := event.Render(w); err != nil {
errChan <- err
return return
} }
c.Writer.Flush() if fl != nil {
case <-c.Request.Context().Done(): fl.Flush()
errChan <- c.Request.Context().Err() }
case <-reqCtx.Done():
return return
} }
} }
}() }()
return eventChan, errChan // 阻塞直到:
// 1. 写入 goroutine 退出 (eventChan 关闭或写入失败)
// 2. 客户端断开连接 (reqCtx 取消)
select {
case <-goroutineExited:
case <-reqCtx.Done():
}
} }

142
sse_test.go Normal file
View file

@ -0,0 +1,142 @@
package touka
import (
"context"
"net/http/httptest"
"strings"
"testing"
"time"
)
// TestEventStreamChanBlocksHandler verifies that EventStreamChan blocks until
// the event channel is closed.
func TestEventStreamChanBlocksHandler(t *testing.T) {
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/sse", nil)
c, _ := CreateTestContextWithRequest(rr, req)
handlerReturned := make(chan struct{})
eventChan := make(chan Event)
// Start producer goroutine before EventStreamChan blocks
go func() {
defer close(eventChan)
time.Sleep(30 * time.Millisecond)
eventChan <- Event{Data: "hello"}
time.Sleep(30 * time.Millisecond)
}()
go func() {
c.EventStreamChan(eventChan)
close(handlerReturned)
}()
// Wait for goroutine to start
time.Sleep(10 * time.Millisecond)
// Handler should NOT have returned (eventChan not closed)
select {
case <-handlerReturned:
t.Fatal("Handler returned before eventChan was closed - EventStreamChan is not blocking")
case <-time.After(40 * time.Millisecond):
// good, still blocking
}
// Wait for producer to finish (30+30ms + margin)
select {
case <-handlerReturned:
// good, handler returned
case <-time.After(200 * time.Millisecond):
t.Fatal("Handler did not return after eventChan was closed")
}
}
// TestEventStreamChanUnblocksOnClientDisconnect verifies the handler returns
// when the request context is cancelled, even if eventChan is never closed.
func TestEventStreamChanUnblocksOnClientDisconnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/sse", nil).WithContext(ctx)
c, _ := CreateTestContextWithRequest(rr, req)
eventChan := make(chan Event)
handlerReturned := make(chan struct{})
// Producer never closes eventChan
go func() {
for {
select {
case <-ctx.Done():
return
case eventChan <- Event{Data: "tick"}:
time.Sleep(10 * time.Millisecond)
}
}
}()
go func() {
c.EventStreamChan(eventChan)
close(handlerReturned)
}()
// Handler should NOT have returned
select {
case <-handlerReturned:
t.Fatal("Handler returned before stream ended")
case <-time.After(60 * time.Millisecond):
// good, still blocked
}
// Cancel context to simulate client disconnect
cancel()
select {
case <-handlerReturned:
// good
case <-time.After(200 * time.Millisecond):
t.Fatal("Handler did not return after client disconnect")
}
}
// TestEventStreamChanWritesEvents verifies the SSE event format is correct.
func TestEventStreamChanWritesEvents(t *testing.T) {
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/sse", nil)
c, _ := CreateTestContextWithRequest(rr, req)
eventChan := make(chan Event)
go func() {
defer close(eventChan)
eventChan <- Event{Id: "1", Event: "tick", Data: "hello\nworld"}
eventChan <- Event{Id: "2", Data: "second"}
}()
c.EventStreamChan(eventChan)
body := rr.Body.String()
ct := rr.Header().Get("Content-Type")
if !strings.Contains(ct, "text/event-stream") {
t.Fatalf("expected text/event-stream content type, got %q", ct)
}
if !strings.Contains(body, "id: 1") {
t.Fatal("missing id field in first event")
}
if !strings.Contains(body, "event: tick") {
t.Fatal("missing event field in first event")
}
if !strings.Contains(body, "data: hello") {
t.Fatal("missing data line 1 in first event")
}
if !strings.Contains(body, "data: world") {
t.Fatal("missing data line 2 in first event")
}
if !strings.Contains(body, "id: 2") {
t.Fatal("missing id field in second event")
}
if !strings.Contains(body, "data: second") {
t.Fatal("missing data in second event")
}
}

View file

@ -22,10 +22,10 @@ type HandlerFunc func(*Context)
// HandlersChain 定义处理函数链(中间件栈)的类型。 // HandlersChain 定义处理函数链(中间件栈)的类型。
type HandlersChain []HandlerFunc type HandlersChain []HandlerFunc
// IRouter 定义了路由注册的接口提供路由分组和HTTP方法注册的能力。 // Router 定义了路由注册的接口提供路由分组和HTTP方法注册的能力。
type IRouter interface { type Router interface {
Group(relativePath string, handlers ...HandlerFunc) IRouter // 创建路由分组 Group(relativePath string, handlers ...HandlerFunc) Router // 创建路由分组
Use(middleware ...HandlerFunc) IRouter // 应用中间件到当前组或子组 Use(middleware ...HandlerFunc) Router // 应用中间件到当前组或子组
Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法 Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法
GET(relativePath string, handlers ...HandlerFunc) GET(relativePath string, handlers ...HandlerFunc)

82
tree.go
View file

@ -124,6 +124,7 @@ type node struct {
path string // 当前节点的路径段 path string // 当前节点的路径段
indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点
wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) wildChild bool // 是否包含通配符子节点(:param 或 *catchAll)
hasCaseInsensitivePath bool // 根节点是否包含需要 fixed-path 大小写修正的路由
nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有)
priority uint32 // 节点的优先级, 用于查找时优先匹配 priority uint32 // 节点的优先级, 用于查找时优先匹配
children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾
@ -131,6 +132,19 @@ type node struct {
fullPath string // 完整路径, 用于调试和错误信息 fullPath string // 完整路径, 用于调试和错误信息
} }
func routeNeedsCaseInsensitiveLookup(path string) bool {
for i := 0; i < len(path); i++ {
c := path[i]
if c >= utf8.RuneSelf {
return true
}
if c >= 'A' && c <= 'Z' {
return true
}
}
return false
}
// incrementChildPrio 增加给定子节点的优先级并在必要时重新排序. // incrementChildPrio 增加给定子节点的优先级并在必要时重新排序.
func (n *node) incrementChildPrio(pos int) int { func (n *node) incrementChildPrio(pos int) int {
cs := n.children // 获取子节点切片 cs := n.children // 获取子节点切片
@ -162,6 +176,9 @@ func (n *node) incrementChildPrio(pos int) int {
func (n *node) addRoute(path string, handlers HandlersChain) { func (n *node) addRoute(path string, handlers HandlersChain) {
fullPath := path // 记录完整的路径 fullPath := path // 记录完整的路径
n.priority++ // 增加当前节点的优先级 n.priority++ // 增加当前节点的优先级
if routeNeedsCaseInsensitiveLookup(path) {
n.hasCaseInsensitivePath = true
}
// 如果是空树(根节点) // 如果是空树(根节点)
if len(n.path) == 0 && len(n.children) == 0 { if len(n.path) == 0 && len(n.children) == 0 {
@ -452,12 +469,14 @@ type skippedNode struct {
// 建议进行 TSR(尾部斜杠重定向). // 建议进行 TSR(尾部斜杠重定向).
func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) { func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) {
var globalParamsCount int16 // 全局参数计数 var globalParamsCount int16 // 全局参数计数
var backtrackToWildChild bool
walk: // 外部循环用于遍历路由树 walk: // 外部循环用于遍历路由树
for { for {
prefix := n.path // 当前节点的路径前缀 prefix := n.path // 当前节点的路径前缀
if len(path) > len(prefix) { if len(path) > len(prefix) {
if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头 if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头
pathAtNode := path
path = path[len(prefix):] // 移除已匹配的前缀 path = path[len(prefix):] // 移除已匹配的前缀
// 在访问 path[0] 之前进行安全检查 // 在访问 path[0] 之前进行安全检查
@ -467,23 +486,16 @@ walk: // 外部循环用于遍历路由树
// 优先尝试所有非通配符子节点, 通过匹配索引字符 // 优先尝试所有非通配符子节点, 通过匹配索引字符
idxc := path[0] // 剩余路径的第一个字符 idxc := path[0] // 剩余路径的第一个字符
for i, c := range []byte(n.indices) { if !backtrackToWildChild {
if c == idxc { // 如果找到匹配的索引字符 for i := 0; i < len(n.indices); i++ {
if n.indices[i] == idxc { // 如果找到匹配的索引字符
// 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯 // 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯
if n.wildChild { if n.wildChild {
index := len(*skippedNodes) index := len(*skippedNodes)
*skippedNodes = (*skippedNodes)[:index+1] *skippedNodes = (*skippedNodes)[:index+1]
(*skippedNodes)[index] = skippedNode{ (*skippedNodes)[index] = skippedNode{
path: prefix + path, // 记录跳过的路径 path: pathAtNode, // 记录进入当前节点时的剩余路径
node: &node{ // 复制当前节点的状态 node: n,
path: n.path,
wildChild: n.wildChild,
nType: n.nType,
priority: n.priority,
children: n.children,
handlers: n.handlers,
fullPath: n.fullPath,
},
paramsCount: globalParamsCount, // 记录当前参数计数 paramsCount: globalParamsCount, // 记录当前参数计数
} }
} }
@ -492,6 +504,9 @@ walk: // 外部循环用于遍历路由树
continue walk // 继续外部循环 continue walk // 继续外部循环
} }
} }
} else {
backtrackToWildChild = false
}
if !n.wildChild { if !n.wildChild {
// 如果路径在循环结束时不等于 '/' 且当前节点没有子节点 // 如果路径在循环结束时不等于 '/' 且当前节点没有子节点
@ -507,6 +522,7 @@ walk: // 外部循环用于遍历路由树
*value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片 *value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片
} }
globalParamsCount = skippedNode.paramsCount // 恢复参数计数 globalParamsCount = skippedNode.paramsCount // 恢复参数计数
backtrackToWildChild = true
continue walk // 继续外部循环 continue walk // 继续外部循环
} }
} }
@ -547,7 +563,7 @@ walk: // 外部循环用于遍历路由树
i := len(*value.params) i := len(*value.params)
*value.params = (*value.params)[:i+1] // 扩展切片 *value.params = (*value.params)[:i+1] // 扩展切片
val := path[:end] // 提取参数值 val := path[:end] // 提取参数值
if unescape { // 如果需要进行 URL 解码 if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) {
if v, err := url.QueryUnescape(val); err == nil { if v, err := url.QueryUnescape(val); err == nil {
val = v // 解码成功则更新值 val = v // 解码成功则更新值
} }
@ -599,7 +615,7 @@ walk: // 外部循环用于遍历路由树
i := len(*value.params) i := len(*value.params)
*value.params = (*value.params)[:i+1] // 扩展切片 *value.params = (*value.params)[:i+1] // 扩展切片
val := path // 参数值是剩余的整个路径 val := path // 参数值是剩余的整个路径
if unescape { // 如果需要进行 URL 解码 if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) {
if v, err := url.QueryUnescape(path); err == nil { if v, err := url.QueryUnescape(path); err == nil {
val = v // 解码成功则更新值 val = v // 解码成功则更新值
} }
@ -634,6 +650,7 @@ walk: // 外部循环用于遍历路由树
*value.params = (*value.params)[:skippedNode.paramsCount] *value.params = (*value.params)[:skippedNode.paramsCount]
} }
globalParamsCount = skippedNode.paramsCount globalParamsCount = skippedNode.paramsCount
backtrackToWildChild = true
continue walk continue walk
} }
} }
@ -658,8 +675,8 @@ walk: // 外部循环用于遍历路由树
} }
// 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议 // 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议
for i, c := range []byte(n.indices) { for i := 0; i < len(n.indices); i++ {
if c == '/' { // 如果索引中包含 '/' if n.indices[i] == '/' { // 如果索引中包含 '/'
n = n.children[i] // 移动到对应的子节点 n = n.children[i] // 移动到对应的子节点
value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
(n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数
@ -688,6 +705,7 @@ walk: // 外部循环用于遍历路由树
*value.params = (*value.params)[:skippedNode.paramsCount] *value.params = (*value.params)[:skippedNode.paramsCount]
} }
globalParamsCount = skippedNode.paramsCount globalParamsCount = skippedNode.paramsCount
backtrackToWildChild = true
continue walk continue walk
} }
} }
@ -701,13 +719,15 @@ walk: // 外部循环用于遍历路由树
// 它还可以选择修复尾部斜杠. // 它还可以选择修复尾部斜杠.
// 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功. // 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功.
func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) { func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) {
const stackBufSize = 128 // 栈上缓冲区的默认大小 return n.findCaseInsensitivePathWithBuffer(path, nil, fixTrailingSlash)
}
// 在常见情况下使用栈上静态大小的缓冲区. func (n *node) findCaseInsensitivePathWithBuffer(path string, buf []byte, fixTrailingSlash bool) ([]byte, bool) {
// 如果路径太长, 则在堆上分配缓冲区. if buf != nil {
buf := make([]byte, 0, stackBufSize) buf = buf[:0]
if length := len(path) + 1; length > stackBufSize { }
buf = make([]byte, 0, length) // 如果路径太长, 则分配更大的缓冲区 if cap(buf) < len(path)+1 {
buf = make([]byte, 0, len(path)+1)
} }
ciPath := n.findCaseInsensitivePathRec( ciPath := n.findCaseInsensitivePathRec(
@ -758,8 +778,8 @@ walk: // 外部循环用于遍历路由树
// 未找到处理函数. // 未找到处理函数.
// 尝试通过添加尾部斜杠来修复路径 // 尝试通过添加尾部斜杠来修复路径
if fixTrailingSlash { if fixTrailingSlash {
for i, c := range []byte(n.indices) { for i := 0; i < len(n.indices); i++ {
if c == '/' { // 如果索引中包含 '/' if n.indices[i] == '/' { // 如果索引中包含 '/'
n = n.children[i] // 移动到对应的子节点 n = n.children[i] // 移动到对应的子节点
if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
(n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数
@ -781,8 +801,8 @@ walk: // 外部循环用于遍历路由树
if rb[0] != 0 { if rb[0] != 0 {
// 旧 rune 未处理完 // 旧 rune 未处理完
idxc := rb[0] idxc := rb[0]
for i, c := range []byte(n.indices) { for i := 0; i < len(n.indices); i++ {
if c == idxc { if n.indices[i] == idxc {
// 继续处理子节点 // 继续处理子节点
n = n.children[i] n = n.children[i]
npLen = len(n.path) npLen = len(n.path)
@ -813,9 +833,9 @@ walk: // 外部循环用于遍历路由树
rb = shiftNRuneBytes(rb, off) rb = shiftNRuneBytes(rb, off)
idxc := rb[0] idxc := rb[0]
for i, c := range []byte(n.indices) { for i := 0; i < len(n.indices); i++ {
// 小写匹配 // 小写匹配
if c == idxc { if n.indices[i] == idxc {
// 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在 // 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在
if out := n.children[i].findCaseInsensitivePathRec( if out := n.children[i].findCaseInsensitivePathRec(
path, ciPath, rb, fixTrailingSlash, path, ciPath, rb, fixTrailingSlash,
@ -832,9 +852,9 @@ walk: // 外部循环用于遍历路由树
rb = shiftNRuneBytes(rb, off) rb = shiftNRuneBytes(rb, off)
idxc := rb[0] idxc := rb[0]
for i, c := range []byte(n.indices) { for i := 0; i < len(n.indices); i++ {
// 大写匹配 // 大写匹配
if c == idxc { if n.indices[i] == idxc {
// 继续处理子节点 // 继续处理子节点
n = n.children[i] n = n.children[i]
npLen = len(n.path) npLen = len(n.path)
@ -852,7 +872,7 @@ walk: // 外部循环用于遍历路由树
return nil // 未找到, 返回 nil return nil // 未找到, 返回 nil
} }
n = n.children[0] // 移动到通配符子节点(通常是唯一一个) n = n.children[len(n.children)-1] // 通配符子节点约定始终位于末尾
switch n.nType { switch n.nType {
case param: // 参数节点 case param: // 参数节点
// 查找参数结束位置('/' 或路径末尾) // 查找参数结束位置('/' 或路径末尾)

View file

@ -11,6 +11,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"testing" "testing"
"time"
) )
// Used as a workaround since we can't compare functions or their addresses // Used as a workaround since we can't compare functions or their addresses
@ -39,6 +40,23 @@ func getSkippedNodes() *[]skippedNode {
return &ps return &ps
} }
func getValueWithTimeout(t *testing.T, tree *node, path string, unescape bool) nodeValue {
t.Helper()
resultCh := make(chan nodeValue, 1)
go func() {
resultCh <- tree.getValue(path, getParams(), getSkippedNodes(), unescape)
}()
select {
case value := <-resultCh:
return value
case <-time.After(2 * time.Second):
t.Fatalf("lookup for path %q timed out, likely stuck in backtracking", path)
return nodeValue{}
}
}
func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) { func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) {
unescape := false unescape := false
if len(unescapes) >= 1 { if len(unescapes) >= 1 {
@ -901,6 +919,34 @@ func TestTreeInvalidNodeType(t *testing.T) {
} }
} }
func TestFindCaseInsensitivePathWithStaticAndParamRoutesDoesNotPanicOnMiss(t *testing.T) {
tree := &node{}
routes := [...]string{
"/:user/:repo/info/refs",
"/healthz",
"/api/db/data",
"/api/db/sum",
}
for _, route := range routes {
tree.addRoute(route, fakeHandler(route))
}
defer func() {
if r := recover(); r != nil {
t.Fatalf("unexpected panic while looking up missing path: %v", r)
}
}()
if out, found := tree.findCaseInsensitivePath("/does-not-exist", true); found || out != nil {
t.Fatalf("expected missing path lookup to return no match, got %q, %t", string(out), found)
}
if out, found := tree.findCaseInsensitivePath("/does-not-exist", false); found || out != nil {
t.Fatalf("expected missing path lookup without trailing slash fix to return no match, got %q, %t", string(out), found)
}
}
func TestTreeInvalidParamsType(t *testing.T) { func TestTreeInvalidParamsType(t *testing.T) {
tree := &node{} tree := &node{}
// add a child with wildcard // add a child with wildcard
@ -1076,3 +1122,51 @@ func TestComplexBacktrackingWithCatchAll(t *testing.T) {
t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams) t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams)
} }
} }
func TestBacktrackingFallsThroughToWildcardBranch(t *testing.T) {
tests := []struct {
name string
routes []string
requestPath string
wantFullPath string
wantParams Params
}{
{
name: "param route after static dead end",
routes: []string{"/foo/bar", "/foo/:id/details"},
requestPath: "/foo/bar/details",
wantFullPath: "/foo/:id/details",
wantParams: Params{{Key: "id", Value: "bar"}},
},
{
name: "catch-all route after static dead end",
routes: []string{"/foo/bar", "/foo/:id/*rest"},
requestPath: "/foo/bar/baz.txt",
wantFullPath: "/foo/:id/*rest",
wantParams: Params{
{Key: "id", Value: "bar"},
{Key: "rest", Value: "/baz.txt"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tree := &node{}
for _, route := range tt.routes {
tree.addRoute(route, fakeHandler(route))
}
value := getValueWithTimeout(t, tree, tt.requestPath, false)
if value.handlers == nil {
t.Fatalf("expected handlers for %q", tt.requestPath)
}
if value.fullPath != tt.wantFullPath {
t.Fatalf("expected full path %q for %q, got %q", tt.wantFullPath, tt.requestPath, value.fullPath)
}
if value.params == nil || !reflect.DeepEqual(*value.params, tt.wantParams) {
t.Fatalf("expected params %v for %q, got %v", tt.wantParams, tt.requestPath, value.params)
}
})
}
}