diff --git a/context.go b/context.go index c1e2bb8..8c52b1f 100644 --- a/context.go +++ b/context.go @@ -18,11 +18,12 @@ import ( "net/netip" "net/url" "os" - "path" + "path/filepath" "strings" "sync" "time" + "github.com/WJQSERVER/wanf" "github.com/fenthope/reco" "github.com/go-json-experiment/json" @@ -42,7 +43,7 @@ type Context struct { index int8 // 当前执行到处理链的哪个位置 mu sync.RWMutex - Keys map[string]interface{} // 用于在中间件之间传递数据 + Keys map[string]any // 用于在中间件之间传递数据 Errors []error // 用于收集处理过程中的错误 @@ -64,6 +65,10 @@ type Context struct { // 请求体Body大小限制 MaxRequestBodySize int64 + + // skippedNodes 用于记录跳过的节点信息,以便回溯 + // 通常在处理嵌套路由时使用 + SkippedNodes []skippedNode } // --- Context 相关方法实现 --- @@ -77,20 +82,30 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } else { c.Writer = newResponseWriter(w) } - //c.Writer = newResponseWriter(w) c.Request = req - c.Params = c.Params[:0] // 清空 Params 切片,而不是重新分配,以复用底层数组 + //c.Params = c.Params[:0] // 清空 Params 切片,而不是重新分配,以复用底层数组 + //避免params长度为0 + if cap(c.Params) > 0 { + c.Params = c.Params[:0] + } else { + c.Params = make(Params, 0, c.engine.maxParams) + } c.handlers = nil c.index = -1 // 初始为 -1,`Next()` 将其设置为 0 - c.Keys = make(map[string]interface{}) // 每次请求重新创建 map,避免数据污染 + c.Keys = make(map[string]any) // 每次请求重新创建 map,避免数据污染 c.Errors = c.Errors[:0] // 清空 Errors 切片 c.queryCache = nil // 清空查询参数缓存 c.formCache = nil // 清空表单数据缓存 c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式 c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize - // c.HTTPClient 和 c.engine 保持不变,它们引用 Engine 实例的成员 + + if cap(c.SkippedNodes) > 0 { + c.SkippedNodes = c.SkippedNodes[:0] + } else { + c.SkippedNodes = make([]skippedNode, 0, 256) + } } // Next 在处理链中执行下一个处理函数 @@ -122,10 +137,10 @@ func (c *Context) AbortWithStatus(code int) { // Set 将一个键值对存储到 Context 中 // 这是一个线程安全的操作,用于在中间件之间传递数据 -func (c *Context) Set(key string, value interface{}) { +func (c *Context) Set(key string, value any) { c.mu.Lock() // 加写锁 if c.Keys == nil { - c.Keys = make(map[string]interface{}) + c.Keys = make(map[string]any) } c.Keys[key] = value c.mu.Unlock() // 解写锁 @@ -133,7 +148,7 @@ func (c *Context) Set(key string, value interface{}) { // Get 从 Context 中获取一个值 // 这是一个线程安全的操作 -func (c *Context) Get(key string) (value interface{}, exists bool) { +func (c *Context) Get(key string) (value any, exists bool) { c.mu.RLock() // 加读锁 value, exists = c.Keys[key] c.mu.RUnlock() // 解读锁 @@ -208,7 +223,7 @@ func (c *Context) GetDuration(key string) (value time.Duration, exists bool) { // MustGet 从 Context 中获取一个值,如果不存在则 panic // 适用于确定值一定存在的场景 -func (c *Context) MustGet(key string) interface{} { +func (c *Context) MustGet(key string) any { if value, exists := c.Get(key); exists { return value } @@ -269,7 +284,7 @@ func (c *Context) Raw(code int, contentType string, data []byte) { } // String 向响应写入格式化的字符串 -func (c *Context) String(code int, format string, values ...interface{}) { +func (c *Context) String(code int, format string, values ...any) { c.Writer.WriteHeader(code) c.Writer.Write([]byte(fmt.Sprintf(format, values...))) } @@ -281,13 +296,121 @@ func (c *Context) Text(code int, text string) { c.Writer.Write([]byte(text)) } +// FileText +func (c *Context) FileText(code int, filePath string) { + // 清理path + cleanPath := filepath.Clean(filePath) + if !filepath.IsAbs(cleanPath) { + c.AddError(fmt.Errorf("relative path not allowed: %s", cleanPath)) + c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("relative path not allowed")) + return + } + // 检查文件是否存在 + if _, err := os.Stat(cleanPath); os.IsNotExist(err) { + c.AddError(fmt.Errorf("file not found: %s", cleanPath)) + c.ErrorUseHandle(http.StatusNotFound, fmt.Errorf("file not found")) + return + } + + // 打开文件 + file, err := os.Open(cleanPath) + if err != nil { + c.AddError(fmt.Errorf("failed to open file %s: %w", cleanPath, err)) + c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to open file: %w", err)) + return + } + defer file.Close() + + // 获取文件信息以获取文件大小 + fileInfo, err := file.Stat() + if err != nil { + c.AddError(fmt.Errorf("failed to get file info for %s: %w", cleanPath, err)) + c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to get file info: %w", err)) + return + } + // 判断是否是dir + if fileInfo.IsDir() { + c.AddError(fmt.Errorf("path is a directory, not a file: %s", cleanPath)) + c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("path is a directory")) + return + } + + c.SetHeader("Content-Type", "text/plain; charset=utf-8") + + c.SetBodyStream(file, int(fileInfo.Size())) +} + +/* +// not fot work +// FileTextSafeDir +func (c *Context) FileTextSafeDir(code int, filePath string, safeDir string) { + + // 清理path + cleanPath := path.Clean(filePath) + if !filepath.IsAbs(cleanPath) { + c.AddError(fmt.Errorf("relative path not allowed: %s", cleanPath)) + c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("relative path not allowed")) + return + } + if strings.Contains(cleanPath, "..") { + c.AddError(fmt.Errorf("path traversal attempt detected: %s", cleanPath)) + c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("path traversal attempt detected")) + return + } + + // 判断filePath是否包含在safeDir内, 防止路径穿越 + relPath, err := filepath.Rel(safeDir, cleanPath) + if err != nil { + c.AddError(fmt.Errorf("failed to get relative path: %w", err)) + c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("failed to get relative path: %w", err)) + return + } + cleanPath = filepath.Join(safeDir, relPath) + + // 检查文件是否存在 + if _, err := os.Stat(cleanPath); os.IsNotExist(err) { + c.AddError(fmt.Errorf("file not found: %s", cleanPath)) + c.ErrorUseHandle(http.StatusNotFound, fmt.Errorf("file not found")) + return + } + + // 打开文件 + file, err := os.Open(cleanPath) + if err != nil { + c.AddError(fmt.Errorf("failed to open file %s: %w", cleanPath, err)) + c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to open file: %w", err)) + return + } + defer file.Close() + + // 获取文件信息以获取文件大小 + fileInfo, err := file.Stat() + if err != nil { + c.AddError(fmt.Errorf("failed to get file info for %s: %w", cleanPath, err)) + c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to get file info: %w", err)) + return + } + // 判断是否是dir + if fileInfo.IsDir() { + c.AddError(fmt.Errorf("path is a directory, not a file: %s", cleanPath)) + c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("path is a directory")) + return + } + + c.SetHeader("Content-Type", "text/plain; charset=utf-8") + + c.SetBodyStream(file, int(fileInfo.Size())) +} +*/ + // JSON 向响应写入 JSON 数据 // 设置 Content-Type 为 application/json -func (c *Context) JSON(code int, obj interface{}) { +func (c *Context) JSON(code int, obj any) { c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") c.Writer.WriteHeader(code) if err := json.MarshalWrite(c.Writer, obj); err != nil { c.AddError(fmt.Errorf("failed to marshal JSON: %w", err)) + c.Errorf("failed to marshal JSON: %s", err) c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to marshal JSON: %w", err)) return } @@ -295,7 +418,7 @@ func (c *Context) JSON(code int, obj interface{}) { // GOB 向响应写入GOB数据 // 设置 Content-Type 为 application/octet-stream -func (c *Context) GOB(code int, obj interface{}) { +func (c *Context) GOB(code int, obj any) { c.Writer.Header().Set("Content-Type", "application/octet-stream") // 设置合适的 Content-Type c.Writer.WriteHeader(code) // GOB 编码 @@ -307,11 +430,25 @@ func (c *Context) GOB(code int, obj interface{}) { } } +// WANF向响应写入WANF数据 +// 设置 application/vnd.wjqserver.wanf; charset=utf-8 +func (c *Context) WANF(code int, obj any) { + c.Writer.Header().Set("Content-Type", "application/vnd.wjqserver.wanf; charset=utf-8") + c.Writer.WriteHeader(code) + // WANF 编码 + encoder := wanf.NewStreamEncoder(c.Writer) + if err := encoder.Encode(obj); err != nil { + c.AddError(fmt.Errorf("failed to encode WANF: %w", err)) + c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to encode WANF: %w", err)) + return + } +} + // HTML 渲染 HTML 模板 // 如果 Engine 配置了 HTMLRender,则使用它进行渲染 // 否则,会进行简单的字符串输出 // 预留接口,可以扩展为支持多种模板引擎 -func (c *Context) HTML(code int, name string, obj interface{}) { +func (c *Context) HTML(code int, name string, obj any) { c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8") c.Writer.WriteHeader(code) @@ -342,7 +479,7 @@ func (c *Context) Redirect(code int, location string) { } // ShouldBindJSON 尝试将请求体绑定到 JSON 对象 -func (c *Context) ShouldBindJSON(obj interface{}) error { +func (c *Context) ShouldBindJSON(obj any) error { if c.Request.Body == nil { return errors.New("request body is empty") } @@ -353,10 +490,28 @@ func (c *Context) ShouldBindJSON(obj interface{}) error { return nil } +// ShouldBindWANF +func (c *Context) ShouldBindWANF(obj any) error { + if c.Request.Body == nil { + return errors.New("request body is empty") + } + decoder, err := wanf.NewStreamDecoder(c.Request.Body) + if err != nil { + return fmt.Errorf("failed to create WANF decoder: %w", err) + } + + if err := decoder.Decode(obj); err != nil { + return fmt.Errorf("WANF binding error: %w", err) + } + return nil +} + +// Deprecated: This function is a reserved placeholder for future API extensions +// and is not yet implemented. It will either be properly defined or removed in v2.0.0. Do not use. // ShouldBind 尝试将请求体绑定到各种类型(JSON, Form, XML 等) // 这是一个复杂的通用绑定接口,通常根据 Content-Type 或其他头部来判断绑定方式 // 预留接口,可根据项目需求进行扩展 -func (c *Context) ShouldBind(obj interface{}) error { +func (c *Context) ShouldBind(obj any) error { // TODO: 完整的通用绑定逻辑 // 可以根据 c.Request.Header.Get("Content-Type") 来判断是 JSON, Form, XML 等 // 例如: @@ -409,7 +564,7 @@ func (c *Context) Err() error { // Value returns the value associated with this context for key, or nil if no // value is associated with key. // 可以用于从 Context 中获取与特定键关联的值,包括 Go 原生 Context 的值和 Touka Context 的 Keys -func (c *Context) Value(key interface{}) interface{} { +func (c *Context) Value(key any) any { if keyAsString, ok := key.(string); ok { if val, exists := c.Get(keyAsString); exists { return val @@ -724,7 +879,7 @@ func (c *Context) GetRequestURIPath() string { // 将文件内容作为响应body func (c *Context) SetRespBodyFile(code int, filePath string) { // 清理path - cleanPath := path.Clean(filePath) + cleanPath := filepath.Clean(filePath) // 打开文件 file, err := os.Open(cleanPath) @@ -744,7 +899,7 @@ func (c *Context) SetRespBodyFile(code int, filePath string) { } // 尝试根据文件扩展名猜测 Content-Type - contentType := mime.TypeByExtension(path.Ext(cleanPath)) + contentType := mime.TypeByExtension(filepath.Ext(cleanPath)) if contentType == "" { // 如果无法猜测,则使用默认的二进制流类型 contentType = "application/octet-stream" diff --git a/engine.go b/engine.go index 581258c..0a95765 100644 --- a/engine.go +++ b/engine.go @@ -421,6 +421,41 @@ func getHandlerName(h HandlerFunc) string { } +const MaxSkippedNodesCap = 256 + +// TempSkippedNodesPool 存储 *[]skippedNode 以复用内存 +var TempSkippedNodesPool = sync.Pool{ + New: func() any { + // 返回一个指向容量为 256 的新切片的指针 + s := make([]skippedNode, 0, MaxSkippedNodesCap) + return &s + }, +} + +// GetTempSkippedNodes 从 Pool 中获取一个 *[]skippedNode 指针 +func GetTempSkippedNodes() *[]skippedNode { + // 直接返回 Pool 中存储的指针 + return TempSkippedNodesPool.Get().(*[]skippedNode) +} + +// PutTempSkippedNodes 将用完的 *[]skippedNode 指针放回 Pool +func PutTempSkippedNodes(skippedNodes *[]skippedNode) { + if skippedNodes == nil || *skippedNodes == nil { + return + } + + // 检查容量是否符合预期。如果容量不足,则丢弃,不放回 Pool。 + if cap(*skippedNodes) < MaxSkippedNodesCap { + return // 丢弃该对象,让 Pool 在下次 Get 时通过 New 重新分配 + } + + // 长度重置为 0,保留容量,实现复用 + *skippedNodes = (*skippedNodes)[:0] + + // 将指针存回 Pool + TempSkippedNodesPool.Put(skippedNodes) +} + // 405中间件 func MethodNotAllowed() HandlerFunc { return func(c *Context) { @@ -432,9 +467,10 @@ func MethodNotAllowed() HandlerFunc { // 如果是 OPTIONS 请求,尝试查找所有允许的方法 allowedMethods := []string{} for _, treeIter := range engine.methodTrees { - var tempSkippedNodes []skippedNode // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) + tempSkippedNodes := GetTempSkippedNodes() + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) + PutTempSkippedNodes(tempSkippedNodes) if value.handlers != nil { allowedMethods = append(allowedMethods, treeIter.method) } @@ -451,9 +487,10 @@ func MethodNotAllowed() HandlerFunc { if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 continue } - var tempSkippedNodes []skippedNode // 用于临时查找,不影响主 Context // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) // 只查找是否存在,不需要参数 + tempSkippedNodes := GetTempSkippedNodes() + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 + PutTempSkippedNodes(tempSkippedNodes) if value.handlers != nil { // 使用定义的ErrorHandle处理 engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) @@ -661,9 +698,8 @@ func (engine *Engine) handleRequest(c *Context) { // 查找匹配的节点和处理函数 // 这里传递 &c.Params 而不是重新创建,以利用 Context 中预分配的容量 // skippedNodes 内部使用,因此无需从外部传入已分配的 slice - var skippedNodes []skippedNode // 用于回溯的跳过节点 // 直接在 rootNode 上调用 getValue 方法 - value := rootNode.getValue(requestPath, &c.Params, &skippedNodes, true) // unescape=true 对路径参数进行 URL 解码 + value := rootNode.getValue(requestPath, &c.Params, &c.SkippedNodes, true) // unescape=true 对路径参数进行 URL 解码 if value.handlers != nil { //c.handlers = engine.combineHandlers(engine.globalHandlers, value.handlers) // 组合全局中间件和路由处理函数 diff --git a/fileserver.go b/fileserver.go index 5b7f248..1aa1aaf 100644 --- a/fileserver.go +++ b/fileserver.go @@ -6,7 +6,6 @@ package touka import ( "errors" - "fmt" "net/http" "path" "strings" @@ -19,13 +18,19 @@ var allowedFileServerMethods = map[string]struct{}{ http.MethodHead: {}, } +var ( + ErrInputFSisNil = errors.New("input FS is nil") + ErrMethodNotAllowed = errors.New("method not allowed") +) + // FileServer方式, 返回一个HandleFunc, 统一化处理 func FileServer(fs http.FileSystem) HandlerFunc { if fs == nil { return func(c *Context) { - c.ErrorUseHandle(500, errors.New("Input FileSystem is nil")) + c.ErrorUseHandle(http.StatusInternalServerError, ErrInputFSisNil) } } + fileServerInstance := http.FileServer(fs) return func(c *Context) { FileServerHandleServe(c, fileServerInstance) @@ -37,7 +42,6 @@ func FileServer(fs http.FileSystem) HandlerFunc { func FileServerHandleServe(c *Context, fsHandle http.Handler) { if fsHandle == nil { - ErrInputFSisNil := errors.New("Input FileSystem Handle is nil") c.AddError(ErrInputFSisNil) // 500 c.ErrorUseHandle(http.StatusInternalServerError, ErrInputFSisNil) @@ -59,7 +63,7 @@ func FileServerHandleServe(c *Context, fsHandle http.Handler) { return } else { // 否则,返回 405 Method Not Allowed - c.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, fmt.Errorf("Method %s is Not Allowed on FileServer", c.Request.Method)) + c.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, ErrMethodNotAllowed) } } else { c.Next() @@ -240,7 +244,7 @@ func (engine *Engine) StaticFS(relativePath string, fs http.FileSystem) { relativePath += "/" } - fileServer := http.FileServer(fs) + fileServer := http.StripPrefix(relativePath, http.FileServer(fs)) engine.ANY(relativePath+"*filepath", GetStaticFSHandleFunc(fileServer)) } @@ -254,7 +258,7 @@ func (group *RouterGroup) StaticFS(relativePath string, fs http.FileSystem) { relativePath += "/" } - fileServer := http.FileServer(fs) + fileServer := http.StripPrefix(relativePath, http.FileServer(fs)) group.ANY(relativePath+"*filepath", GetStaticFSHandleFunc(fileServer)) } diff --git a/go.mod b/go.mod index e9d0304..f9d10a9 100644 --- a/go.mod +++ b/go.mod @@ -1,15 +1,16 @@ module github.com/infinite-iroha/touka -go 1.24.5 +go 1.25.1 require ( github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 github.com/WJQSERVER-STUDIO/httpc v0.8.2 + github.com/WJQSERVER/wanf v0.0.3 github.com/fenthope/reco v0.0.4 - github.com/go-json-experiment/json v0.0.0-20250725192818-e39067aee2d2 + github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e ) require ( github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/net v0.42.0 // indirect + golang.org/x/net v0.49.0 // indirect ) diff --git a/go.sum b/go.sum index d9a63e3..b75fec4 100644 --- a/go.sum +++ b/go.sum @@ -2,11 +2,13 @@ github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 h1:AiIHXP21LpK7pFfqUlUstgQEWzjbe github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok= github.com/WJQSERVER-STUDIO/httpc v0.8.2 h1:PFPLodV0QAfGEP6915J57vIqoKu9cGuuiXG/7C9TNUk= github.com/WJQSERVER-STUDIO/httpc v0.8.2/go.mod h1:8WhHVRO+olDFBSvL5PC/bdMkb6U3vRdPJ4p4pnguV5Y= +github.com/WJQSERVER/wanf v0.0.3 h1:OqhG7ETiR5Knqr0lmbb+iUMw9O7re2vEogjVf06QevM= +github.com/WJQSERVER/wanf v0.0.3/go.mod h1:q2Pyg+G+s1acMWxrbI4CwS/Yk76/BzLREEdZ8iFwUNE= github.com/fenthope/reco v0.0.4 h1:yo2g3aWwdoMpaZWZX4SdZOW7mCK82RQIU/YI8ZUQThM= github.com/fenthope/reco v0.0.4/go.mod h1:eMyS8HpdMVdJ/2WJt6Cvt8P1EH9Igzj5lSJrgc+0jeg= -github.com/go-json-experiment/json v0.0.0-20250725192818-e39067aee2d2 h1:iizUGZ9pEquQS5jTGkh4AqeeHCMbfbjeb0zMt0aEFzs= -github.com/go-json-experiment/json v0.0.0-20250725192818-e39067aee2d2/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M= +github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU= +github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= -golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= diff --git a/sse.go b/sse.go new file mode 100644 index 0000000..3b98800 --- /dev/null +++ b/sse.go @@ -0,0 +1,184 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2025 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +import ( + "bytes" + "io" + "net/http" + "strings" +) + +// Event 代表一个服务器发送事件(SSE). +type Event struct { + // Event 是事件的名称. + Event string + // Data 是事件的内容, 可以是多行文本. + Data string + // Id 是事件的唯一标识符. + Id string + // Retry 是指定客户端在连接丢失后应等待多少毫秒后尝试重新连接. + Retry string +} + +// Render 将事件格式化并写入给定的 writer. +// 通过逐行处理数据, 此方法可防止因数据中包含换行符而导致的CRLF注入问题. +// 为了性能, 它使用 bytes.Buffer 并通过 WriteTo 直接写入, 以避免不必要的内存分配. +func (e *Event) Render(w io.Writer) error { + var buf bytes.Buffer + + if len(e.Id) > 0 { + buf.WriteString("id: ") + buf.WriteString(e.Id) + buf.WriteString("\n") + } + if len(e.Event) > 0 { + buf.WriteString("event: ") + buf.WriteString(e.Event) + buf.WriteString("\n") + } + if len(e.Data) > 0 { + lines := strings.Split(e.Data, "\n") + for _, line := range lines { + buf.WriteString("data: ") + buf.WriteString(line) + buf.WriteString("\n") + } + } + if len(e.Retry) > 0 { + buf.WriteString("retry: ") + buf.WriteString(e.Retry) + buf.WriteString("\n") + } + + // 每个事件都以一个额外的换行符结尾. + buf.WriteString("\n") + + // 直接将 buffer 的内容写入 writer, 避免生成中间字符串. + _, err := buf.WriteTo(w) + return err +} + +// EventStream 启动一个 SSE 事件流. +// 这是推荐的、更简单安全的方式, 采用阻塞和回调的设计, 框架负责管理连接生命周期. +// +// 详细用法: +// +// r.GET("/sse/callback", func(c *touka.Context) { +// // streamer 回调函数会在一个循环中被调用. +// c.EventStream(func(w io.Writer) bool { +// event := touka.Event{ +// Event: "time-tick", +// Data: time.Now().Format(time.RFC1123), +// } +// +// if err := event.Render(w); err != nil { +// // 发生写入错误, 停止发送. +// return false // 返回 false 结束事件流. +// } +// +// time.Sleep(2 * time.Second) +// return true // 返回 true 继续事件流. +// }) +// // 当事件流结束后(例如客户端关闭页面), 这行代码会被执行. +// fmt.Println("Client disconnected from /sse/callback") +// }) +func (c *Context) EventStream(streamer func(w io.Writer) bool) { + // 为现代网络协议优化头部. + c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8") + c.Writer.Header().Set("Cache-Control", "no-cache, no-transform") + c.Writer.Header().Del("Connection") + c.Writer.Header().Del("Transfer-Encoding") + + c.Writer.WriteHeader(http.StatusOK) + c.Writer.Flush() // 直接调用, ResponseWriter 接口保证了 Flush 方法的存在. + + for { + select { + case <-c.Request.Context().Done(): + return + default: + if !streamer(c.Writer) { + return + } + c.Writer.Flush() + } + } +} + +// EventStreamChan 返回用于 SSE 事件流的 channel. +// 这是为高级并发场景设计的、更灵活的API. +// +// 重要: +// - 调用者必须 close(eventChan) 来结束事件流. +// - 调用者必须在独立的 goroutine 中消费 errChan 来处理错误和连接断开. +// - 为防止 goroutine 泄漏, 建议发送方在 select 中同时监听 c.Request.Context().Done(). +// +// 详细用法: +// +// r.GET("/sse/channel", func(c *touka.Context) { +// eventChan, errChan := c.EventStreamChan() +// +// // 必须在独立的goroutine中处理错误和连接断开. +// go func() { +// if err := <-errChan; err != nil { +// c.Errorf("SSE channel error: %v", err) +// } +// }() +// +// // 在另一个goroutine中异步发送事件. +// go func() { +// // 重要: 必须在逻辑结束时关闭channel, 以通知框架. +// defer close(eventChan) +// +// for i := 1; i <= 5; i++ { +// select { +// case <-c.Request.Context().Done(): +// return // 客户端已断开, 退出 goroutine. +// default: +// eventChan <- touka.Event{ +// Id: fmt.Sprintf("%d", i), +// Data: "hello from channel", +// } +// time.Sleep(2 * time.Second) +// } +// } +// }() +// }) +func (c *Context) EventStreamChan() (chan<- Event, <-chan error) { + eventChan := make(chan Event) + errChan := make(chan error, 1) + + c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8") + c.Writer.Header().Set("Cache-Control", "no-cache, no-transform") + c.Writer.Header().Del("Connection") + c.Writer.Header().Del("Transfer-Encoding") + + c.Writer.WriteHeader(http.StatusOK) + c.Writer.Flush() + + go func() { + defer close(errChan) + + for { + select { + case event, ok := <-eventChan: + if !ok { + return + } + if err := event.Render(c.Writer); err != nil { + errChan <- err + return + } + c.Writer.Flush() + case <-c.Request.Context().Done(): + errChan <- c.Request.Context().Err() + return + } + } + }() + + return eventChan, errChan +} diff --git a/tree.go b/tree.go index 09711a1..31246a5 100644 --- a/tree.go +++ b/tree.go @@ -5,7 +5,6 @@ package touka import ( - "bytes" "net/url" "strings" "unicode" @@ -27,12 +26,6 @@ func BytesToString(b []byte) string { return unsafe.String(unsafe.SliceData(b), len(b)) } -var ( - strColon = []byte(":") // 定义字节切片常量, 表示冒号, 用于路径参数识别 - strStar = []byte("*") // 定义字节切片常量, 表示星号, 用于捕获所有路径识别 - strSlash = []byte("/") // 定义字节切片常量, 表示斜杠, 用于路径分隔符识别 -) - // Param 是单个 URL 参数, 由键和值组成. type Param struct { Key string // 参数的键名 @@ -106,17 +99,14 @@ func (n *node) addChild(child *node) { // countParams 计算路径中参数(冒号)和捕获所有(星号)的数量. func countParams(path string) uint16 { - var n uint16 - s := StringToBytes(path) // 将路径字符串转换为字节切片 - n += uint16(bytes.Count(s, strColon)) // 统计冒号的数量 - n += uint16(bytes.Count(s, strStar)) // 统计星号的数量 - return n + colons := strings.Count(path, ":") + stars := strings.Count(path, "*") + return uint16(colons + stars) } // countSections 计算路径中斜杠('/')的数量, 即路径段的数量. func countSections(path string) uint16 { - s := StringToBytes(path) // 将路径字符串转换为字节切片 - return uint16(bytes.Count(s, strSlash)) // 统计斜杠的数量 + return uint16(strings.Count(path, "/")) } // nodeType 定义了节点的类型. @@ -418,10 +408,10 @@ func (n *node) insertChild(path string, fullPath string, handlers HandlersChain) fullPath: fullPath, // 设置完整路径 } - n.addChild(child) // 添加子节点 - n.indices = string('/') // 索引设置为 '/' - n = child // 移动到新创建的 catchAll 节点 - n.priority++ // 增加优先级 + n.addChild(child) // 添加子节点 + n.indices = "/" // 索引设置为 '/' + n = child // 移动到新创建的 catchAll 节点 + n.priority++ // 增加优先级 // 第二个节点: 包含变量的节点 child = &node{