Compare commits

...

19 commits
v0.3.6 ... main

Author SHA1 Message Date
WJQSERVER
a6e278d458 print errlog (jsonv2 marshal) 2026-01-26 08:08:01 +08:00
WJQSERVER
7b536ac137
Merge pull request #59 from infinite-iroha/fix-slice-panic
refactor: Improve engine's tree processing and context handling.
2025-12-15 00:05:02 +08:00
WJQSERVER
b348d7d41f update TempSkippedNodesPool 2025-12-14 23:42:50 +08:00
WJQSERVER
60b2936eff add TempSkippedNodesPool 2025-12-14 23:16:29 +08:00
WJQSERVER
9cfc82a347 chore: update go module dependencies. 2025-12-14 22:57:48 +08:00
WJQSERVER
904aea5df8 refactor: Improve engine's tree processing and context handling. 2025-12-14 22:56:37 +08:00
WJQSERVER
ee0ebc986c
Merge pull request #54 from infinite-iroha/dev
context added FileText method
2025-10-21 15:06:39 +08:00
wjqserver
e4aaaa1583 fix path to filepath 2025-10-21 15:06:26 +08:00
wjqserver
1361f6e237 update 2025-10-21 14:47:29 +08:00
WJQSERVER
a6458cca16
Merge pull request #53 from infinite-iroha/dev
update
2025-10-12 15:48:48 +08:00
wjqserver
76a89800a2 update 2025-10-12 15:47:02 +08:00
WJQSERVER
4955fb9d03
Merge pull request #52 from infinite-iroha/dev
fix StaticFS
2025-09-14 08:27:29 +08:00
wjqserver
5b98310de5 fix StaticFS 2025-09-14 08:24:01 +08:00
WJQSERVER
f1ac0dd6ff
Merge pull request #51 from infinite-iroha/dev
0.3.7
2025-09-10 02:40:51 +08:00
wjqserver
38ff5126e3 fix 2025-09-10 02:40:41 +08:00
WJQSERVER
b4e073ae2f
Update sse.go 2025-09-07 02:24:28 +08:00
WJQSERVER
af0a99acda add sse intn support 2025-09-06 17:55:45 +00:00
wjqserver
3ffde5742c add wanf 2025-08-20 16:50:26 +08:00
WJQSERVER
016df0efe4
Merge pull request #50 from infinite-iroha/dev
0.3.6
2025-08-01 10:27:01 +08:00
7 changed files with 428 additions and 56 deletions

View file

@ -18,11 +18,12 @@ import (
"net/netip" "net/netip"
"net/url" "net/url"
"os" "os"
"path" "path/filepath"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/WJQSERVER/wanf"
"github.com/fenthope/reco" "github.com/fenthope/reco"
"github.com/go-json-experiment/json" "github.com/go-json-experiment/json"
@ -42,7 +43,7 @@ type Context struct {
index int8 // 当前执行到处理链的哪个位置 index int8 // 当前执行到处理链的哪个位置
mu sync.RWMutex mu sync.RWMutex
Keys map[string]interface{} // 用于在中间件之间传递数据 Keys map[string]any // 用于在中间件之间传递数据
Errors []error // 用于收集处理过程中的错误 Errors []error // 用于收集处理过程中的错误
@ -64,6 +65,10 @@ type Context struct {
// 请求体Body大小限制 // 请求体Body大小限制
MaxRequestBodySize int64 MaxRequestBodySize int64
// skippedNodes 用于记录跳过的节点信息,以便回溯
// 通常在处理嵌套路由时使用
SkippedNodes []skippedNode
} }
// --- Context 相关方法实现 --- // --- Context 相关方法实现 ---
@ -77,20 +82,30 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
} else { } else {
c.Writer = newResponseWriter(w) c.Writer = newResponseWriter(w)
} }
//c.Writer = newResponseWriter(w)
c.Request = req c.Request = req
c.Params = c.Params[:0] // 清空 Params 切片,而不是重新分配,以复用底层数组 //c.Params = c.Params[:0] // 清空 Params 切片,而不是重新分配,以复用底层数组
//避免params长度为0
if cap(c.Params) > 0 {
c.Params = c.Params[:0]
} else {
c.Params = make(Params, 0, c.engine.maxParams)
}
c.handlers = nil c.handlers = nil
c.index = -1 // 初始为 -1`Next()` 将其设置为 0 c.index = -1 // 初始为 -1`Next()` 将其设置为 0
c.Keys = make(map[string]interface{}) // 每次请求重新创建 map避免数据污染 c.Keys = make(map[string]any) // 每次请求重新创建 map避免数据污染
c.Errors = c.Errors[:0] // 清空 Errors 切片 c.Errors = c.Errors[:0] // 清空 Errors 切片
c.queryCache = nil // 清空查询参数缓存 c.queryCache = nil // 清空查询参数缓存
c.formCache = nil // 清空表单数据缓存 c.formCache = nil // 清空表单数据缓存
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式 c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize
// c.HTTPClient 和 c.engine 保持不变,它们引用 Engine 实例的成员
if cap(c.SkippedNodes) > 0 {
c.SkippedNodes = c.SkippedNodes[:0]
} else {
c.SkippedNodes = make([]skippedNode, 0, 256)
}
} }
// Next 在处理链中执行下一个处理函数 // Next 在处理链中执行下一个处理函数
@ -122,10 +137,10 @@ func (c *Context) AbortWithStatus(code int) {
// Set 将一个键值对存储到 Context 中 // Set 将一个键值对存储到 Context 中
// 这是一个线程安全的操作,用于在中间件之间传递数据 // 这是一个线程安全的操作,用于在中间件之间传递数据
func (c *Context) Set(key string, value interface{}) { func (c *Context) Set(key string, value any) {
c.mu.Lock() // 加写锁 c.mu.Lock() // 加写锁
if c.Keys == nil { if c.Keys == nil {
c.Keys = make(map[string]interface{}) c.Keys = make(map[string]any)
} }
c.Keys[key] = value c.Keys[key] = value
c.mu.Unlock() // 解写锁 c.mu.Unlock() // 解写锁
@ -133,7 +148,7 @@ func (c *Context) Set(key string, value interface{}) {
// Get 从 Context 中获取一个值 // Get 从 Context 中获取一个值
// 这是一个线程安全的操作 // 这是一个线程安全的操作
func (c *Context) Get(key string) (value interface{}, exists bool) { func (c *Context) Get(key string) (value any, exists bool) {
c.mu.RLock() // 加读锁 c.mu.RLock() // 加读锁
value, exists = c.Keys[key] value, exists = c.Keys[key]
c.mu.RUnlock() // 解读锁 c.mu.RUnlock() // 解读锁
@ -208,7 +223,7 @@ func (c *Context) GetDuration(key string) (value time.Duration, exists bool) {
// MustGet 从 Context 中获取一个值,如果不存在则 panic // MustGet 从 Context 中获取一个值,如果不存在则 panic
// 适用于确定值一定存在的场景 // 适用于确定值一定存在的场景
func (c *Context) MustGet(key string) interface{} { func (c *Context) MustGet(key string) any {
if value, exists := c.Get(key); exists { if value, exists := c.Get(key); exists {
return value return value
} }
@ -269,7 +284,7 @@ func (c *Context) Raw(code int, contentType string, data []byte) {
} }
// String 向响应写入格式化的字符串 // String 向响应写入格式化的字符串
func (c *Context) String(code int, format string, values ...interface{}) { func (c *Context) String(code int, format string, values ...any) {
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
c.Writer.Write([]byte(fmt.Sprintf(format, values...))) c.Writer.Write([]byte(fmt.Sprintf(format, values...)))
} }
@ -281,13 +296,121 @@ func (c *Context) Text(code int, text string) {
c.Writer.Write([]byte(text)) c.Writer.Write([]byte(text))
} }
// FileText
func (c *Context) FileText(code int, filePath string) {
// 清理path
cleanPath := filepath.Clean(filePath)
if !filepath.IsAbs(cleanPath) {
c.AddError(fmt.Errorf("relative path not allowed: %s", cleanPath))
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("relative path not allowed"))
return
}
// 检查文件是否存在
if _, err := os.Stat(cleanPath); os.IsNotExist(err) {
c.AddError(fmt.Errorf("file not found: %s", cleanPath))
c.ErrorUseHandle(http.StatusNotFound, fmt.Errorf("file not found"))
return
}
// 打开文件
file, err := os.Open(cleanPath)
if err != nil {
c.AddError(fmt.Errorf("failed to open file %s: %w", cleanPath, err))
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to open file: %w", err))
return
}
defer file.Close()
// 获取文件信息以获取文件大小
fileInfo, err := file.Stat()
if err != nil {
c.AddError(fmt.Errorf("failed to get file info for %s: %w", cleanPath, err))
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to get file info: %w", err))
return
}
// 判断是否是dir
if fileInfo.IsDir() {
c.AddError(fmt.Errorf("path is a directory, not a file: %s", cleanPath))
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("path is a directory"))
return
}
c.SetHeader("Content-Type", "text/plain; charset=utf-8")
c.SetBodyStream(file, int(fileInfo.Size()))
}
/*
// not fot work
// FileTextSafeDir
func (c *Context) FileTextSafeDir(code int, filePath string, safeDir string) {
// 清理path
cleanPath := path.Clean(filePath)
if !filepath.IsAbs(cleanPath) {
c.AddError(fmt.Errorf("relative path not allowed: %s", cleanPath))
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("relative path not allowed"))
return
}
if strings.Contains(cleanPath, "..") {
c.AddError(fmt.Errorf("path traversal attempt detected: %s", cleanPath))
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("path traversal attempt detected"))
return
}
// 判断filePath是否包含在safeDir内, 防止路径穿越
relPath, err := filepath.Rel(safeDir, cleanPath)
if err != nil {
c.AddError(fmt.Errorf("failed to get relative path: %w", err))
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("failed to get relative path: %w", err))
return
}
cleanPath = filepath.Join(safeDir, relPath)
// 检查文件是否存在
if _, err := os.Stat(cleanPath); os.IsNotExist(err) {
c.AddError(fmt.Errorf("file not found: %s", cleanPath))
c.ErrorUseHandle(http.StatusNotFound, fmt.Errorf("file not found"))
return
}
// 打开文件
file, err := os.Open(cleanPath)
if err != nil {
c.AddError(fmt.Errorf("failed to open file %s: %w", cleanPath, err))
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to open file: %w", err))
return
}
defer file.Close()
// 获取文件信息以获取文件大小
fileInfo, err := file.Stat()
if err != nil {
c.AddError(fmt.Errorf("failed to get file info for %s: %w", cleanPath, err))
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to get file info: %w", err))
return
}
// 判断是否是dir
if fileInfo.IsDir() {
c.AddError(fmt.Errorf("path is a directory, not a file: %s", cleanPath))
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("path is a directory"))
return
}
c.SetHeader("Content-Type", "text/plain; charset=utf-8")
c.SetBodyStream(file, int(fileInfo.Size()))
}
*/
// JSON 向响应写入 JSON 数据 // JSON 向响应写入 JSON 数据
// 设置 Content-Type 为 application/json // 设置 Content-Type 为 application/json
func (c *Context) JSON(code int, obj interface{}) { func (c *Context) JSON(code int, obj any) {
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
if err := json.MarshalWrite(c.Writer, obj); err != nil { if err := json.MarshalWrite(c.Writer, obj); err != nil {
c.AddError(fmt.Errorf("failed to marshal JSON: %w", err)) c.AddError(fmt.Errorf("failed to marshal JSON: %w", err))
c.Errorf("failed to marshal JSON: %s", err)
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to marshal JSON: %w", err)) c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to marshal JSON: %w", err))
return return
} }
@ -295,7 +418,7 @@ func (c *Context) JSON(code int, obj interface{}) {
// GOB 向响应写入GOB数据 // GOB 向响应写入GOB数据
// 设置 Content-Type 为 application/octet-stream // 设置 Content-Type 为 application/octet-stream
func (c *Context) GOB(code int, obj interface{}) { func (c *Context) GOB(code int, obj any) {
c.Writer.Header().Set("Content-Type", "application/octet-stream") // 设置合适的 Content-Type c.Writer.Header().Set("Content-Type", "application/octet-stream") // 设置合适的 Content-Type
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
// GOB 编码 // GOB 编码
@ -307,11 +430,25 @@ func (c *Context) GOB(code int, obj interface{}) {
} }
} }
// WANF向响应写入WANF数据
// 设置 application/vnd.wjqserver.wanf; charset=utf-8
func (c *Context) WANF(code int, obj any) {
c.Writer.Header().Set("Content-Type", "application/vnd.wjqserver.wanf; charset=utf-8")
c.Writer.WriteHeader(code)
// WANF 编码
encoder := wanf.NewStreamEncoder(c.Writer)
if err := encoder.Encode(obj); err != nil {
c.AddError(fmt.Errorf("failed to encode WANF: %w", err))
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to encode WANF: %w", err))
return
}
}
// HTML 渲染 HTML 模板 // HTML 渲染 HTML 模板
// 如果 Engine 配置了 HTMLRender则使用它进行渲染 // 如果 Engine 配置了 HTMLRender则使用它进行渲染
// 否则,会进行简单的字符串输出 // 否则,会进行简单的字符串输出
// 预留接口,可以扩展为支持多种模板引擎 // 预留接口,可以扩展为支持多种模板引擎
func (c *Context) HTML(code int, name string, obj interface{}) { func (c *Context) HTML(code int, name string, obj any) {
c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8") c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8")
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
@ -342,7 +479,7 @@ func (c *Context) Redirect(code int, location string) {
} }
// ShouldBindJSON 尝试将请求体绑定到 JSON 对象 // ShouldBindJSON 尝试将请求体绑定到 JSON 对象
func (c *Context) ShouldBindJSON(obj interface{}) error { func (c *Context) ShouldBindJSON(obj any) error {
if c.Request.Body == nil { if c.Request.Body == nil {
return errors.New("request body is empty") return errors.New("request body is empty")
} }
@ -353,10 +490,28 @@ func (c *Context) ShouldBindJSON(obj interface{}) error {
return nil return nil
} }
// ShouldBindWANF
func (c *Context) ShouldBindWANF(obj any) error {
if c.Request.Body == nil {
return errors.New("request body is empty")
}
decoder, err := wanf.NewStreamDecoder(c.Request.Body)
if err != nil {
return fmt.Errorf("failed to create WANF decoder: %w", err)
}
if err := decoder.Decode(obj); err != nil {
return fmt.Errorf("WANF binding error: %w", err)
}
return nil
}
// Deprecated: This function is a reserved placeholder for future API extensions
// and is not yet implemented. It will either be properly defined or removed in v2.0.0. Do not use.
// ShouldBind 尝试将请求体绑定到各种类型JSON, Form, XML 等) // ShouldBind 尝试将请求体绑定到各种类型JSON, Form, XML 等)
// 这是一个复杂的通用绑定接口,通常根据 Content-Type 或其他头部来判断绑定方式 // 这是一个复杂的通用绑定接口,通常根据 Content-Type 或其他头部来判断绑定方式
// 预留接口,可根据项目需求进行扩展 // 预留接口,可根据项目需求进行扩展
func (c *Context) ShouldBind(obj interface{}) error { func (c *Context) ShouldBind(obj any) error {
// TODO: 完整的通用绑定逻辑 // TODO: 完整的通用绑定逻辑
// 可以根据 c.Request.Header.Get("Content-Type") 来判断是 JSON, Form, XML 等 // 可以根据 c.Request.Header.Get("Content-Type") 来判断是 JSON, Form, XML 等
// 例如: // 例如:
@ -409,7 +564,7 @@ func (c *Context) Err() error {
// Value returns the value associated with this context for key, or nil if no // Value returns the value associated with this context for key, or nil if no
// value is associated with key. // value is associated with key.
// 可以用于从 Context 中获取与特定键关联的值,包括 Go 原生 Context 的值和 Touka Context 的 Keys // 可以用于从 Context 中获取与特定键关联的值,包括 Go 原生 Context 的值和 Touka Context 的 Keys
func (c *Context) Value(key interface{}) interface{} { func (c *Context) Value(key any) any {
if keyAsString, ok := key.(string); ok { if keyAsString, ok := key.(string); ok {
if val, exists := c.Get(keyAsString); exists { if val, exists := c.Get(keyAsString); exists {
return val return val
@ -724,7 +879,7 @@ func (c *Context) GetRequestURIPath() string {
// 将文件内容作为响应body // 将文件内容作为响应body
func (c *Context) SetRespBodyFile(code int, filePath string) { func (c *Context) SetRespBodyFile(code int, filePath string) {
// 清理path // 清理path
cleanPath := path.Clean(filePath) cleanPath := filepath.Clean(filePath)
// 打开文件 // 打开文件
file, err := os.Open(cleanPath) file, err := os.Open(cleanPath)
@ -744,7 +899,7 @@ func (c *Context) SetRespBodyFile(code int, filePath string) {
} }
// 尝试根据文件扩展名猜测 Content-Type // 尝试根据文件扩展名猜测 Content-Type
contentType := mime.TypeByExtension(path.Ext(cleanPath)) contentType := mime.TypeByExtension(filepath.Ext(cleanPath))
if contentType == "" { if contentType == "" {
// 如果无法猜测,则使用默认的二进制流类型 // 如果无法猜测,则使用默认的二进制流类型
contentType = "application/octet-stream" contentType = "application/octet-stream"

View file

@ -421,6 +421,41 @@ func getHandlerName(h HandlerFunc) string {
} }
const MaxSkippedNodesCap = 256
// TempSkippedNodesPool 存储 *[]skippedNode 以复用内存
var TempSkippedNodesPool = sync.Pool{
New: func() any {
// 返回一个指向容量为 256 的新切片的指针
s := make([]skippedNode, 0, MaxSkippedNodesCap)
return &s
},
}
// GetTempSkippedNodes 从 Pool 中获取一个 *[]skippedNode 指针
func GetTempSkippedNodes() *[]skippedNode {
// 直接返回 Pool 中存储的指针
return TempSkippedNodesPool.Get().(*[]skippedNode)
}
// PutTempSkippedNodes 将用完的 *[]skippedNode 指针放回 Pool
func PutTempSkippedNodes(skippedNodes *[]skippedNode) {
if skippedNodes == nil || *skippedNodes == nil {
return
}
// 检查容量是否符合预期。如果容量不足,则丢弃,不放回 Pool。
if cap(*skippedNodes) < MaxSkippedNodesCap {
return // 丢弃该对象,让 Pool 在下次 Get 时通过 New 重新分配
}
// 长度重置为 0保留容量实现复用
*skippedNodes = (*skippedNodes)[:0]
// 将指针存回 Pool
TempSkippedNodesPool.Put(skippedNodes)
}
// 405中间件 // 405中间件
func MethodNotAllowed() HandlerFunc { func MethodNotAllowed() HandlerFunc {
return func(c *Context) { return func(c *Context) {
@ -432,9 +467,10 @@ func MethodNotAllowed() HandlerFunc {
// 如果是 OPTIONS 请求,尝试查找所有允许的方法 // 如果是 OPTIONS 请求,尝试查找所有允许的方法
allowedMethods := []string{} allowedMethods := []string{}
for _, treeIter := range engine.methodTrees { for _, treeIter := range engine.methodTrees {
var tempSkippedNodes []skippedNode
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) tempSkippedNodes := GetTempSkippedNodes()
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
PutTempSkippedNodes(tempSkippedNodes)
if value.handlers != nil { if value.handlers != nil {
allowedMethods = append(allowedMethods, treeIter.method) allowedMethods = append(allowedMethods, treeIter.method)
} }
@ -451,9 +487,10 @@ func MethodNotAllowed() HandlerFunc {
if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 if treeIter.method == httpMethod { // 已经处理过当前方法,跳过
continue continue
} }
var tempSkippedNodes []skippedNode // 用于临时查找,不影响主 Context
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) // 只查找是否存在,不需要参数 tempSkippedNodes := GetTempSkippedNodes()
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数
PutTempSkippedNodes(tempSkippedNodes)
if value.handlers != nil { if value.handlers != nil {
// 使用定义的ErrorHandle处理 // 使用定义的ErrorHandle处理
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed"))
@ -661,9 +698,8 @@ func (engine *Engine) handleRequest(c *Context) {
// 查找匹配的节点和处理函数 // 查找匹配的节点和处理函数
// 这里传递 &c.Params 而不是重新创建,以利用 Context 中预分配的容量 // 这里传递 &c.Params 而不是重新创建,以利用 Context 中预分配的容量
// skippedNodes 内部使用,因此无需从外部传入已分配的 slice // skippedNodes 内部使用,因此无需从外部传入已分配的 slice
var skippedNodes []skippedNode // 用于回溯的跳过节点
// 直接在 rootNode 上调用 getValue 方法 // 直接在 rootNode 上调用 getValue 方法
value := rootNode.getValue(requestPath, &c.Params, &skippedNodes, true) // unescape=true 对路径参数进行 URL 解码 value := rootNode.getValue(requestPath, &c.Params, &c.SkippedNodes, true) // unescape=true 对路径参数进行 URL 解码
if value.handlers != nil { if value.handlers != nil {
//c.handlers = engine.combineHandlers(engine.globalHandlers, value.handlers) // 组合全局中间件和路由处理函数 //c.handlers = engine.combineHandlers(engine.globalHandlers, value.handlers) // 组合全局中间件和路由处理函数

View file

@ -6,7 +6,6 @@ package touka
import ( import (
"errors" "errors"
"fmt"
"net/http" "net/http"
"path" "path"
"strings" "strings"
@ -19,13 +18,19 @@ var allowedFileServerMethods = map[string]struct{}{
http.MethodHead: {}, http.MethodHead: {},
} }
var (
ErrInputFSisNil = errors.New("input FS is nil")
ErrMethodNotAllowed = errors.New("method not allowed")
)
// FileServer方式, 返回一个HandleFunc, 统一化处理 // FileServer方式, 返回一个HandleFunc, 统一化处理
func FileServer(fs http.FileSystem) HandlerFunc { func FileServer(fs http.FileSystem) HandlerFunc {
if fs == nil { if fs == nil {
return func(c *Context) { return func(c *Context) {
c.ErrorUseHandle(500, errors.New("Input FileSystem is nil")) c.ErrorUseHandle(http.StatusInternalServerError, ErrInputFSisNil)
} }
} }
fileServerInstance := http.FileServer(fs) fileServerInstance := http.FileServer(fs)
return func(c *Context) { return func(c *Context) {
FileServerHandleServe(c, fileServerInstance) FileServerHandleServe(c, fileServerInstance)
@ -37,7 +42,6 @@ func FileServer(fs http.FileSystem) HandlerFunc {
func FileServerHandleServe(c *Context, fsHandle http.Handler) { func FileServerHandleServe(c *Context, fsHandle http.Handler) {
if fsHandle == nil { if fsHandle == nil {
ErrInputFSisNil := errors.New("Input FileSystem Handle is nil")
c.AddError(ErrInputFSisNil) c.AddError(ErrInputFSisNil)
// 500 // 500
c.ErrorUseHandle(http.StatusInternalServerError, ErrInputFSisNil) c.ErrorUseHandle(http.StatusInternalServerError, ErrInputFSisNil)
@ -59,7 +63,7 @@ func FileServerHandleServe(c *Context, fsHandle http.Handler) {
return return
} else { } else {
// 否则,返回 405 Method Not Allowed // 否则,返回 405 Method Not Allowed
c.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, fmt.Errorf("Method %s is Not Allowed on FileServer", c.Request.Method)) c.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, ErrMethodNotAllowed)
} }
} else { } else {
c.Next() c.Next()
@ -240,7 +244,7 @@ func (engine *Engine) StaticFS(relativePath string, fs http.FileSystem) {
relativePath += "/" relativePath += "/"
} }
fileServer := http.FileServer(fs) fileServer := http.StripPrefix(relativePath, http.FileServer(fs))
engine.ANY(relativePath+"*filepath", GetStaticFSHandleFunc(fileServer)) engine.ANY(relativePath+"*filepath", GetStaticFSHandleFunc(fileServer))
} }
@ -254,7 +258,7 @@ func (group *RouterGroup) StaticFS(relativePath string, fs http.FileSystem) {
relativePath += "/" relativePath += "/"
} }
fileServer := http.FileServer(fs) fileServer := http.StripPrefix(relativePath, http.FileServer(fs))
group.ANY(relativePath+"*filepath", GetStaticFSHandleFunc(fileServer)) group.ANY(relativePath+"*filepath", GetStaticFSHandleFunc(fileServer))
} }

7
go.mod
View file

@ -1,15 +1,16 @@
module github.com/infinite-iroha/touka module github.com/infinite-iroha/touka
go 1.24.5 go 1.25.1
require ( require (
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2
github.com/WJQSERVER-STUDIO/httpc v0.8.2 github.com/WJQSERVER-STUDIO/httpc v0.8.2
github.com/WJQSERVER/wanf v0.0.3
github.com/fenthope/reco v0.0.4 github.com/fenthope/reco v0.0.4
github.com/go-json-experiment/json v0.0.0-20250725192818-e39067aee2d2 github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e
) )
require ( require (
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
golang.org/x/net v0.42.0 // indirect golang.org/x/net v0.49.0 // indirect
) )

10
go.sum
View file

@ -2,11 +2,13 @@ github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 h1:AiIHXP21LpK7pFfqUlUstgQEWzjbe
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok= github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok=
github.com/WJQSERVER-STUDIO/httpc v0.8.2 h1:PFPLodV0QAfGEP6915J57vIqoKu9cGuuiXG/7C9TNUk= github.com/WJQSERVER-STUDIO/httpc v0.8.2 h1:PFPLodV0QAfGEP6915J57vIqoKu9cGuuiXG/7C9TNUk=
github.com/WJQSERVER-STUDIO/httpc v0.8.2/go.mod h1:8WhHVRO+olDFBSvL5PC/bdMkb6U3vRdPJ4p4pnguV5Y= github.com/WJQSERVER-STUDIO/httpc v0.8.2/go.mod h1:8WhHVRO+olDFBSvL5PC/bdMkb6U3vRdPJ4p4pnguV5Y=
github.com/WJQSERVER/wanf v0.0.3 h1:OqhG7ETiR5Knqr0lmbb+iUMw9O7re2vEogjVf06QevM=
github.com/WJQSERVER/wanf v0.0.3/go.mod h1:q2Pyg+G+s1acMWxrbI4CwS/Yk76/BzLREEdZ8iFwUNE=
github.com/fenthope/reco v0.0.4 h1:yo2g3aWwdoMpaZWZX4SdZOW7mCK82RQIU/YI8ZUQThM= github.com/fenthope/reco v0.0.4 h1:yo2g3aWwdoMpaZWZX4SdZOW7mCK82RQIU/YI8ZUQThM=
github.com/fenthope/reco v0.0.4/go.mod h1:eMyS8HpdMVdJ/2WJt6Cvt8P1EH9Igzj5lSJrgc+0jeg= github.com/fenthope/reco v0.0.4/go.mod h1:eMyS8HpdMVdJ/2WJt6Cvt8P1EH9Igzj5lSJrgc+0jeg=
github.com/go-json-experiment/json v0.0.0-20250725192818-e39067aee2d2 h1:iizUGZ9pEquQS5jTGkh4AqeeHCMbfbjeb0zMt0aEFzs= github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU=
github.com/go-json-experiment/json v0.0.0-20250725192818-e39067aee2d2/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M= github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=

184
sse.go Normal file
View file

@ -0,0 +1,184 @@
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2025 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
import (
"bytes"
"io"
"net/http"
"strings"
)
// Event 代表一个服务器发送事件(SSE).
type Event struct {
// Event 是事件的名称.
Event string
// Data 是事件的内容, 可以是多行文本.
Data string
// Id 是事件的唯一标识符.
Id string
// Retry 是指定客户端在连接丢失后应等待多少毫秒后尝试重新连接.
Retry string
}
// Render 将事件格式化并写入给定的 writer.
// 通过逐行处理数据, 此方法可防止因数据中包含换行符而导致的CRLF注入问题.
// 为了性能, 它使用 bytes.Buffer 并通过 WriteTo 直接写入, 以避免不必要的内存分配.
func (e *Event) Render(w io.Writer) error {
var buf bytes.Buffer
if len(e.Id) > 0 {
buf.WriteString("id: ")
buf.WriteString(e.Id)
buf.WriteString("\n")
}
if len(e.Event) > 0 {
buf.WriteString("event: ")
buf.WriteString(e.Event)
buf.WriteString("\n")
}
if len(e.Data) > 0 {
lines := strings.Split(e.Data, "\n")
for _, line := range lines {
buf.WriteString("data: ")
buf.WriteString(line)
buf.WriteString("\n")
}
}
if len(e.Retry) > 0 {
buf.WriteString("retry: ")
buf.WriteString(e.Retry)
buf.WriteString("\n")
}
// 每个事件都以一个额外的换行符结尾.
buf.WriteString("\n")
// 直接将 buffer 的内容写入 writer, 避免生成中间字符串.
_, err := buf.WriteTo(w)
return err
}
// EventStream 启动一个 SSE 事件流.
// 这是推荐的、更简单安全的方式, 采用阻塞和回调的设计, 框架负责管理连接生命周期.
//
// 详细用法:
//
// r.GET("/sse/callback", func(c *touka.Context) {
// // streamer 回调函数会在一个循环中被调用.
// c.EventStream(func(w io.Writer) bool {
// event := touka.Event{
// Event: "time-tick",
// Data: time.Now().Format(time.RFC1123),
// }
//
// if err := event.Render(w); err != nil {
// // 发生写入错误, 停止发送.
// return false // 返回 false 结束事件流.
// }
//
// time.Sleep(2 * time.Second)
// return true // 返回 true 继续事件流.
// })
// // 当事件流结束后(例如客户端关闭页面), 这行代码会被执行.
// fmt.Println("Client disconnected from /sse/callback")
// })
func (c *Context) EventStream(streamer func(w io.Writer) bool) {
// 为现代网络协议优化头部.
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
c.Writer.Header().Set("Cache-Control", "no-cache, no-transform")
c.Writer.Header().Del("Connection")
c.Writer.Header().Del("Transfer-Encoding")
c.Writer.WriteHeader(http.StatusOK)
c.Writer.Flush() // 直接调用, ResponseWriter 接口保证了 Flush 方法的存在.
for {
select {
case <-c.Request.Context().Done():
return
default:
if !streamer(c.Writer) {
return
}
c.Writer.Flush()
}
}
}
// EventStreamChan 返回用于 SSE 事件流的 channel.
// 这是为高级并发场景设计的、更灵活的API.
//
// 重要:
// - 调用者必须 close(eventChan) 来结束事件流.
// - 调用者必须在独立的 goroutine 中消费 errChan 来处理错误和连接断开.
// - 为防止 goroutine 泄漏, 建议发送方在 select 中同时监听 c.Request.Context().Done().
//
// 详细用法:
//
// r.GET("/sse/channel", func(c *touka.Context) {
// eventChan, errChan := c.EventStreamChan()
//
// // 必须在独立的goroutine中处理错误和连接断开.
// go func() {
// if err := <-errChan; err != nil {
// c.Errorf("SSE channel error: %v", err)
// }
// }()
//
// // 在另一个goroutine中异步发送事件.
// go func() {
// // 重要: 必须在逻辑结束时关闭channel, 以通知框架.
// defer close(eventChan)
//
// for i := 1; i <= 5; i++ {
// select {
// case <-c.Request.Context().Done():
// return // 客户端已断开, 退出 goroutine.
// default:
// eventChan <- touka.Event{
// Id: fmt.Sprintf("%d", i),
// Data: "hello from channel",
// }
// time.Sleep(2 * time.Second)
// }
// }
// }()
// })
func (c *Context) EventStreamChan() (chan<- Event, <-chan error) {
eventChan := make(chan Event)
errChan := make(chan error, 1)
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
c.Writer.Header().Set("Cache-Control", "no-cache, no-transform")
c.Writer.Header().Del("Connection")
c.Writer.Header().Del("Transfer-Encoding")
c.Writer.WriteHeader(http.StatusOK)
c.Writer.Flush()
go func() {
defer close(errChan)
for {
select {
case event, ok := <-eventChan:
if !ok {
return
}
if err := event.Render(c.Writer); err != nil {
errChan <- err
return
}
c.Writer.Flush()
case <-c.Request.Context().Done():
errChan <- c.Request.Context().Err()
return
}
}
}()
return eventChan, errChan
}

26
tree.go
View file

@ -5,7 +5,6 @@
package touka package touka
import ( import (
"bytes"
"net/url" "net/url"
"strings" "strings"
"unicode" "unicode"
@ -27,12 +26,6 @@ func BytesToString(b []byte) string {
return unsafe.String(unsafe.SliceData(b), len(b)) return unsafe.String(unsafe.SliceData(b), len(b))
} }
var (
strColon = []byte(":") // 定义字节切片常量, 表示冒号, 用于路径参数识别
strStar = []byte("*") // 定义字节切片常量, 表示星号, 用于捕获所有路径识别
strSlash = []byte("/") // 定义字节切片常量, 表示斜杠, 用于路径分隔符识别
)
// Param 是单个 URL 参数, 由键和值组成. // Param 是单个 URL 参数, 由键和值组成.
type Param struct { type Param struct {
Key string // 参数的键名 Key string // 参数的键名
@ -106,17 +99,14 @@ func (n *node) addChild(child *node) {
// countParams 计算路径中参数(冒号)和捕获所有(星号)的数量. // countParams 计算路径中参数(冒号)和捕获所有(星号)的数量.
func countParams(path string) uint16 { func countParams(path string) uint16 {
var n uint16 colons := strings.Count(path, ":")
s := StringToBytes(path) // 将路径字符串转换为字节切片 stars := strings.Count(path, "*")
n += uint16(bytes.Count(s, strColon)) // 统计冒号的数量 return uint16(colons + stars)
n += uint16(bytes.Count(s, strStar)) // 统计星号的数量
return n
} }
// countSections 计算路径中斜杠('/')的数量, 即路径段的数量. // countSections 计算路径中斜杠('/')的数量, 即路径段的数量.
func countSections(path string) uint16 { func countSections(path string) uint16 {
s := StringToBytes(path) // 将路径字符串转换为字节切片 return uint16(strings.Count(path, "/"))
return uint16(bytes.Count(s, strSlash)) // 统计斜杠的数量
} }
// nodeType 定义了节点的类型. // nodeType 定义了节点的类型.
@ -418,10 +408,10 @@ func (n *node) insertChild(path string, fullPath string, handlers HandlersChain)
fullPath: fullPath, // 设置完整路径 fullPath: fullPath, // 设置完整路径
} }
n.addChild(child) // 添加子节点 n.addChild(child) // 添加子节点
n.indices = string('/') // 索引设置为 '/' n.indices = "/" // 索引设置为 '/'
n = child // 移动到新创建的 catchAll 节点 n = child // 移动到新创建的 catchAll 节点
n.priority++ // 增加优先级 n.priority++ // 增加优先级
// 第二个节点: 包含变量的节点 // 第二个节点: 包含变量的节点
child = &node{ child = &node{