mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-02-03 00:41:10 +08:00
Compare commits
3 commits
96154fff78
...
896182417f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
896182417f | ||
|
|
9a2aeef0d0 | ||
|
|
bb822599b9 |
5 changed files with 230 additions and 12 deletions
127
context.go
127
context.go
|
|
@ -1,6 +1,7 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
|
|
@ -14,6 +15,7 @@ import (
|
|||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fenthope/reco"
|
||||
"github.com/go-json-experiment/json"
|
||||
|
|
@ -128,6 +130,72 @@ func (c *Context) Get(key string) (value interface{}, exists bool) {
|
|||
return
|
||||
}
|
||||
|
||||
// GetString 从 Context 中获取一个字符串值
|
||||
// 这是一个线程安全的操作
|
||||
func (c *Context) GetString(key string) (value string, exists bool) {
|
||||
if val, exists := c.Get(key); exists {
|
||||
if str, ok := val.(string); ok {
|
||||
return str, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// GetInt 从 Context 中获取一个 int 值
|
||||
// 这是一个线程安全的操作
|
||||
func (c *Context) GetInt(key string) (value int, exists bool) {
|
||||
if val, exists := c.Get(key); exists {
|
||||
if i, ok := val.(int); ok {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetBool 从 Context 中获取一个 bool 值
|
||||
// 这是一个线程安全的操作
|
||||
func (c *Context) GetBool(key string) (value bool, exists bool) {
|
||||
if val, exists := c.Get(key); exists {
|
||||
if b, ok := val.(bool); ok {
|
||||
return b, true
|
||||
}
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
// GetFloat64 从 Context 中获取一个 float64 值
|
||||
// 这是一个线程安全的操作
|
||||
func (c *Context) GetFloat64(key string) (value float64, exists bool) {
|
||||
if val, exists := c.Get(key); exists {
|
||||
if f, ok := val.(float64); ok {
|
||||
return f, true
|
||||
}
|
||||
}
|
||||
return 0.0, false
|
||||
}
|
||||
|
||||
// GetTime 从 Context 中获取一个 time.Time 值
|
||||
// 这是一个线程安全的操作
|
||||
func (c *Context) GetTime(key string) (value time.Time, exists bool) {
|
||||
if val, exists := c.Get(key); exists {
|
||||
if t, ok := val.(time.Time); ok {
|
||||
return t, true
|
||||
}
|
||||
}
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
// GetDuration 从 Context 中获取一个 time.Duration 值
|
||||
// 这是一个线程安全的操作
|
||||
func (c *Context) GetDuration(key string) (value time.Duration, exists bool) {
|
||||
if val, exists := c.Get(key); exists {
|
||||
if d, ok := val.(time.Duration); ok {
|
||||
return d, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// MustGet 从 Context 中获取一个值,如果不存在则 panic
|
||||
// 适用于确定值一定存在的场景
|
||||
func (c *Context) MustGet(key string) interface{} {
|
||||
|
|
@ -369,7 +437,6 @@ func (c *Context) GetReqBody() io.ReadCloser {
|
|||
return c.Request.Body
|
||||
}
|
||||
|
||||
// GetReqBodyFull
|
||||
// GetReqBodyFull 读取并返回请求体的所有内容
|
||||
// 注意:请求体只能读取一次
|
||||
func (c *Context) GetReqBodyFull() ([]byte, error) {
|
||||
|
|
@ -385,6 +452,20 @@ func (c *Context) GetReqBodyFull() ([]byte, error) {
|
|||
return data, nil
|
||||
}
|
||||
|
||||
// 类似 GetReqBodyFull, 返回 *bytes.Buffer
|
||||
func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
|
||||
if c.Request.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
defer c.Request.Body.Close() // 确保请求体被关闭
|
||||
data, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||
}
|
||||
return bytes.NewBuffer(data), nil
|
||||
}
|
||||
|
||||
// RequestIP 返回客户端的 IP 地址
|
||||
// 它会根据 Engine 的配置 (ForwardByClientIP) 尝试从 X-Forwarded-For 或 X-Real-IP 等头部获取,
|
||||
// 否则回退到 Request.RemoteAddr
|
||||
|
|
@ -512,6 +593,50 @@ func (c *Context) GetLogger() *reco.Logger {
|
|||
return c.engine.LogReco
|
||||
}
|
||||
|
||||
// GetReqQueryString
|
||||
// GetReqQueryString 返回请求的原始查询字符串
|
||||
func (c *Context) GetReqQueryString() string {
|
||||
return c.Request.URL.RawQuery
|
||||
}
|
||||
|
||||
// SetBodyStream 设置响应体为一个 io.Reader,并指定内容长度
|
||||
// 如果 contentSize 为 -1,则表示内容长度未知,将使用 Transfer-Encoding: chunked
|
||||
func (c *Context) SetBodyStream(reader io.Reader, contentSize int) {
|
||||
// 如果指定了内容长度且大于等于 0,则设置 Content-Length 头部
|
||||
if contentSize >= 0 {
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", contentSize))
|
||||
} else {
|
||||
// 如果内容长度未知,移除 Content-Length 头部,通常会使用 Transfer-Encoding: chunked
|
||||
c.Writer.Header().Del("Content-Length")
|
||||
}
|
||||
|
||||
// 确保在写入数据前设置状态码
|
||||
if !c.Writer.Written() {
|
||||
c.Writer.WriteHeader(http.StatusOK) // 默认 200 OK
|
||||
}
|
||||
|
||||
// 将 reader 的内容直接复制到 ResponseWriter
|
||||
// ResponseWriter 实现了 io.Writer 接口
|
||||
_, err := copyb.Copy(c.Writer, reader)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to write stream: %w", err))
|
||||
// 注意:这里可能无法设置错误状态码,因为头部可能已经发送
|
||||
// 可以在调用 SetBodyStream 之前检查错误,或者在中间件中处理 Context.Errors
|
||||
}
|
||||
}
|
||||
|
||||
// GetRequestURI 返回请求的原始 URI
|
||||
func (c *Context) GetRequestURI() string {
|
||||
return c.Request.RequestURI
|
||||
}
|
||||
|
||||
// GetRequestURIPath 返回请求的原始 URI 的路径部分
|
||||
func (c *Context) GetRequestURIPath() string {
|
||||
return c.Request.URL.Path
|
||||
}
|
||||
|
||||
// == cookie ===
|
||||
|
||||
// SetSameSite 设置响应的 SameSite cookie 属性
|
||||
func (c *Context) SetSameSite(samesite http.SameSite) {
|
||||
c.sameSite = samesite
|
||||
|
|
|
|||
103
engine.go
103
engine.go
|
|
@ -3,6 +3,7 @@ package touka
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"runtime"
|
||||
|
|
@ -147,7 +148,7 @@ func New() *Engine {
|
|||
}
|
||||
//engine.SetProtocols(GetDefaultProtocolsConfig())
|
||||
engine.SetDefaultProtocols()
|
||||
engine.SetLogger(defaultLogRecoConfig)
|
||||
engine.SetLoggerCfg(defaultLogRecoConfig)
|
||||
// 初始化 Context Pool,为每个新 Context 实例提供一个构造函数
|
||||
engine.pool.New = func() interface{} {
|
||||
return &Context{
|
||||
|
|
@ -173,8 +174,13 @@ func Default() *Engine {
|
|||
|
||||
// === 外部操作方法 ===
|
||||
|
||||
// 配置日志Logger
|
||||
func (engine *Engine) SetLogger(logcfg reco.Config) {
|
||||
// SetLogger传入实例
|
||||
func (engine *Engine) SetLogger(logger *reco.Logger) {
|
||||
engine.LogReco = logger
|
||||
}
|
||||
|
||||
// 配置日志LoggerCfg
|
||||
func (engine *Engine) SetLoggerCfg(logcfg reco.Config) {
|
||||
engine.LogReco = NewLogger(logcfg)
|
||||
}
|
||||
|
||||
|
|
@ -659,8 +665,9 @@ func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IR
|
|||
|
||||
// == 其他操作方式 ===
|
||||
|
||||
// Static FileServer 传入一个文件夹路径, 使用FileServer进行处理
|
||||
func (engine *Engine) Static(relativePath, rootPath string) {
|
||||
// StaticDir 传入一个文件夹路径, 使用FileServer进行处理
|
||||
// r.StaticDir("/test/*filepath", "/var/www/test")
|
||||
func (engine *Engine) StaticDir(relativePath, rootPath string) {
|
||||
// 清理路径
|
||||
relativePath = path.Clean(relativePath)
|
||||
rootPath = path.Clean(rootPath)
|
||||
|
|
@ -723,8 +730,8 @@ func (engine *Engine) Static(relativePath, rootPath string) {
|
|||
})
|
||||
}
|
||||
|
||||
// Group的Static
|
||||
func (group *RouterGroup) Static(relativePath, rootPath string) {
|
||||
// Group的StaticDir方式
|
||||
func (group *RouterGroup) StaticDir(relativePath, rootPath string) {
|
||||
// 清理路径
|
||||
relativePath = path.Clean(relativePath)
|
||||
rootPath = path.Clean(rootPath)
|
||||
|
|
@ -896,5 +903,85 @@ func (group *RouterGroup) StaticFile(relativePath, filePath string) {
|
|||
group.GET(relativePath, FileHandle)
|
||||
group.HEAD(relativePath, FileHandle)
|
||||
group.OPTIONS(relativePath, FileHandle)
|
||||
|
||||
}
|
||||
|
||||
// 维护一个Methods列表
|
||||
var (
|
||||
MethodGet = "GET"
|
||||
MethodHead = "HEAD"
|
||||
MethodPost = "POST"
|
||||
MethodPut = "PUT"
|
||||
MethodPatch = "PATCH"
|
||||
MethodDelete = "DELETE"
|
||||
MethodConnect = "CONNECT"
|
||||
MethodOptions = "OPTIONS"
|
||||
MethodTrace = "TRACE"
|
||||
)
|
||||
|
||||
var MethodsSet = map[string]struct{}{
|
||||
MethodGet: {},
|
||||
MethodHead: {},
|
||||
MethodPost: {},
|
||||
MethodPut: {},
|
||||
MethodPatch: {},
|
||||
MethodDelete: {},
|
||||
MethodConnect: {},
|
||||
MethodOptions: {},
|
||||
MethodTrace: {},
|
||||
}
|
||||
|
||||
// HandleFunc 注册一个或多个 HTTP 方法的路由
|
||||
// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"})
|
||||
// relativePath 是相对于当前组或 Engine 的路径
|
||||
// handlers 是处理函数链
|
||||
func (engine *Engine) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) {
|
||||
for _, method := range methods {
|
||||
if _, ok := MethodsSet[method]; !ok {
|
||||
panic("invalid method: " + method)
|
||||
}
|
||||
engine.Handle(method, relativePath, handlers...)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleFunc 注册一个或多个 HTTP 方法的路由
|
||||
// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"})
|
||||
// relativePath 是相对于当前组或 Engine 的路径
|
||||
// handlers 是处理函数链
|
||||
func (group *RouterGroup) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) {
|
||||
for _, method := range methods {
|
||||
if _, ok := MethodsSet[method]; !ok {
|
||||
panic("invalid method: " + method)
|
||||
}
|
||||
group.Handle(method, relativePath, handlers...)
|
||||
}
|
||||
}
|
||||
|
||||
// FileServer方式, 返回一个HandleFunc, 统一化处理
|
||||
func FileServer(fs http.FileSystem) HandlerFunc {
|
||||
return func(c *Context) {
|
||||
// 检查是否是 GET 或 HEAD 方法
|
||||
if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead {
|
||||
// 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件
|
||||
if c.engine.HandleMethodNotAllowed {
|
||||
c.Next()
|
||||
} else {
|
||||
// 否则,返回 405 Method Not Allowed
|
||||
c.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, fmt.Errorf("Method %s is Not Allowed on FileServer", c.Request.Method))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码
|
||||
ecw := AcquireErrorCapturingResponseWriter(c)
|
||||
defer ReleaseErrorCapturingResponseWriter(ecw)
|
||||
|
||||
// 调用 http.FileServer 处理请求
|
||||
http.FileServer(fs).ServeHTTP(ecw, c.Request)
|
||||
|
||||
// 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler
|
||||
ecw.processAfterFileServer()
|
||||
|
||||
// 中止处理链,因为 FileServer 已经处理了响应
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
2
go.mod
2
go.mod
|
|
@ -5,7 +5,7 @@ go 1.24.4
|
|||
require (
|
||||
github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4
|
||||
github.com/WJQSERVER-STUDIO/httpc v0.7.0
|
||||
github.com/fenthope/reco v0.0.1
|
||||
github.com/fenthope/reco v0.0.2
|
||||
github.com/go-json-experiment/json v0.0.0-20250517221953-25912455fbc8
|
||||
)
|
||||
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -2,8 +2,8 @@ github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4 h1:JLtFd00AdFg/TP+dtvIzLkdHwKU
|
|||
github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4/go.mod h1:FZ6XE+4TKy4MOfX1xWKe6Rwsg0ucYFCdNh1KLvyKTfc=
|
||||
github.com/WJQSERVER-STUDIO/httpc v0.7.0 h1:iHhqlxppJBjlmvsIjvLZKRbWXqSdbeSGGofjHGmqGJc=
|
||||
github.com/WJQSERVER-STUDIO/httpc v0.7.0/go.mod h1:M7KNUZjjhCkzzcg9lBPs9YfkImI+7vqjAyjdA19+joE=
|
||||
github.com/fenthope/reco v0.0.1 h1:GYcuXCEKYoctD0dFkiBC+t0RMTOyOiujBCin8bbLR3Y=
|
||||
github.com/fenthope/reco v0.0.1/go.mod h1:mDkGLHte5udWTIcjQTxrABRcf56SSdxBOCLgrRDwI/Y=
|
||||
github.com/fenthope/reco v0.0.2 h1:9+RdpZlQYuwQh00XGn8uFu7vcHeldlgufGnszXkyFwg=
|
||||
github.com/fenthope/reco v0.0.2/go.mod h1:mDkGLHte5udWTIcjQTxrABRcf56SSdxBOCLgrRDwI/Y=
|
||||
github.com/go-json-experiment/json v0.0.0-20250517221953-25912455fbc8 h1:o8UqXPI6SVwQt04RGsqKp3qqmbOfTNMqDrWsc4O47kk=
|
||||
github.com/go-json-experiment/json v0.0.0-20250517221953-25912455fbc8/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
|
|
|
|||
|
|
@ -34,3 +34,9 @@ func CloseLogger(logger *reco.Logger) {
|
|||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (engine *Engine) CloseLogger() {
|
||||
if engine.LogReco != nil {
|
||||
CloseLogger(engine.LogReco)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue