mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-02-03 00:41:10 +08:00
Compare commits
19 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a6e278d458 | ||
|
|
7b536ac137 | ||
|
|
b348d7d41f | ||
|
|
60b2936eff | ||
|
|
9cfc82a347 | ||
|
|
904aea5df8 | ||
|
|
ee0ebc986c | ||
|
|
e4aaaa1583 | ||
|
|
1361f6e237 | ||
|
|
a6458cca16 | ||
|
|
76a89800a2 | ||
|
|
4955fb9d03 | ||
|
|
5b98310de5 | ||
|
|
f1ac0dd6ff | ||
|
|
38ff5126e3 | ||
|
|
b4e073ae2f | ||
|
|
af0a99acda | ||
|
|
3ffde5742c | ||
|
|
016df0efe4 |
7 changed files with 428 additions and 56 deletions
193
context.go
193
context.go
|
|
@ -18,11 +18,12 @@ import (
|
|||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/WJQSERVER/wanf"
|
||||
"github.com/fenthope/reco"
|
||||
"github.com/go-json-experiment/json"
|
||||
|
||||
|
|
@ -42,7 +43,7 @@ type Context struct {
|
|||
index int8 // 当前执行到处理链的哪个位置
|
||||
|
||||
mu sync.RWMutex
|
||||
Keys map[string]interface{} // 用于在中间件之间传递数据
|
||||
Keys map[string]any // 用于在中间件之间传递数据
|
||||
|
||||
Errors []error // 用于收集处理过程中的错误
|
||||
|
||||
|
|
@ -64,6 +65,10 @@ type Context struct {
|
|||
|
||||
// 请求体Body大小限制
|
||||
MaxRequestBodySize int64
|
||||
|
||||
// skippedNodes 用于记录跳过的节点信息,以便回溯
|
||||
// 通常在处理嵌套路由时使用
|
||||
SkippedNodes []skippedNode
|
||||
}
|
||||
|
||||
// --- Context 相关方法实现 ---
|
||||
|
|
@ -77,20 +82,30 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
|
|||
} else {
|
||||
c.Writer = newResponseWriter(w)
|
||||
}
|
||||
//c.Writer = newResponseWriter(w)
|
||||
|
||||
c.Request = req
|
||||
c.Params = c.Params[:0] // 清空 Params 切片,而不是重新分配,以复用底层数组
|
||||
//c.Params = c.Params[:0] // 清空 Params 切片,而不是重新分配,以复用底层数组
|
||||
//避免params长度为0
|
||||
if cap(c.Params) > 0 {
|
||||
c.Params = c.Params[:0]
|
||||
} else {
|
||||
c.Params = make(Params, 0, c.engine.maxParams)
|
||||
}
|
||||
c.handlers = nil
|
||||
c.index = -1 // 初始为 -1,`Next()` 将其设置为 0
|
||||
c.Keys = make(map[string]interface{}) // 每次请求重新创建 map,避免数据污染
|
||||
c.Keys = make(map[string]any) // 每次请求重新创建 map,避免数据污染
|
||||
c.Errors = c.Errors[:0] // 清空 Errors 切片
|
||||
c.queryCache = nil // 清空查询参数缓存
|
||||
c.formCache = nil // 清空表单数据缓存
|
||||
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
|
||||
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
|
||||
c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize
|
||||
// c.HTTPClient 和 c.engine 保持不变,它们引用 Engine 实例的成员
|
||||
|
||||
if cap(c.SkippedNodes) > 0 {
|
||||
c.SkippedNodes = c.SkippedNodes[:0]
|
||||
} else {
|
||||
c.SkippedNodes = make([]skippedNode, 0, 256)
|
||||
}
|
||||
}
|
||||
|
||||
// Next 在处理链中执行下一个处理函数
|
||||
|
|
@ -122,10 +137,10 @@ func (c *Context) AbortWithStatus(code int) {
|
|||
|
||||
// Set 将一个键值对存储到 Context 中
|
||||
// 这是一个线程安全的操作,用于在中间件之间传递数据
|
||||
func (c *Context) Set(key string, value interface{}) {
|
||||
func (c *Context) Set(key string, value any) {
|
||||
c.mu.Lock() // 加写锁
|
||||
if c.Keys == nil {
|
||||
c.Keys = make(map[string]interface{})
|
||||
c.Keys = make(map[string]any)
|
||||
}
|
||||
c.Keys[key] = value
|
||||
c.mu.Unlock() // 解写锁
|
||||
|
|
@ -133,7 +148,7 @@ func (c *Context) Set(key string, value interface{}) {
|
|||
|
||||
// Get 从 Context 中获取一个值
|
||||
// 这是一个线程安全的操作
|
||||
func (c *Context) Get(key string) (value interface{}, exists bool) {
|
||||
func (c *Context) Get(key string) (value any, exists bool) {
|
||||
c.mu.RLock() // 加读锁
|
||||
value, exists = c.Keys[key]
|
||||
c.mu.RUnlock() // 解读锁
|
||||
|
|
@ -208,7 +223,7 @@ func (c *Context) GetDuration(key string) (value time.Duration, exists bool) {
|
|||
|
||||
// MustGet 从 Context 中获取一个值,如果不存在则 panic
|
||||
// 适用于确定值一定存在的场景
|
||||
func (c *Context) MustGet(key string) interface{} {
|
||||
func (c *Context) MustGet(key string) any {
|
||||
if value, exists := c.Get(key); exists {
|
||||
return value
|
||||
}
|
||||
|
|
@ -269,7 +284,7 @@ func (c *Context) Raw(code int, contentType string, data []byte) {
|
|||
}
|
||||
|
||||
// String 向响应写入格式化的字符串
|
||||
func (c *Context) String(code int, format string, values ...interface{}) {
|
||||
func (c *Context) String(code int, format string, values ...any) {
|
||||
c.Writer.WriteHeader(code)
|
||||
c.Writer.Write([]byte(fmt.Sprintf(format, values...)))
|
||||
}
|
||||
|
|
@ -281,13 +296,121 @@ func (c *Context) Text(code int, text string) {
|
|||
c.Writer.Write([]byte(text))
|
||||
}
|
||||
|
||||
// FileText
|
||||
func (c *Context) FileText(code int, filePath string) {
|
||||
// 清理path
|
||||
cleanPath := filepath.Clean(filePath)
|
||||
if !filepath.IsAbs(cleanPath) {
|
||||
c.AddError(fmt.Errorf("relative path not allowed: %s", cleanPath))
|
||||
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("relative path not allowed"))
|
||||
return
|
||||
}
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(cleanPath); os.IsNotExist(err) {
|
||||
c.AddError(fmt.Errorf("file not found: %s", cleanPath))
|
||||
c.ErrorUseHandle(http.StatusNotFound, fmt.Errorf("file not found"))
|
||||
return
|
||||
}
|
||||
|
||||
// 打开文件
|
||||
file, err := os.Open(cleanPath)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to open file %s: %w", cleanPath, err))
|
||||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to open file: %w", err))
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 获取文件信息以获取文件大小
|
||||
fileInfo, err := file.Stat()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to get file info for %s: %w", cleanPath, err))
|
||||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to get file info: %w", err))
|
||||
return
|
||||
}
|
||||
// 判断是否是dir
|
||||
if fileInfo.IsDir() {
|
||||
c.AddError(fmt.Errorf("path is a directory, not a file: %s", cleanPath))
|
||||
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("path is a directory"))
|
||||
return
|
||||
}
|
||||
|
||||
c.SetHeader("Content-Type", "text/plain; charset=utf-8")
|
||||
|
||||
c.SetBodyStream(file, int(fileInfo.Size()))
|
||||
}
|
||||
|
||||
/*
|
||||
// not fot work
|
||||
// FileTextSafeDir
|
||||
func (c *Context) FileTextSafeDir(code int, filePath string, safeDir string) {
|
||||
|
||||
// 清理path
|
||||
cleanPath := path.Clean(filePath)
|
||||
if !filepath.IsAbs(cleanPath) {
|
||||
c.AddError(fmt.Errorf("relative path not allowed: %s", cleanPath))
|
||||
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("relative path not allowed"))
|
||||
return
|
||||
}
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
c.AddError(fmt.Errorf("path traversal attempt detected: %s", cleanPath))
|
||||
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("path traversal attempt detected"))
|
||||
return
|
||||
}
|
||||
|
||||
// 判断filePath是否包含在safeDir内, 防止路径穿越
|
||||
relPath, err := filepath.Rel(safeDir, cleanPath)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to get relative path: %w", err))
|
||||
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("failed to get relative path: %w", err))
|
||||
return
|
||||
}
|
||||
cleanPath = filepath.Join(safeDir, relPath)
|
||||
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(cleanPath); os.IsNotExist(err) {
|
||||
c.AddError(fmt.Errorf("file not found: %s", cleanPath))
|
||||
c.ErrorUseHandle(http.StatusNotFound, fmt.Errorf("file not found"))
|
||||
return
|
||||
}
|
||||
|
||||
// 打开文件
|
||||
file, err := os.Open(cleanPath)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to open file %s: %w", cleanPath, err))
|
||||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to open file: %w", err))
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 获取文件信息以获取文件大小
|
||||
fileInfo, err := file.Stat()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to get file info for %s: %w", cleanPath, err))
|
||||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to get file info: %w", err))
|
||||
return
|
||||
}
|
||||
// 判断是否是dir
|
||||
if fileInfo.IsDir() {
|
||||
c.AddError(fmt.Errorf("path is a directory, not a file: %s", cleanPath))
|
||||
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("path is a directory"))
|
||||
return
|
||||
}
|
||||
|
||||
c.SetHeader("Content-Type", "text/plain; charset=utf-8")
|
||||
|
||||
c.SetBodyStream(file, int(fileInfo.Size()))
|
||||
}
|
||||
*/
|
||||
|
||||
// JSON 向响应写入 JSON 数据
|
||||
// 设置 Content-Type 为 application/json
|
||||
func (c *Context) JSON(code int, obj interface{}) {
|
||||
func (c *Context) JSON(code int, obj any) {
|
||||
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
c.Writer.WriteHeader(code)
|
||||
if err := json.MarshalWrite(c.Writer, obj); err != nil {
|
||||
c.AddError(fmt.Errorf("failed to marshal JSON: %w", err))
|
||||
c.Errorf("failed to marshal JSON: %s", err)
|
||||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to marshal JSON: %w", err))
|
||||
return
|
||||
}
|
||||
|
|
@ -295,7 +418,7 @@ func (c *Context) JSON(code int, obj interface{}) {
|
|||
|
||||
// GOB 向响应写入GOB数据
|
||||
// 设置 Content-Type 为 application/octet-stream
|
||||
func (c *Context) GOB(code int, obj interface{}) {
|
||||
func (c *Context) GOB(code int, obj any) {
|
||||
c.Writer.Header().Set("Content-Type", "application/octet-stream") // 设置合适的 Content-Type
|
||||
c.Writer.WriteHeader(code)
|
||||
// GOB 编码
|
||||
|
|
@ -307,11 +430,25 @@ func (c *Context) GOB(code int, obj interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
// WANF向响应写入WANF数据
|
||||
// 设置 application/vnd.wjqserver.wanf; charset=utf-8
|
||||
func (c *Context) WANF(code int, obj any) {
|
||||
c.Writer.Header().Set("Content-Type", "application/vnd.wjqserver.wanf; charset=utf-8")
|
||||
c.Writer.WriteHeader(code)
|
||||
// WANF 编码
|
||||
encoder := wanf.NewStreamEncoder(c.Writer)
|
||||
if err := encoder.Encode(obj); err != nil {
|
||||
c.AddError(fmt.Errorf("failed to encode WANF: %w", err))
|
||||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to encode WANF: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// HTML 渲染 HTML 模板
|
||||
// 如果 Engine 配置了 HTMLRender,则使用它进行渲染
|
||||
// 否则,会进行简单的字符串输出
|
||||
// 预留接口,可以扩展为支持多种模板引擎
|
||||
func (c *Context) HTML(code int, name string, obj interface{}) {
|
||||
func (c *Context) HTML(code int, name string, obj any) {
|
||||
c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
c.Writer.WriteHeader(code)
|
||||
|
||||
|
|
@ -342,7 +479,7 @@ func (c *Context) Redirect(code int, location string) {
|
|||
}
|
||||
|
||||
// ShouldBindJSON 尝试将请求体绑定到 JSON 对象
|
||||
func (c *Context) ShouldBindJSON(obj interface{}) error {
|
||||
func (c *Context) ShouldBindJSON(obj any) error {
|
||||
if c.Request.Body == nil {
|
||||
return errors.New("request body is empty")
|
||||
}
|
||||
|
|
@ -353,10 +490,28 @@ func (c *Context) ShouldBindJSON(obj interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// ShouldBindWANF
|
||||
func (c *Context) ShouldBindWANF(obj any) error {
|
||||
if c.Request.Body == nil {
|
||||
return errors.New("request body is empty")
|
||||
}
|
||||
decoder, err := wanf.NewStreamDecoder(c.Request.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create WANF decoder: %w", err)
|
||||
}
|
||||
|
||||
if err := decoder.Decode(obj); err != nil {
|
||||
return fmt.Errorf("WANF binding error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deprecated: This function is a reserved placeholder for future API extensions
|
||||
// and is not yet implemented. It will either be properly defined or removed in v2.0.0. Do not use.
|
||||
// ShouldBind 尝试将请求体绑定到各种类型(JSON, Form, XML 等)
|
||||
// 这是一个复杂的通用绑定接口,通常根据 Content-Type 或其他头部来判断绑定方式
|
||||
// 预留接口,可根据项目需求进行扩展
|
||||
func (c *Context) ShouldBind(obj interface{}) error {
|
||||
func (c *Context) ShouldBind(obj any) error {
|
||||
// TODO: 完整的通用绑定逻辑
|
||||
// 可以根据 c.Request.Header.Get("Content-Type") 来判断是 JSON, Form, XML 等
|
||||
// 例如:
|
||||
|
|
@ -409,7 +564,7 @@ func (c *Context) Err() error {
|
|||
// Value returns the value associated with this context for key, or nil if no
|
||||
// value is associated with key.
|
||||
// 可以用于从 Context 中获取与特定键关联的值,包括 Go 原生 Context 的值和 Touka Context 的 Keys
|
||||
func (c *Context) Value(key interface{}) interface{} {
|
||||
func (c *Context) Value(key any) any {
|
||||
if keyAsString, ok := key.(string); ok {
|
||||
if val, exists := c.Get(keyAsString); exists {
|
||||
return val
|
||||
|
|
@ -724,7 +879,7 @@ func (c *Context) GetRequestURIPath() string {
|
|||
// 将文件内容作为响应body
|
||||
func (c *Context) SetRespBodyFile(code int, filePath string) {
|
||||
// 清理path
|
||||
cleanPath := path.Clean(filePath)
|
||||
cleanPath := filepath.Clean(filePath)
|
||||
|
||||
// 打开文件
|
||||
file, err := os.Open(cleanPath)
|
||||
|
|
@ -744,7 +899,7 @@ func (c *Context) SetRespBodyFile(code int, filePath string) {
|
|||
}
|
||||
|
||||
// 尝试根据文件扩展名猜测 Content-Type
|
||||
contentType := mime.TypeByExtension(path.Ext(cleanPath))
|
||||
contentType := mime.TypeByExtension(filepath.Ext(cleanPath))
|
||||
if contentType == "" {
|
||||
// 如果无法猜测,则使用默认的二进制流类型
|
||||
contentType = "application/octet-stream"
|
||||
|
|
|
|||
48
engine.go
48
engine.go
|
|
@ -421,6 +421,41 @@ func getHandlerName(h HandlerFunc) string {
|
|||
|
||||
}
|
||||
|
||||
const MaxSkippedNodesCap = 256
|
||||
|
||||
// TempSkippedNodesPool 存储 *[]skippedNode 以复用内存
|
||||
var TempSkippedNodesPool = sync.Pool{
|
||||
New: func() any {
|
||||
// 返回一个指向容量为 256 的新切片的指针
|
||||
s := make([]skippedNode, 0, MaxSkippedNodesCap)
|
||||
return &s
|
||||
},
|
||||
}
|
||||
|
||||
// GetTempSkippedNodes 从 Pool 中获取一个 *[]skippedNode 指针
|
||||
func GetTempSkippedNodes() *[]skippedNode {
|
||||
// 直接返回 Pool 中存储的指针
|
||||
return TempSkippedNodesPool.Get().(*[]skippedNode)
|
||||
}
|
||||
|
||||
// PutTempSkippedNodes 将用完的 *[]skippedNode 指针放回 Pool
|
||||
func PutTempSkippedNodes(skippedNodes *[]skippedNode) {
|
||||
if skippedNodes == nil || *skippedNodes == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查容量是否符合预期。如果容量不足,则丢弃,不放回 Pool。
|
||||
if cap(*skippedNodes) < MaxSkippedNodesCap {
|
||||
return // 丢弃该对象,让 Pool 在下次 Get 时通过 New 重新分配
|
||||
}
|
||||
|
||||
// 长度重置为 0,保留容量,实现复用
|
||||
*skippedNodes = (*skippedNodes)[:0]
|
||||
|
||||
// 将指针存回 Pool
|
||||
TempSkippedNodesPool.Put(skippedNodes)
|
||||
}
|
||||
|
||||
// 405中间件
|
||||
func MethodNotAllowed() HandlerFunc {
|
||||
return func(c *Context) {
|
||||
|
|
@ -432,9 +467,10 @@ func MethodNotAllowed() HandlerFunc {
|
|||
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
|
||||
allowedMethods := []string{}
|
||||
for _, treeIter := range engine.methodTrees {
|
||||
var tempSkippedNodes []skippedNode
|
||||
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
||||
value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false)
|
||||
tempSkippedNodes := GetTempSkippedNodes()
|
||||
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
|
||||
PutTempSkippedNodes(tempSkippedNodes)
|
||||
if value.handlers != nil {
|
||||
allowedMethods = append(allowedMethods, treeIter.method)
|
||||
}
|
||||
|
|
@ -451,9 +487,10 @@ func MethodNotAllowed() HandlerFunc {
|
|||
if treeIter.method == httpMethod { // 已经处理过当前方法,跳过
|
||||
continue
|
||||
}
|
||||
var tempSkippedNodes []skippedNode // 用于临时查找,不影响主 Context
|
||||
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
||||
value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) // 只查找是否存在,不需要参数
|
||||
tempSkippedNodes := GetTempSkippedNodes()
|
||||
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数
|
||||
PutTempSkippedNodes(tempSkippedNodes)
|
||||
if value.handlers != nil {
|
||||
// 使用定义的ErrorHandle处理
|
||||
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed"))
|
||||
|
|
@ -661,9 +698,8 @@ func (engine *Engine) handleRequest(c *Context) {
|
|||
// 查找匹配的节点和处理函数
|
||||
// 这里传递 &c.Params 而不是重新创建,以利用 Context 中预分配的容量
|
||||
// skippedNodes 内部使用,因此无需从外部传入已分配的 slice
|
||||
var skippedNodes []skippedNode // 用于回溯的跳过节点
|
||||
// 直接在 rootNode 上调用 getValue 方法
|
||||
value := rootNode.getValue(requestPath, &c.Params, &skippedNodes, true) // unescape=true 对路径参数进行 URL 解码
|
||||
value := rootNode.getValue(requestPath, &c.Params, &c.SkippedNodes, true) // unescape=true 对路径参数进行 URL 解码
|
||||
|
||||
if value.handlers != nil {
|
||||
//c.handlers = engine.combineHandlers(engine.globalHandlers, value.handlers) // 组合全局中间件和路由处理函数
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ package touka
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
|
|
@ -19,13 +18,19 @@ var allowedFileServerMethods = map[string]struct{}{
|
|||
http.MethodHead: {},
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInputFSisNil = errors.New("input FS is nil")
|
||||
ErrMethodNotAllowed = errors.New("method not allowed")
|
||||
)
|
||||
|
||||
// FileServer方式, 返回一个HandleFunc, 统一化处理
|
||||
func FileServer(fs http.FileSystem) HandlerFunc {
|
||||
if fs == nil {
|
||||
return func(c *Context) {
|
||||
c.ErrorUseHandle(500, errors.New("Input FileSystem is nil"))
|
||||
c.ErrorUseHandle(http.StatusInternalServerError, ErrInputFSisNil)
|
||||
}
|
||||
}
|
||||
|
||||
fileServerInstance := http.FileServer(fs)
|
||||
return func(c *Context) {
|
||||
FileServerHandleServe(c, fileServerInstance)
|
||||
|
|
@ -37,7 +42,6 @@ func FileServer(fs http.FileSystem) HandlerFunc {
|
|||
|
||||
func FileServerHandleServe(c *Context, fsHandle http.Handler) {
|
||||
if fsHandle == nil {
|
||||
ErrInputFSisNil := errors.New("Input FileSystem Handle is nil")
|
||||
c.AddError(ErrInputFSisNil)
|
||||
// 500
|
||||
c.ErrorUseHandle(http.StatusInternalServerError, ErrInputFSisNil)
|
||||
|
|
@ -59,7 +63,7 @@ func FileServerHandleServe(c *Context, fsHandle http.Handler) {
|
|||
return
|
||||
} else {
|
||||
// 否则,返回 405 Method Not Allowed
|
||||
c.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, fmt.Errorf("Method %s is Not Allowed on FileServer", c.Request.Method))
|
||||
c.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, ErrMethodNotAllowed)
|
||||
}
|
||||
} else {
|
||||
c.Next()
|
||||
|
|
@ -240,7 +244,7 @@ func (engine *Engine) StaticFS(relativePath string, fs http.FileSystem) {
|
|||
relativePath += "/"
|
||||
}
|
||||
|
||||
fileServer := http.FileServer(fs)
|
||||
fileServer := http.StripPrefix(relativePath, http.FileServer(fs))
|
||||
engine.ANY(relativePath+"*filepath", GetStaticFSHandleFunc(fileServer))
|
||||
}
|
||||
|
||||
|
|
@ -254,7 +258,7 @@ func (group *RouterGroup) StaticFS(relativePath string, fs http.FileSystem) {
|
|||
relativePath += "/"
|
||||
}
|
||||
|
||||
fileServer := http.FileServer(fs)
|
||||
fileServer := http.StripPrefix(relativePath, http.FileServer(fs))
|
||||
group.ANY(relativePath+"*filepath", GetStaticFSHandleFunc(fileServer))
|
||||
}
|
||||
|
||||
|
|
|
|||
7
go.mod
7
go.mod
|
|
@ -1,15 +1,16 @@
|
|||
module github.com/infinite-iroha/touka
|
||||
|
||||
go 1.24.5
|
||||
go 1.25.1
|
||||
|
||||
require (
|
||||
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2
|
||||
github.com/WJQSERVER-STUDIO/httpc v0.8.2
|
||||
github.com/WJQSERVER/wanf v0.0.3
|
||||
github.com/fenthope/reco v0.0.4
|
||||
github.com/go-json-experiment/json v0.0.0-20250725192818-e39067aee2d2
|
||||
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
golang.org/x/net v0.42.0 // indirect
|
||||
golang.org/x/net v0.49.0 // indirect
|
||||
)
|
||||
|
|
|
|||
10
go.sum
10
go.sum
|
|
@ -2,11 +2,13 @@ github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 h1:AiIHXP21LpK7pFfqUlUstgQEWzjbe
|
|||
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok=
|
||||
github.com/WJQSERVER-STUDIO/httpc v0.8.2 h1:PFPLodV0QAfGEP6915J57vIqoKu9cGuuiXG/7C9TNUk=
|
||||
github.com/WJQSERVER-STUDIO/httpc v0.8.2/go.mod h1:8WhHVRO+olDFBSvL5PC/bdMkb6U3vRdPJ4p4pnguV5Y=
|
||||
github.com/WJQSERVER/wanf v0.0.3 h1:OqhG7ETiR5Knqr0lmbb+iUMw9O7re2vEogjVf06QevM=
|
||||
github.com/WJQSERVER/wanf v0.0.3/go.mod h1:q2Pyg+G+s1acMWxrbI4CwS/Yk76/BzLREEdZ8iFwUNE=
|
||||
github.com/fenthope/reco v0.0.4 h1:yo2g3aWwdoMpaZWZX4SdZOW7mCK82RQIU/YI8ZUQThM=
|
||||
github.com/fenthope/reco v0.0.4/go.mod h1:eMyS8HpdMVdJ/2WJt6Cvt8P1EH9Igzj5lSJrgc+0jeg=
|
||||
github.com/go-json-experiment/json v0.0.0-20250725192818-e39067aee2d2 h1:iizUGZ9pEquQS5jTGkh4AqeeHCMbfbjeb0zMt0aEFzs=
|
||||
github.com/go-json-experiment/json v0.0.0-20250725192818-e39067aee2d2/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M=
|
||||
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU=
|
||||
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
|
||||
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
|
||||
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
||||
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
||||
|
|
|
|||
184
sse.go
Normal file
184
sse.go
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
// Copyright 2025 WJQSERVER. All rights reserved.
|
||||
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||
package touka
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Event 代表一个服务器发送事件(SSE).
|
||||
type Event struct {
|
||||
// Event 是事件的名称.
|
||||
Event string
|
||||
// Data 是事件的内容, 可以是多行文本.
|
||||
Data string
|
||||
// Id 是事件的唯一标识符.
|
||||
Id string
|
||||
// Retry 是指定客户端在连接丢失后应等待多少毫秒后尝试重新连接.
|
||||
Retry string
|
||||
}
|
||||
|
||||
// Render 将事件格式化并写入给定的 writer.
|
||||
// 通过逐行处理数据, 此方法可防止因数据中包含换行符而导致的CRLF注入问题.
|
||||
// 为了性能, 它使用 bytes.Buffer 并通过 WriteTo 直接写入, 以避免不必要的内存分配.
|
||||
func (e *Event) Render(w io.Writer) error {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if len(e.Id) > 0 {
|
||||
buf.WriteString("id: ")
|
||||
buf.WriteString(e.Id)
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
if len(e.Event) > 0 {
|
||||
buf.WriteString("event: ")
|
||||
buf.WriteString(e.Event)
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
if len(e.Data) > 0 {
|
||||
lines := strings.Split(e.Data, "\n")
|
||||
for _, line := range lines {
|
||||
buf.WriteString("data: ")
|
||||
buf.WriteString(line)
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
}
|
||||
if len(e.Retry) > 0 {
|
||||
buf.WriteString("retry: ")
|
||||
buf.WriteString(e.Retry)
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
|
||||
// 每个事件都以一个额外的换行符结尾.
|
||||
buf.WriteString("\n")
|
||||
|
||||
// 直接将 buffer 的内容写入 writer, 避免生成中间字符串.
|
||||
_, err := buf.WriteTo(w)
|
||||
return err
|
||||
}
|
||||
|
||||
// EventStream 启动一个 SSE 事件流.
|
||||
// 这是推荐的、更简单安全的方式, 采用阻塞和回调的设计, 框架负责管理连接生命周期.
|
||||
//
|
||||
// 详细用法:
|
||||
//
|
||||
// r.GET("/sse/callback", func(c *touka.Context) {
|
||||
// // streamer 回调函数会在一个循环中被调用.
|
||||
// c.EventStream(func(w io.Writer) bool {
|
||||
// event := touka.Event{
|
||||
// Event: "time-tick",
|
||||
// Data: time.Now().Format(time.RFC1123),
|
||||
// }
|
||||
//
|
||||
// if err := event.Render(w); err != nil {
|
||||
// // 发生写入错误, 停止发送.
|
||||
// return false // 返回 false 结束事件流.
|
||||
// }
|
||||
//
|
||||
// time.Sleep(2 * time.Second)
|
||||
// return true // 返回 true 继续事件流.
|
||||
// })
|
||||
// // 当事件流结束后(例如客户端关闭页面), 这行代码会被执行.
|
||||
// fmt.Println("Client disconnected from /sse/callback")
|
||||
// })
|
||||
func (c *Context) EventStream(streamer func(w io.Writer) bool) {
|
||||
// 为现代网络协议优化头部.
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
c.Writer.Header().Del("Connection")
|
||||
c.Writer.Header().Del("Transfer-Encoding")
|
||||
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
c.Writer.Flush() // 直接调用, ResponseWriter 接口保证了 Flush 方法的存在.
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
default:
|
||||
if !streamer(c.Writer) {
|
||||
return
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EventStreamChan 返回用于 SSE 事件流的 channel.
|
||||
// 这是为高级并发场景设计的、更灵活的API.
|
||||
//
|
||||
// 重要:
|
||||
// - 调用者必须 close(eventChan) 来结束事件流.
|
||||
// - 调用者必须在独立的 goroutine 中消费 errChan 来处理错误和连接断开.
|
||||
// - 为防止 goroutine 泄漏, 建议发送方在 select 中同时监听 c.Request.Context().Done().
|
||||
//
|
||||
// 详细用法:
|
||||
//
|
||||
// r.GET("/sse/channel", func(c *touka.Context) {
|
||||
// eventChan, errChan := c.EventStreamChan()
|
||||
//
|
||||
// // 必须在独立的goroutine中处理错误和连接断开.
|
||||
// go func() {
|
||||
// if err := <-errChan; err != nil {
|
||||
// c.Errorf("SSE channel error: %v", err)
|
||||
// }
|
||||
// }()
|
||||
//
|
||||
// // 在另一个goroutine中异步发送事件.
|
||||
// go func() {
|
||||
// // 重要: 必须在逻辑结束时关闭channel, 以通知框架.
|
||||
// defer close(eventChan)
|
||||
//
|
||||
// for i := 1; i <= 5; i++ {
|
||||
// select {
|
||||
// case <-c.Request.Context().Done():
|
||||
// return // 客户端已断开, 退出 goroutine.
|
||||
// default:
|
||||
// eventChan <- touka.Event{
|
||||
// Id: fmt.Sprintf("%d", i),
|
||||
// Data: "hello from channel",
|
||||
// }
|
||||
// time.Sleep(2 * time.Second)
|
||||
// }
|
||||
// }
|
||||
// }()
|
||||
// })
|
||||
func (c *Context) EventStreamChan() (chan<- Event, <-chan error) {
|
||||
eventChan := make(chan Event)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
c.Writer.Header().Del("Connection")
|
||||
c.Writer.Header().Del("Transfer-Encoding")
|
||||
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
c.Writer.Flush()
|
||||
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-eventChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := event.Render(c.Writer); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
c.Writer.Flush()
|
||||
case <-c.Request.Context().Done():
|
||||
errChan <- c.Request.Context().Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return eventChan, errChan
|
||||
}
|
||||
20
tree.go
20
tree.go
|
|
@ -5,7 +5,6 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/url"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
|
@ -27,12 +26,6 @@ func BytesToString(b []byte) string {
|
|||
return unsafe.String(unsafe.SliceData(b), len(b))
|
||||
}
|
||||
|
||||
var (
|
||||
strColon = []byte(":") // 定义字节切片常量, 表示冒号, 用于路径参数识别
|
||||
strStar = []byte("*") // 定义字节切片常量, 表示星号, 用于捕获所有路径识别
|
||||
strSlash = []byte("/") // 定义字节切片常量, 表示斜杠, 用于路径分隔符识别
|
||||
)
|
||||
|
||||
// Param 是单个 URL 参数, 由键和值组成.
|
||||
type Param struct {
|
||||
Key string // 参数的键名
|
||||
|
|
@ -106,17 +99,14 @@ func (n *node) addChild(child *node) {
|
|||
|
||||
// countParams 计算路径中参数(冒号)和捕获所有(星号)的数量.
|
||||
func countParams(path string) uint16 {
|
||||
var n uint16
|
||||
s := StringToBytes(path) // 将路径字符串转换为字节切片
|
||||
n += uint16(bytes.Count(s, strColon)) // 统计冒号的数量
|
||||
n += uint16(bytes.Count(s, strStar)) // 统计星号的数量
|
||||
return n
|
||||
colons := strings.Count(path, ":")
|
||||
stars := strings.Count(path, "*")
|
||||
return uint16(colons + stars)
|
||||
}
|
||||
|
||||
// countSections 计算路径中斜杠('/')的数量, 即路径段的数量.
|
||||
func countSections(path string) uint16 {
|
||||
s := StringToBytes(path) // 将路径字符串转换为字节切片
|
||||
return uint16(bytes.Count(s, strSlash)) // 统计斜杠的数量
|
||||
return uint16(strings.Count(path, "/"))
|
||||
}
|
||||
|
||||
// nodeType 定义了节点的类型.
|
||||
|
|
@ -419,7 +409,7 @@ func (n *node) insertChild(path string, fullPath string, handlers HandlersChain)
|
|||
}
|
||||
|
||||
n.addChild(child) // 添加子节点
|
||||
n.indices = string('/') // 索引设置为 '/'
|
||||
n.indices = "/" // 索引设置为 '/'
|
||||
n = child // 移动到新创建的 catchAll 节点
|
||||
n.priority++ // 增加优先级
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue