From c4c0160b5f7d7f2c09bc04cc3ab92e71c718384e Mon Sep 17 00:00:00 2001 From: WJQSERVER <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 17 Mar 2026 12:02:49 +0800 Subject: [PATCH] refactor: improve binding, performance, and type safety - Implement ShouldBind with support for JSON, Form, WANF, and GOB formats - Add ShouldBindForm, ShouldBindGOB, and helper functions for form binding - Use fmt.Appendf instead of fmt.Sprintf for better performance - Replace interface{} with any for modern Go style - Use maps.Copy for cleaner header copying - Update strings.SplitSeq to use range over strings.Seq - Remove deprecated placeholder comments and add proper implementations - Fix reflect.Pointer usage for Go 1.22+ compatibility --- context.go | 192 ++++++++++++++++++++++++++++++++++++++++++++++------ ecw.go | 13 ++-- engine.go | 6 +- recovery.go | 6 +- sse.go | 4 +- touka.go | 2 +- 6 files changed, 187 insertions(+), 36 deletions(-) diff --git a/context.go b/context.go index 8c52b1f..c37371f 100644 --- a/context.go +++ b/context.go @@ -19,6 +19,8 @@ import ( "net/url" "os" "path/filepath" + "reflect" + "strconv" "strings" "sync" "time" @@ -286,7 +288,7 @@ func (c *Context) Raw(code int, contentType string, data []byte) { // String 向响应写入格式化的字符串 func (c *Context) String(code int, format string, values ...any) { c.Writer.WriteHeader(code) - c.Writer.Write([]byte(fmt.Sprintf(format, values...))) + c.Writer.Write(fmt.Appendf(nil, format, values...)) } // Text 向响应写入无需格式化的string @@ -341,7 +343,6 @@ func (c *Context) FileText(code int, filePath string) { } /* -// not fot work // FileTextSafeDir func (c *Context) FileTextSafeDir(code int, filePath string, safeDir string) { @@ -465,7 +466,7 @@ func (c *Context) HTML(code int, name string, obj any) { // 可以扩展支持其他渲染器接口 } // 默认简单输出,用于未配置 HTMLRender 的情况 - c.Writer.Write([]byte(fmt.Sprintf("\n
%v
", name, obj))) + c.Writer.Write(fmt.Appendf(nil, "\n
%v
", name, obj)) } // Redirect 执行 HTTP 重定向 @@ -490,7 +491,7 @@ func (c *Context) ShouldBindJSON(obj any) error { return nil } -// ShouldBindWANF +// ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象 func (c *Context) ShouldBindWANF(obj any) error { if c.Request.Body == nil { return errors.New("request body is empty") @@ -506,23 +507,174 @@ func (c *Context) ShouldBindWANF(obj any) error { return nil } -// Deprecated: This function is a reserved placeholder for future API extensions -// and is not yet implemented. It will either be properly defined or removed in v2.0.0. Do not use. -// ShouldBind 尝试将请求体绑定到各种类型(JSON, Form, XML 等) -// 这是一个复杂的通用绑定接口,通常根据 Content-Type 或其他头部来判断绑定方式 -// 预留接口,可根据项目需求进行扩展 +// ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象 +func (c *Context) ShouldBindGOB(obj any) error { + if c.Request.Body == nil { + return errors.New("request body is empty") + } + decoder := gob.NewDecoder(c.Request.Body) + if err := decoder.Decode(obj); err != nil { + return fmt.Errorf("GOB binding error: %w", err) + } + return nil +} + +// bindForm 将 url.Values 绑定到结构体 +// 支持 form tag 标签,如 `form:"field_name"` +func bindForm(values url.Values, obj any) error { + val := reflect.ValueOf(obj) + if val.Kind() != reflect.Pointer || val.Elem().Kind() != reflect.Struct { + return errors.New("obj must be a pointer to struct") + } + + val = val.Elem() + typ := val.Type() + + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + fieldType := typ.Field(i) + + if !field.CanSet() { + continue + } + + tag := fieldType.Tag.Get("form") + if tag == "" { + tag = fieldType.Name + } + if tag == "-" { + continue + } + + formValues := values[tag] + if len(formValues) == 0 { + continue + } + + if err := setFieldValue(field, formValues); err != nil { + return fmt.Errorf("field %s: %w", fieldType.Name, err) + } + } + return nil +} + +// setFieldValue 将字符串值设置到反射值 +func setFieldValue(field reflect.Value, values []string) error { + if !field.CanSet() { + return nil + } + + value := values[0] + + switch field.Kind() { + case reflect.String: + field.SetString(value) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if value == "" { + return nil + } + v, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + field.SetInt(v) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if value == "" { + return nil + } + v, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return err + } + field.SetUint(v) + case reflect.Float32, reflect.Float64: + if value == "" { + return nil + } + v, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + field.SetFloat(v) + case reflect.Bool: + if value == "" { + return nil + } + v, err := strconv.ParseBool(value) + if err != nil { + return err + } + field.SetBool(v) + case reflect.Pointer: + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + return setFieldValue(field.Elem(), values) + case reflect.Slice: + slice := reflect.MakeSlice(field.Type(), len(values), len(values)) + elemType := field.Type().Elem() + for i, v := range values { + if err := setFieldValue(slice.Index(i), []string{v}); err != nil { + return err + } + _ = elemType + } + field.Set(slice) + default: + return fmt.Errorf("unsupported type: %s", field.Kind()) + } + return nil +} + +// ShouldBindForm 尝试将表单数据绑定到结构体 +// 支持 application/x-www-form-urlencoded 和 multipart/form-data +func (c *Context) ShouldBindForm(obj any) error { + contentType := c.Request.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return fmt.Errorf("invalid content type: %w", err) + } + + switch mediaType { + case "multipart/form-data": + if err := c.Request.ParseMultipartForm(32 << 20); err != nil { + return fmt.Errorf("parse multipart form error: %w", err) + } + case "application/x-www-form-urlencoded": + if err := c.Request.ParseForm(); err != nil { + return fmt.Errorf("parse form error: %w", err) + } + default: + return fmt.Errorf("unsupported form content type: %s", mediaType) + } + + if err := bindForm(c.Request.Form, obj); err != nil { + return fmt.Errorf("form binding error: %w", err) + } + return nil +} + +// ShouldBind 尝试根据 Content-Type 将请求体绑定到结构体 +// 支持的类型:application/json, application/x-www-form-urlencoded, multipart/form-data, application/wanf, application/vnd.wjqserver.wanf, application/gob func (c *Context) ShouldBind(obj any) error { - // TODO: 完整的通用绑定逻辑 - // 可以根据 c.Request.Header.Get("Content-Type") 来判断是 JSON, Form, XML 等 - // 例如: - // contentType := c.Request.Header.Get("Content-Type") - // if strings.HasPrefix(contentType, "application/json") { - // return c.ShouldBindJSON(obj) - // } - // if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") || strings.HasPrefix(contentType, "multipart/form-data") { - // return c.ShouldBindForm(obj) // 需要实现 ShouldBindForm - // } - return errors.New("generic binding not fully implemented yet, implement based on Content-Type") + contentType := c.Request.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return fmt.Errorf("invalid content type: %w", err) + } + + switch mediaType { + case "application/json": + return c.ShouldBindJSON(obj) + case "application/x-www-form-urlencoded", "multipart/form-data": + return c.ShouldBindForm(obj) + case "application/wanf", "application/vnd.wjqserver.wanf": + return c.ShouldBindWANF(obj) + case "application/gob": + return c.ShouldBindGOB(obj) + default: + return fmt.Errorf("unsupported content type: %s", mediaType) + } } // AddError 添加一个错误到 Context diff --git a/ecw.go b/ecw.go index c87be28..754571f 100644 --- a/ecw.go +++ b/ecw.go @@ -7,6 +7,7 @@ package touka import ( "bufio" "errors" + "maps" "net" "net/http" "sync" @@ -27,7 +28,7 @@ type errorCapturingResponseWriter struct { // errorResponseWriterPool 是用于复用 errorCapturingResponseWriter 实例的对象池 var errorResponseWriterPool = sync.Pool{ - New: func() interface{} { + New: func() any { return &errorCapturingResponseWriter{ headerSnapshot: make(http.Header), // 预先初始化 map, 减少 reset 时的分配 } @@ -91,9 +92,8 @@ func (ecw *errorCapturingResponseWriter) WriteHeader(statusCode int) { // 是成功状态码 // 将 ecw.headerSnapshot 中(由 FileServer 在此之前通过 ecw.Header() 设置的) // 任何头部直接复制到原始的 w.Header(), 确保多值头部正确传递 - for k, v := range ecw.headerSnapshot { - ecw.w.Header()[k] = v // 直接赋值 []string, 保留所有值 - } + // 直接赋值 []string, 保留所有值 + maps.Copy(ecw.w.Header(), ecw.headerSnapshot) ecw.w.WriteHeader(statusCode) // 实际写入状态码到原始 ResponseWriter ecw.responseStarted = true // 标记成功响应已开始 } @@ -112,9 +112,8 @@ func (ecw *errorCapturingResponseWriter) Write(data []byte) (int, error) { ecw.statusCode = http.StatusOK // 隐式 200 OK } // 将 headerSnapshot 中的头部复制到原始 ResponseWriter 的 Header - for k, v := range ecw.headerSnapshot { - ecw.w.Header()[k] = v // 直接赋值 []string, 保留所有值 - } + // 直接赋值 []string, 保留所有值 + maps.Copy(ecw.w.Header(), ecw.headerSnapshot) ecw.w.WriteHeader(ecw.Status()) // 发送实际的状态码 (可能是 200 或之前设置的 2xx) ecw.responseStarted = true } diff --git a/engine.go b/engine.go index 1e7bb18..f236624 100644 --- a/engine.go +++ b/engine.go @@ -51,7 +51,7 @@ type Engine struct { LogReco *reco.Logger - HTMLRender interface{} // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口 + HTMLRender any // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口 routesInfo []RouteInfo // 存储所有注册的路由信息 @@ -215,11 +215,11 @@ func New() *Engine { engine.SetDefaultProtocols() engine.SetLoggerCfg(defaultLogRecoConfig) // 初始化 Context Pool,为每个新 Context 实例提供一个构造函数 - engine.pool.New = func() interface{} { + engine.pool.New = func() any { return &Context{ Writer: newResponseWriter(nil), // 初始时可以传入nil,在ServeHTTP中会重新设置实际的 http.ResponseWriter Params: make(Params, 0, engine.maxParams), // 预分配 Params 切片以减少内存分配 - Keys: make(map[string]interface{}), + Keys: make(map[string]any), Errors: make([]error, 0), ctx: context.Background(), // 初始上下文,后续会被请求的 Context 覆盖 HTTPClient: engine.HTTPClient, diff --git a/recovery.go b/recovery.go index 5dfb837..dc4d892 100644 --- a/recovery.go +++ b/recovery.go @@ -18,7 +18,7 @@ import ( // PanicHandlerFunc 定义了用户自定义的 panic 处理函数类型 // 它接收当前的 Context 和 panic 的值 -type PanicHandlerFunc func(c *Context, panicInfo interface{}) +type PanicHandlerFunc func(c *Context, panicInfo any) // RecoveryWithOptions 返回一个可配置的 panic 恢复中间件 // @@ -50,7 +50,7 @@ func Recovery() HandlerFunc { } // defaultPanicHandler 是默认的 panic 处理逻辑 -func defaultPanicHandler(c *Context, r interface{}) { +func defaultPanicHandler(c *Context, r any) { // 检查连接是否已由客户端关闭 // 常见的错误类型包括 net.OpError (其内部错误可能是 os.SyscallError), // 以及在 HTTP/2 中可能出现的特定 stream 错误 @@ -107,7 +107,7 @@ func defaultPanicHandler(c *Context, r interface{}) { // isBrokenPipeError 检查 recover() 捕获的值是否表示一个由客户端断开连接引起的网络错误 // 这对于防止在已关闭的连接上写入响应至关重要 -func isBrokenPipeError(r interface{}) bool { +func isBrokenPipeError(r any) bool { // 将 recover() 的结果转换为 error 类型 err, ok := r.(error) if !ok { diff --git a/sse.go b/sse.go index 3b98800..ab6c226 100644 --- a/sse.go +++ b/sse.go @@ -40,8 +40,8 @@ func (e *Event) Render(w io.Writer) error { buf.WriteString("\n") } if len(e.Data) > 0 { - lines := strings.Split(e.Data, "\n") - for _, line := range lines { + lines := strings.SplitSeq(e.Data, "\n") + for line := range lines { buf.WriteString("data: ") buf.WriteString(line) buf.WriteString("\n") diff --git a/touka.go b/touka.go index 837d62d..dd529cb 100644 --- a/touka.go +++ b/touka.go @@ -12,7 +12,7 @@ const ( defaultMemory = 32 << 20 // 32 MB, Gin 的默认值,用于 ParseMultipartForm ) -type H map[string]interface{} // map简写, 类似gin.H +type H map[string]any // map简写, 类似gin.H type Handle func(http.ResponseWriter, *http.Request, Params)