refactor: improve binding, performance, and type safety

- Implement ShouldBind with support for JSON, Form, WANF, and GOB formats
- Add ShouldBindForm, ShouldBindGOB, and helper functions for form binding
- Use fmt.Appendf instead of fmt.Sprintf for better performance
- Replace interface{} with any for modern Go style
- Use maps.Copy for cleaner header copying
- Update strings.SplitSeq to use range over strings.Seq
- Remove deprecated placeholder comments and add proper implementations
- Fix reflect.Pointer usage for Go 1.22+ compatibility
This commit is contained in:
WJQSERVER 2026-03-17 12:02:49 +08:00
parent 9e3e43bf88
commit c4c0160b5f
6 changed files with 187 additions and 36 deletions

View file

@ -19,6 +19,8 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -286,7 +288,7 @@ func (c *Context) Raw(code int, contentType string, data []byte) {
// String 向响应写入格式化的字符串 // String 向响应写入格式化的字符串
func (c *Context) String(code int, format string, values ...any) { func (c *Context) String(code int, format string, values ...any) {
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
c.Writer.Write([]byte(fmt.Sprintf(format, values...))) c.Writer.Write(fmt.Appendf(nil, format, values...))
} }
// Text 向响应写入无需格式化的string // Text 向响应写入无需格式化的string
@ -341,7 +343,6 @@ func (c *Context) FileText(code int, filePath string) {
} }
/* /*
// not fot work
// FileTextSafeDir // FileTextSafeDir
func (c *Context) FileTextSafeDir(code int, filePath string, safeDir string) { func (c *Context) FileTextSafeDir(code int, filePath string, safeDir string) {
@ -465,7 +466,7 @@ func (c *Context) HTML(code int, name string, obj any) {
// 可以扩展支持其他渲染器接口 // 可以扩展支持其他渲染器接口
} }
// 默认简单输出,用于未配置 HTMLRender 的情况 // 默认简单输出,用于未配置 HTMLRender 的情况
c.Writer.Write([]byte(fmt.Sprintf("<!-- HTML rendered for %s -->\n<pre>%v</pre>", name, obj))) c.Writer.Write(fmt.Appendf(nil, "<!-- HTML rendered for %s -->\n<pre>%v</pre>", name, obj))
} }
// Redirect 执行 HTTP 重定向 // Redirect 执行 HTTP 重定向
@ -490,7 +491,7 @@ func (c *Context) ShouldBindJSON(obj any) error {
return nil return nil
} }
// ShouldBindWANF // ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象
func (c *Context) ShouldBindWANF(obj any) error { func (c *Context) ShouldBindWANF(obj any) error {
if c.Request.Body == nil { if c.Request.Body == nil {
return errors.New("request body is empty") return errors.New("request body is empty")
@ -506,23 +507,174 @@ func (c *Context) ShouldBindWANF(obj any) error {
return nil return nil
} }
// Deprecated: This function is a reserved placeholder for future API extensions // ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象
// and is not yet implemented. It will either be properly defined or removed in v2.0.0. Do not use. func (c *Context) ShouldBindGOB(obj any) error {
// ShouldBind 尝试将请求体绑定到各种类型JSON, Form, XML 等) if c.Request.Body == nil {
// 这是一个复杂的通用绑定接口,通常根据 Content-Type 或其他头部来判断绑定方式 return errors.New("request body is empty")
// 预留接口,可根据项目需求进行扩展 }
decoder := gob.NewDecoder(c.Request.Body)
if err := decoder.Decode(obj); err != nil {
return fmt.Errorf("GOB binding error: %w", err)
}
return nil
}
// bindForm 将 url.Values 绑定到结构体
// 支持 form tag 标签,如 `form:"field_name"`
func bindForm(values url.Values, obj any) error {
val := reflect.ValueOf(obj)
if val.Kind() != reflect.Pointer || val.Elem().Kind() != reflect.Struct {
return errors.New("obj must be a pointer to struct")
}
val = val.Elem()
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
fieldType := typ.Field(i)
if !field.CanSet() {
continue
}
tag := fieldType.Tag.Get("form")
if tag == "" {
tag = fieldType.Name
}
if tag == "-" {
continue
}
formValues := values[tag]
if len(formValues) == 0 {
continue
}
if err := setFieldValue(field, formValues); err != nil {
return fmt.Errorf("field %s: %w", fieldType.Name, err)
}
}
return nil
}
// setFieldValue 将字符串值设置到反射值
func setFieldValue(field reflect.Value, values []string) error {
if !field.CanSet() {
return nil
}
value := values[0]
switch field.Kind() {
case reflect.String:
field.SetString(value)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if value == "" {
return nil
}
v, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return err
}
field.SetInt(v)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if value == "" {
return nil
}
v, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return err
}
field.SetUint(v)
case reflect.Float32, reflect.Float64:
if value == "" {
return nil
}
v, err := strconv.ParseFloat(value, 64)
if err != nil {
return err
}
field.SetFloat(v)
case reflect.Bool:
if value == "" {
return nil
}
v, err := strconv.ParseBool(value)
if err != nil {
return err
}
field.SetBool(v)
case reflect.Pointer:
if field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
return setFieldValue(field.Elem(), values)
case reflect.Slice:
slice := reflect.MakeSlice(field.Type(), len(values), len(values))
elemType := field.Type().Elem()
for i, v := range values {
if err := setFieldValue(slice.Index(i), []string{v}); err != nil {
return err
}
_ = elemType
}
field.Set(slice)
default:
return fmt.Errorf("unsupported type: %s", field.Kind())
}
return nil
}
// ShouldBindForm 尝试将表单数据绑定到结构体
// 支持 application/x-www-form-urlencoded 和 multipart/form-data
func (c *Context) ShouldBindForm(obj any) error {
contentType := c.Request.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
return fmt.Errorf("invalid content type: %w", err)
}
switch mediaType {
case "multipart/form-data":
if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
return fmt.Errorf("parse multipart form error: %w", err)
}
case "application/x-www-form-urlencoded":
if err := c.Request.ParseForm(); err != nil {
return fmt.Errorf("parse form error: %w", err)
}
default:
return fmt.Errorf("unsupported form content type: %s", mediaType)
}
if err := bindForm(c.Request.Form, obj); err != nil {
return fmt.Errorf("form binding error: %w", err)
}
return nil
}
// ShouldBind 尝试根据 Content-Type 将请求体绑定到结构体
// 支持的类型application/json, application/x-www-form-urlencoded, multipart/form-data, application/wanf, application/vnd.wjqserver.wanf, application/gob
func (c *Context) ShouldBind(obj any) error { func (c *Context) ShouldBind(obj any) error {
// TODO: 完整的通用绑定逻辑 contentType := c.Request.Header.Get("Content-Type")
// 可以根据 c.Request.Header.Get("Content-Type") 来判断是 JSON, Form, XML 等 mediaType, _, err := mime.ParseMediaType(contentType)
// 例如: if err != nil {
// contentType := c.Request.Header.Get("Content-Type") return fmt.Errorf("invalid content type: %w", err)
// if strings.HasPrefix(contentType, "application/json") { }
// return c.ShouldBindJSON(obj)
// } switch mediaType {
// if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") || strings.HasPrefix(contentType, "multipart/form-data") { case "application/json":
// return c.ShouldBindForm(obj) // 需要实现 ShouldBindForm return c.ShouldBindJSON(obj)
// } case "application/x-www-form-urlencoded", "multipart/form-data":
return errors.New("generic binding not fully implemented yet, implement based on Content-Type") return c.ShouldBindForm(obj)
case "application/wanf", "application/vnd.wjqserver.wanf":
return c.ShouldBindWANF(obj)
case "application/gob":
return c.ShouldBindGOB(obj)
default:
return fmt.Errorf("unsupported content type: %s", mediaType)
}
} }
// AddError 添加一个错误到 Context // AddError 添加一个错误到 Context

13
ecw.go
View file

@ -7,6 +7,7 @@ package touka
import ( import (
"bufio" "bufio"
"errors" "errors"
"maps"
"net" "net"
"net/http" "net/http"
"sync" "sync"
@ -27,7 +28,7 @@ type errorCapturingResponseWriter struct {
// errorResponseWriterPool 是用于复用 errorCapturingResponseWriter 实例的对象池 // errorResponseWriterPool 是用于复用 errorCapturingResponseWriter 实例的对象池
var errorResponseWriterPool = sync.Pool{ var errorResponseWriterPool = sync.Pool{
New: func() interface{} { New: func() any {
return &errorCapturingResponseWriter{ return &errorCapturingResponseWriter{
headerSnapshot: make(http.Header), // 预先初始化 map, 减少 reset 时的分配 headerSnapshot: make(http.Header), // 预先初始化 map, 减少 reset 时的分配
} }
@ -91,9 +92,8 @@ func (ecw *errorCapturingResponseWriter) WriteHeader(statusCode int) {
// 是成功状态码 // 是成功状态码
// 将 ecw.headerSnapshot 中(由 FileServer 在此之前通过 ecw.Header() 设置的) // 将 ecw.headerSnapshot 中(由 FileServer 在此之前通过 ecw.Header() 设置的)
// 任何头部直接复制到原始的 w.Header(), 确保多值头部正确传递 // 任何头部直接复制到原始的 w.Header(), 确保多值头部正确传递
for k, v := range ecw.headerSnapshot { // 直接赋值 []string, 保留所有值
ecw.w.Header()[k] = v // 直接赋值 []string, 保留所有值 maps.Copy(ecw.w.Header(), ecw.headerSnapshot)
}
ecw.w.WriteHeader(statusCode) // 实际写入状态码到原始 ResponseWriter ecw.w.WriteHeader(statusCode) // 实际写入状态码到原始 ResponseWriter
ecw.responseStarted = true // 标记成功响应已开始 ecw.responseStarted = true // 标记成功响应已开始
} }
@ -112,9 +112,8 @@ func (ecw *errorCapturingResponseWriter) Write(data []byte) (int, error) {
ecw.statusCode = http.StatusOK // 隐式 200 OK ecw.statusCode = http.StatusOK // 隐式 200 OK
} }
// 将 headerSnapshot 中的头部复制到原始 ResponseWriter 的 Header // 将 headerSnapshot 中的头部复制到原始 ResponseWriter 的 Header
for k, v := range ecw.headerSnapshot { // 直接赋值 []string, 保留所有值
ecw.w.Header()[k] = v // 直接赋值 []string, 保留所有值 maps.Copy(ecw.w.Header(), ecw.headerSnapshot)
}
ecw.w.WriteHeader(ecw.Status()) // 发送实际的状态码 (可能是 200 或之前设置的 2xx) ecw.w.WriteHeader(ecw.Status()) // 发送实际的状态码 (可能是 200 或之前设置的 2xx)
ecw.responseStarted = true ecw.responseStarted = true
} }

View file

@ -51,7 +51,7 @@ type Engine struct {
LogReco *reco.Logger LogReco *reco.Logger
HTMLRender interface{} // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口 HTMLRender any // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口
routesInfo []RouteInfo // 存储所有注册的路由信息 routesInfo []RouteInfo // 存储所有注册的路由信息
@ -215,11 +215,11 @@ func New() *Engine {
engine.SetDefaultProtocols() engine.SetDefaultProtocols()
engine.SetLoggerCfg(defaultLogRecoConfig) engine.SetLoggerCfg(defaultLogRecoConfig)
// 初始化 Context Pool,为每个新 Context 实例提供一个构造函数 // 初始化 Context Pool,为每个新 Context 实例提供一个构造函数
engine.pool.New = func() interface{} { engine.pool.New = func() any {
return &Context{ return &Context{
Writer: newResponseWriter(nil), // 初始时可以传入nil,在ServeHTTP中会重新设置实际的 http.ResponseWriter Writer: newResponseWriter(nil), // 初始时可以传入nil,在ServeHTTP中会重新设置实际的 http.ResponseWriter
Params: make(Params, 0, engine.maxParams), // 预分配 Params 切片以减少内存分配 Params: make(Params, 0, engine.maxParams), // 预分配 Params 切片以减少内存分配
Keys: make(map[string]interface{}), Keys: make(map[string]any),
Errors: make([]error, 0), Errors: make([]error, 0),
ctx: context.Background(), // 初始上下文,后续会被请求的 Context 覆盖 ctx: context.Background(), // 初始上下文,后续会被请求的 Context 覆盖
HTTPClient: engine.HTTPClient, HTTPClient: engine.HTTPClient,

View file

@ -18,7 +18,7 @@ import (
// PanicHandlerFunc 定义了用户自定义的 panic 处理函数类型 // PanicHandlerFunc 定义了用户自定义的 panic 处理函数类型
// 它接收当前的 Context 和 panic 的值 // 它接收当前的 Context 和 panic 的值
type PanicHandlerFunc func(c *Context, panicInfo interface{}) type PanicHandlerFunc func(c *Context, panicInfo any)
// RecoveryWithOptions 返回一个可配置的 panic 恢复中间件 // RecoveryWithOptions 返回一个可配置的 panic 恢复中间件
// //
@ -50,7 +50,7 @@ func Recovery() HandlerFunc {
} }
// defaultPanicHandler 是默认的 panic 处理逻辑 // defaultPanicHandler 是默认的 panic 处理逻辑
func defaultPanicHandler(c *Context, r interface{}) { func defaultPanicHandler(c *Context, r any) {
// 检查连接是否已由客户端关闭 // 检查连接是否已由客户端关闭
// 常见的错误类型包括 net.OpError (其内部错误可能是 os.SyscallError) // 常见的错误类型包括 net.OpError (其内部错误可能是 os.SyscallError)
// 以及在 HTTP/2 中可能出现的特定 stream 错误 // 以及在 HTTP/2 中可能出现的特定 stream 错误
@ -107,7 +107,7 @@ func defaultPanicHandler(c *Context, r interface{}) {
// isBrokenPipeError 检查 recover() 捕获的值是否表示一个由客户端断开连接引起的网络错误 // isBrokenPipeError 检查 recover() 捕获的值是否表示一个由客户端断开连接引起的网络错误
// 这对于防止在已关闭的连接上写入响应至关重要 // 这对于防止在已关闭的连接上写入响应至关重要
func isBrokenPipeError(r interface{}) bool { func isBrokenPipeError(r any) bool {
// 将 recover() 的结果转换为 error 类型 // 将 recover() 的结果转换为 error 类型
err, ok := r.(error) err, ok := r.(error)
if !ok { if !ok {

4
sse.go
View file

@ -40,8 +40,8 @@ func (e *Event) Render(w io.Writer) error {
buf.WriteString("\n") buf.WriteString("\n")
} }
if len(e.Data) > 0 { if len(e.Data) > 0 {
lines := strings.Split(e.Data, "\n") lines := strings.SplitSeq(e.Data, "\n")
for _, line := range lines { for line := range lines {
buf.WriteString("data: ") buf.WriteString("data: ")
buf.WriteString(line) buf.WriteString(line)
buf.WriteString("\n") buf.WriteString("\n")

View file

@ -12,7 +12,7 @@ const (
defaultMemory = 32 << 20 // 32 MB, Gin 的默认值,用于 ParseMultipartForm defaultMemory = 32 << 20 // 32 MB, Gin 的默认值,用于 ParseMultipartForm
) )
type H map[string]interface{} // map简写, 类似gin.H type H map[string]any // map简写, 类似gin.H
type Handle func(http.ResponseWriter, *http.Request, Params) type Handle func(http.ResponseWriter, *http.Request, Params)