mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
Make Context keys lazy so requests that never call Set stop allocating on reset. Reuse stable 404 and 405 handlers and add focused benchmarks so ServeHTTP miss paths stay measurable.
1293 lines
38 KiB
Go
1293 lines
38 KiB
Go
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||
// Copyright 2024 WJQSERVER. All rights reserved.
|
||
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||
package touka
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/gob"
|
||
"errors"
|
||
"fmt"
|
||
"html/template"
|
||
"io"
|
||
"math"
|
||
"mime"
|
||
"net/http"
|
||
"net/netip"
|
||
"net/url"
|
||
"os"
|
||
"path/filepath"
|
||
"reflect"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/WJQSERVER/wanf"
|
||
"github.com/fenthope/reco"
|
||
"github.com/go-json-experiment/json"
|
||
|
||
"github.com/WJQSERVER-STUDIO/go-utils/iox"
|
||
"github.com/WJQSERVER-STUDIO/httpc"
|
||
)
|
||
|
||
const abortIndex int8 = math.MaxInt8 >> 1
|
||
|
||
// Context 是每个请求的上下文,封装了请求和响应,并提供了很多便捷方法
|
||
// 它在中间件和最终处理函数之间传递
|
||
type Context struct {
|
||
Writer ResponseWriter // 包装的 http.ResponseWriter
|
||
Request *http.Request
|
||
Params Params // 从 httprouter 获取的路径参数
|
||
handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler)
|
||
index int8 // 当前执行到处理链的哪个位置
|
||
|
||
requestBodyPrepared bool
|
||
|
||
mu sync.RWMutex
|
||
Keys map[string]any // 用于在中间件之间传递数据
|
||
|
||
Errors []error // 用于收集处理过程中的错误
|
||
|
||
// 缓存查询参数和表单数据
|
||
queryCache url.Values
|
||
formCache url.Values
|
||
|
||
// 携带ctx以实现关闭逻辑
|
||
ctx context.Context
|
||
|
||
// HTTPClient 用于在此上下文中执行出站 HTTP 请求
|
||
// 它由 Engine 提供
|
||
HTTPClient *httpc.Client
|
||
|
||
// 引用所属的 Engine 实例,方便访问 Engine 的配置(如 HTMLRender)
|
||
engine *Engine
|
||
|
||
sameSite http.SameSite
|
||
|
||
// 请求体Body大小限制
|
||
MaxRequestBodySize int64
|
||
|
||
// skippedNodes 用于记录跳过的节点信息,以便回溯
|
||
// 通常在处理嵌套路由时使用
|
||
SkippedNodes []skippedNode
|
||
}
|
||
|
||
// --- Context 相关方法实现 ---
|
||
|
||
// reset 重置 Context 对象以供复用
|
||
// 每次从 sync.Pool 中获取 Context 后,都需要调用此方法进行初始化
|
||
func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
|
||
|
||
if rw, ok := c.Writer.(*responseWriterImpl); ok && !rw.IsHijacked() {
|
||
rw.reset(w)
|
||
} else {
|
||
c.Writer = newResponseWriter(w)
|
||
}
|
||
|
||
c.Request = req
|
||
//c.Params = c.Params[:0] // 清空 Params 切片,而不是重新分配,以复用底层数组
|
||
//避免params长度为0
|
||
if cap(c.Params) > 0 {
|
||
c.Params = c.Params[:0]
|
||
} else {
|
||
c.Params = make(Params, 0, c.engine.maxParams)
|
||
}
|
||
c.handlers = nil
|
||
c.index = -1 // 初始为 -1,`Next()` 将其设置为 0
|
||
c.Keys = nil // 仅在首次 Set 时创建,避免每个请求都分配 map
|
||
c.Errors = c.Errors[:0] // 清空 Errors 切片
|
||
c.queryCache = nil // 清空查询参数缓存
|
||
c.formCache = nil // 清空表单数据缓存
|
||
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
|
||
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
|
||
c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize
|
||
c.requestBodyPrepared = false
|
||
|
||
if cap(c.SkippedNodes) > 0 {
|
||
c.SkippedNodes = c.SkippedNodes[:0]
|
||
} else {
|
||
c.SkippedNodes = make([]skippedNode, 0, 256)
|
||
}
|
||
}
|
||
|
||
// Next 在处理链中执行下一个处理函数
|
||
// 这是中间件模式的核心,允许请求依次经过多个处理函数
|
||
func (c *Context) Next() {
|
||
c.index++
|
||
for c.index < int8(len(c.handlers)) {
|
||
c.handlers[c.index](c) // 执行当前索引处的处理函数
|
||
c.index++ // 移动到下一个处理函数
|
||
}
|
||
}
|
||
|
||
// Abort 停止处理链的后续执行
|
||
// 通常在中间件中,当遇到错误或需要提前终止请求时调用
|
||
func (c *Context) Abort() {
|
||
c.index = abortIndex // 将 index 设置为一个很大的值,使后续 Next() 调用跳过所有处理函数
|
||
}
|
||
|
||
// IsAborted 返回处理链是否已被中止
|
||
func (c *Context) IsAborted() bool {
|
||
return c.index >= abortIndex
|
||
}
|
||
|
||
// AbortWithStatus 中止处理链并设置 HTTP 状态码
|
||
func (c *Context) AbortWithStatus(code int) {
|
||
c.Writer.WriteHeader(code) // 设置响应状态码
|
||
c.Abort() // 中止处理链
|
||
}
|
||
|
||
// Set 将一个键值对存储到 Context 中
|
||
// 这是一个线程安全的操作,用于在中间件之间传递数据
|
||
func (c *Context) Set(key string, value any) {
|
||
c.mu.Lock() // 加写锁
|
||
if c.Keys == nil {
|
||
c.Keys = make(map[string]any)
|
||
}
|
||
c.Keys[key] = value
|
||
c.mu.Unlock() // 解写锁
|
||
}
|
||
|
||
// Get 从 Context 中获取一个值
|
||
// 这是一个线程安全的操作
|
||
func (c *Context) Get(key string) (value any, exists bool) {
|
||
c.mu.RLock() // 加读锁
|
||
value, exists = c.Keys[key]
|
||
c.mu.RUnlock() // 解读锁
|
||
return
|
||
}
|
||
|
||
// GetString 从 Context 中获取一个字符串值
|
||
// 这是一个线程安全的操作
|
||
func (c *Context) GetString(key string) (value string, exists bool) {
|
||
if val, exists := c.Get(key); exists {
|
||
if str, ok := val.(string); ok {
|
||
return str, true
|
||
}
|
||
}
|
||
return "", false
|
||
}
|
||
|
||
// GetInt 从 Context 中获取一个 int 值
|
||
// 这是一个线程安全的操作
|
||
func (c *Context) GetInt(key string) (value int, exists bool) {
|
||
if val, exists := c.Get(key); exists {
|
||
if i, ok := val.(int); ok {
|
||
return i, true
|
||
}
|
||
}
|
||
return 0, false
|
||
}
|
||
|
||
// GetBool 从 Context 中获取一个 bool 值
|
||
// 这是一个线程安全的操作
|
||
func (c *Context) GetBool(key string) (value bool, exists bool) {
|
||
if val, exists := c.Get(key); exists {
|
||
if b, ok := val.(bool); ok {
|
||
return b, true
|
||
}
|
||
}
|
||
return false, false
|
||
}
|
||
|
||
// GetFloat64 从 Context 中获取一个 float64 值
|
||
// 这是一个线程安全的操作
|
||
func (c *Context) GetFloat64(key string) (value float64, exists bool) {
|
||
if val, exists := c.Get(key); exists {
|
||
if f, ok := val.(float64); ok {
|
||
return f, true
|
||
}
|
||
}
|
||
return 0.0, false
|
||
}
|
||
|
||
// GetTime 从 Context 中获取一个 time.Time 值
|
||
// 这是一个线程安全的操作
|
||
func (c *Context) GetTime(key string) (value time.Time, exists bool) {
|
||
if val, exists := c.Get(key); exists {
|
||
if t, ok := val.(time.Time); ok {
|
||
return t, true
|
||
}
|
||
}
|
||
return time.Time{}, false
|
||
}
|
||
|
||
// GetDuration 从 Context 中获取一个 time.Duration 值
|
||
// 这是一个线程安全的操作
|
||
func (c *Context) GetDuration(key string) (value time.Duration, exists bool) {
|
||
if val, exists := c.Get(key); exists {
|
||
if d, ok := val.(time.Duration); ok {
|
||
return d, true
|
||
}
|
||
}
|
||
return 0, false
|
||
}
|
||
|
||
// MustGet 从 Context 中获取一个值,如果不存在则 panic
|
||
// 适用于确定值一定存在的场景
|
||
func (c *Context) MustGet(key string) any {
|
||
if value, exists := c.Get(key); exists {
|
||
return value
|
||
}
|
||
panic("Key \"" + key + "\" does not exist in context.")
|
||
}
|
||
|
||
// SetMaxRequestBodySize
|
||
func (c *Context) SetMaxRequestBodySize(size int64) {
|
||
c.MaxRequestBodySize = size
|
||
}
|
||
|
||
func (c *Context) prepareRequestBody() io.ReadCloser {
|
||
if c.Request == nil || c.Request.Body == nil {
|
||
return nil
|
||
}
|
||
if c.requestBodyPrepared || c.MaxRequestBodySize <= 0 {
|
||
return c.Request.Body
|
||
}
|
||
c.Request.Body = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
||
c.requestBodyPrepared = true
|
||
return c.Request.Body
|
||
}
|
||
|
||
// Query 从 URL 查询参数中获取值
|
||
// 懒加载解析查询参数,并进行缓存
|
||
func (c *Context) Query(key string) string {
|
||
if c.queryCache == nil {
|
||
c.queryCache = c.Request.URL.Query() // 首次访问时解析并缓存
|
||
}
|
||
return c.queryCache.Get(key)
|
||
}
|
||
|
||
// DefaultQuery 从 URL 查询参数中获取值,如果不存在则返回默认值
|
||
func (c *Context) DefaultQuery(key, defaultValue string) string {
|
||
if value := c.Query(key); value != "" {
|
||
return value
|
||
}
|
||
return defaultValue
|
||
}
|
||
|
||
// PostForm 从 POST 请求体中获取表单值
|
||
// 懒加载解析表单数据,并进行缓存
|
||
func (c *Context) PostForm(key string) string {
|
||
if c.formCache == nil {
|
||
if c.MaxRequestBodySize > 0 {
|
||
c.prepareRequestBody()
|
||
}
|
||
contentType := c.Request.Header.Get("Content-Type")
|
||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||
c.formCache = make(url.Values)
|
||
return ""
|
||
}
|
||
|
||
switch mediaType {
|
||
case "multipart/form-data":
|
||
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
|
||
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||
c.formCache = make(url.Values)
|
||
return ""
|
||
}
|
||
case "application/x-www-form-urlencoded":
|
||
if err := c.Request.ParseForm(); err != nil {
|
||
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||
c.formCache = make(url.Values)
|
||
return ""
|
||
}
|
||
default:
|
||
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
|
||
if !errors.Is(err, http.ErrNotMultipart) {
|
||
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||
c.formCache = make(url.Values)
|
||
return ""
|
||
}
|
||
}
|
||
}
|
||
c.formCache = c.Request.PostForm
|
||
}
|
||
return c.formCache.Get(key)
|
||
}
|
||
|
||
// DefaultPostForm 从 POST 请求体中获取表单值,如果不存在则返回默认值
|
||
func (c *Context) DefaultPostForm(key, defaultValue string) string {
|
||
if value := c.PostForm(key); value != "" {
|
||
return value
|
||
}
|
||
return defaultValue
|
||
}
|
||
|
||
// Param 从 URL 路径参数中获取值
|
||
// 例如,对于路由 /users/:id,c.Param("id") 可以获取 id 的值
|
||
func (c *Context) Param(key string) string {
|
||
return c.Params.ByName(key)
|
||
}
|
||
|
||
// Raw 向响应写入bytes
|
||
func (c *Context) Raw(code int, contentType string, data []byte) {
|
||
c.Writer.Header().Set("Content-Type", contentType)
|
||
c.Writer.WriteHeader(code)
|
||
c.Writer.Write(data)
|
||
}
|
||
|
||
// String 向响应写入格式化的字符串
|
||
func (c *Context) String(code int, format string, values ...any) {
|
||
c.Writer.WriteHeader(code)
|
||
c.Writer.Write(fmt.Appendf(nil, format, values...))
|
||
}
|
||
|
||
// Text 向响应写入无需格式化的string
|
||
func (c *Context) Text(code int, text string) {
|
||
c.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||
c.Writer.WriteHeader(code)
|
||
c.Writer.Write([]byte(text))
|
||
}
|
||
|
||
// FileText
|
||
func (c *Context) FileText(code int, filePath string) {
|
||
// 清理path
|
||
cleanPath := filepath.Clean(filePath)
|
||
if !filepath.IsAbs(cleanPath) {
|
||
c.AddError(fmt.Errorf("relative path not allowed: %s", cleanPath))
|
||
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("relative path not allowed"))
|
||
return
|
||
}
|
||
// 检查文件是否存在
|
||
if _, err := os.Stat(cleanPath); os.IsNotExist(err) {
|
||
c.AddError(fmt.Errorf("file not found: %s", cleanPath))
|
||
c.ErrorUseHandle(http.StatusNotFound, fmt.Errorf("file not found"))
|
||
return
|
||
}
|
||
|
||
// 打开文件
|
||
file, err := os.Open(cleanPath)
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("failed to open file %s: %w", cleanPath, err))
|
||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to open file: %w", err))
|
||
return
|
||
}
|
||
defer file.Close()
|
||
|
||
// 获取文件信息以获取文件大小
|
||
fileInfo, err := file.Stat()
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("failed to get file info for %s: %w", cleanPath, err))
|
||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to get file info: %w", err))
|
||
return
|
||
}
|
||
// 判断是否是dir
|
||
if fileInfo.IsDir() {
|
||
c.AddError(fmt.Errorf("path is a directory, not a file: %s", cleanPath))
|
||
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("path is a directory"))
|
||
return
|
||
}
|
||
|
||
c.SetHeader("Content-Type", "text/plain; charset=utf-8")
|
||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size()))
|
||
c.Writer.WriteHeader(code)
|
||
if _, err := iox.Copy(c.Writer, file); err != nil {
|
||
c.AddError(fmt.Errorf("failed to write file %s to response: %w", cleanPath, err))
|
||
}
|
||
}
|
||
|
||
/*
|
||
// FileTextSafeDir
|
||
func (c *Context) FileTextSafeDir(code int, filePath string, safeDir string) {
|
||
|
||
// 清理path
|
||
cleanPath := path.Clean(filePath)
|
||
if !filepath.IsAbs(cleanPath) {
|
||
c.AddError(fmt.Errorf("relative path not allowed: %s", cleanPath))
|
||
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("relative path not allowed"))
|
||
return
|
||
}
|
||
if strings.Contains(cleanPath, "..") {
|
||
c.AddError(fmt.Errorf("path traversal attempt detected: %s", cleanPath))
|
||
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("path traversal attempt detected"))
|
||
return
|
||
}
|
||
|
||
// 判断filePath是否包含在safeDir内, 防止路径穿越
|
||
relPath, err := filepath.Rel(safeDir, cleanPath)
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("failed to get relative path: %w", err))
|
||
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("failed to get relative path: %w", err))
|
||
return
|
||
}
|
||
cleanPath = filepath.Join(safeDir, relPath)
|
||
|
||
// 检查文件是否存在
|
||
if _, err := os.Stat(cleanPath); os.IsNotExist(err) {
|
||
c.AddError(fmt.Errorf("file not found: %s", cleanPath))
|
||
c.ErrorUseHandle(http.StatusNotFound, fmt.Errorf("file not found"))
|
||
return
|
||
}
|
||
|
||
// 打开文件
|
||
file, err := os.Open(cleanPath)
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("failed to open file %s: %w", cleanPath, err))
|
||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to open file: %w", err))
|
||
return
|
||
}
|
||
defer file.Close()
|
||
|
||
// 获取文件信息以获取文件大小
|
||
fileInfo, err := file.Stat()
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("failed to get file info for %s: %w", cleanPath, err))
|
||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to get file info: %w", err))
|
||
return
|
||
}
|
||
// 判断是否是dir
|
||
if fileInfo.IsDir() {
|
||
c.AddError(fmt.Errorf("path is a directory, not a file: %s", cleanPath))
|
||
c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("path is a directory"))
|
||
return
|
||
}
|
||
|
||
c.SetHeader("Content-Type", "text/plain; charset=utf-8")
|
||
|
||
c.SetBodyStream(file, int(fileInfo.Size()))
|
||
}
|
||
*/
|
||
|
||
// JSON 向响应写入 JSON 数据
|
||
// 设置 Content-Type 为 application/json
|
||
func (c *Context) JSON(code int, obj any) {
|
||
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||
c.Writer.WriteHeader(code)
|
||
if err := json.MarshalWrite(c.Writer, obj); err != nil {
|
||
c.AddError(fmt.Errorf("failed to marshal JSON: %w", err))
|
||
c.Errorf("failed to marshal JSON: %s", err)
|
||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to marshal JSON: %w", err))
|
||
return
|
||
}
|
||
}
|
||
|
||
// JSONBuf 先将 JSON 编码到 buffer, 成功后再写入状态码和响应体.
|
||
// 与 JSON 相比,编码失败时可以正确返回 500 状态码,代价是多一次内存分配.
|
||
func (c *Context) JSONBuf(code int, obj any) {
|
||
var buf bytes.Buffer
|
||
if err := json.MarshalWrite(&buf, obj); err != nil {
|
||
errMsg := fmt.Errorf("failed to marshal JSON: %w", err)
|
||
c.AddError(errMsg)
|
||
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
|
||
return
|
||
}
|
||
|
||
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||
c.Writer.WriteHeader(code)
|
||
c.Writer.Write(buf.Bytes())
|
||
}
|
||
|
||
// GOB 向响应写入GOB数据
|
||
// 设置 Content-Type 为 application/octet-stream
|
||
func (c *Context) GOB(code int, obj any) {
|
||
c.Writer.Header().Set("Content-Type", "application/octet-stream") // 设置合适的 Content-Type
|
||
c.Writer.WriteHeader(code)
|
||
// GOB 编码
|
||
encoder := gob.NewEncoder(c.Writer)
|
||
if err := encoder.Encode(obj); err != nil {
|
||
c.AddError(fmt.Errorf("failed to encode GOB: %w", err))
|
||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to encode GOB: %w", err))
|
||
return
|
||
}
|
||
}
|
||
|
||
// GOBBuf 先将 GOB 编码到 buffer, 成功后再写入状态码和响应体.
|
||
func (c *Context) GOBBuf(code int, obj any) {
|
||
var buf bytes.Buffer
|
||
encoder := gob.NewEncoder(&buf)
|
||
if err := encoder.Encode(obj); err != nil {
|
||
errMsg := fmt.Errorf("failed to encode GOB: %w", err)
|
||
c.AddError(errMsg)
|
||
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
|
||
return
|
||
}
|
||
c.Writer.Header().Set("Content-Type", "application/octet-stream")
|
||
c.Writer.WriteHeader(code)
|
||
c.Writer.Write(buf.Bytes())
|
||
}
|
||
|
||
// WANF向响应写入WANF数据
|
||
// 设置 application/vnd.wjqserver.wanf; charset=utf-8
|
||
func (c *Context) WANF(code int, obj any) {
|
||
c.Writer.Header().Set("Content-Type", "application/vnd.wjqserver.wanf; charset=utf-8")
|
||
c.Writer.WriteHeader(code)
|
||
// WANF 编码
|
||
encoder := wanf.NewStreamEncoder(c.Writer)
|
||
if err := encoder.Encode(obj); err != nil {
|
||
c.AddError(fmt.Errorf("failed to encode WANF: %w", err))
|
||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to encode WANF: %w", err))
|
||
return
|
||
}
|
||
}
|
||
|
||
// WANFBuf 先将 WANF 编码到 buffer, 成功后再写入状态码和响应体.
|
||
func (c *Context) WANFBuf(code int, obj any) {
|
||
var buf bytes.Buffer
|
||
encoder := wanf.NewStreamEncoder(&buf)
|
||
if err := encoder.Encode(obj); err != nil {
|
||
errMsg := fmt.Errorf("failed to encode WANF: %w", err)
|
||
c.AddError(errMsg)
|
||
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
|
||
return
|
||
}
|
||
c.Writer.Header().Set("Content-Type", "application/vnd.wjqserver.wanf; charset=utf-8")
|
||
c.Writer.WriteHeader(code)
|
||
c.Writer.Write(buf.Bytes())
|
||
}
|
||
|
||
// HTML 渲染 HTML 模板
|
||
// 如果 Engine 配置了 HTMLRender,则使用它进行渲染
|
||
// 否则,会进行简单的字符串输出
|
||
// 预留接口,可以扩展为支持多种模板引擎
|
||
func (c *Context) HTML(code int, name string, obj any) {
|
||
c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||
c.Writer.WriteHeader(code)
|
||
|
||
if c.engine != nil && c.engine.HTMLRender != nil {
|
||
// 假设 HTMLRender 是一个 *template.Template 实例
|
||
if tpl, ok := c.engine.HTMLRender.(*template.Template); ok {
|
||
err := tpl.ExecuteTemplate(c.Writer, name, obj)
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("failed to render HTML template '%s': %w", name, err))
|
||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to render HTML template '%s': %w", name, err))
|
||
}
|
||
return
|
||
}
|
||
// 可以扩展支持其他渲染器接口
|
||
}
|
||
// 默认简单输出,用于未配置 HTMLRender 的情况
|
||
c.Writer.Write(fmt.Appendf(nil, "<!-- HTML rendered for %s -->\n<pre>%v</pre>", name, obj))
|
||
}
|
||
|
||
// HTMLBuf 先将 HTML 模板渲染到 buffer, 成功后再写入状态码和响应体.
|
||
// 如果模板渲染失败,则返回 500 错误且不写入任何内容.
|
||
func (c *Context) HTMLBuf(code int, name string, obj any) {
|
||
if c.engine == nil || c.engine.HTMLRender == nil {
|
||
// 没有渲染器,回退到简单输出
|
||
c.HTML(code, name, obj)
|
||
return
|
||
}
|
||
|
||
if tpl, ok := c.engine.HTMLRender.(*template.Template); ok {
|
||
var buf bytes.Buffer
|
||
err := tpl.ExecuteTemplate(&buf, name, obj)
|
||
if err != nil {
|
||
// 渲染失败,记录错误并返回 500,不写入任何内容
|
||
errMsg := fmt.Errorf("failed to render HTML template '%s': %w", name, err)
|
||
c.AddError(errMsg)
|
||
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
|
||
return
|
||
}
|
||
// 渲染成功,写入响应
|
||
c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||
c.Writer.WriteHeader(code)
|
||
c.Writer.Write(buf.Bytes())
|
||
return
|
||
}
|
||
|
||
// 不支持的渲染器类型,回退到简单输出
|
||
c.HTML(code, name, obj)
|
||
}
|
||
|
||
// Redirect 执行 HTTP 重定向
|
||
// code 应为 3xx 状态码 (如 http.StatusMovedPermanently, http.StatusFound)
|
||
func (c *Context) Redirect(code int, location string) {
|
||
http.Redirect(c.Writer, c.Request, location, code)
|
||
c.Abort()
|
||
if fl, ok := c.Writer.(http.Flusher); ok {
|
||
fl.Flush()
|
||
}
|
||
}
|
||
|
||
// ShouldBindJSON 尝试将请求体绑定到 JSON 对象
|
||
func (c *Context) ShouldBindJSON(obj any) error {
|
||
var body io.ReadCloser
|
||
if c.MaxRequestBodySize > 0 {
|
||
body = c.prepareRequestBody()
|
||
} else {
|
||
body = c.Request.Body
|
||
}
|
||
if body == nil {
|
||
return errors.New("request body is empty")
|
||
}
|
||
err := json.UnmarshalRead(body, obj)
|
||
if err != nil {
|
||
return fmt.Errorf("json binding error: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象
|
||
func (c *Context) ShouldBindWANF(obj any) error {
|
||
var body io.ReadCloser
|
||
if c.MaxRequestBodySize > 0 {
|
||
body = c.prepareRequestBody()
|
||
} else {
|
||
body = c.Request.Body
|
||
}
|
||
if body == nil {
|
||
return errors.New("request body is empty")
|
||
}
|
||
decoder, err := wanf.NewStreamDecoder(body)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create WANF decoder: %w", err)
|
||
}
|
||
|
||
if err := decoder.Decode(obj); err != nil {
|
||
return fmt.Errorf("WANF binding error: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象
|
||
func (c *Context) ShouldBindGOB(obj any) error {
|
||
var body io.ReadCloser
|
||
if c.MaxRequestBodySize > 0 {
|
||
body = c.prepareRequestBody()
|
||
} else {
|
||
body = c.Request.Body
|
||
}
|
||
if body == nil {
|
||
return errors.New("request body is empty")
|
||
}
|
||
decoder := gob.NewDecoder(body)
|
||
if err := decoder.Decode(obj); err != nil {
|
||
return fmt.Errorf("GOB binding error: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// bindForm 将 url.Values 绑定到结构体
|
||
// 支持 form tag 标签,如 `form:"field_name"`
|
||
func bindForm(values url.Values, obj any) error {
|
||
val := reflect.ValueOf(obj)
|
||
if val.Kind() != reflect.Pointer || val.Elem().Kind() != reflect.Struct {
|
||
return errors.New("obj must be a pointer to struct")
|
||
}
|
||
|
||
val = val.Elem()
|
||
typ := val.Type()
|
||
|
||
for i := 0; i < val.NumField(); i++ {
|
||
field := val.Field(i)
|
||
fieldType := typ.Field(i)
|
||
|
||
if !field.CanSet() {
|
||
continue
|
||
}
|
||
|
||
tag := fieldType.Tag.Get("form")
|
||
if tag == "" {
|
||
tag = fieldType.Name
|
||
}
|
||
if tag == "-" {
|
||
continue
|
||
}
|
||
|
||
formValues := values[tag]
|
||
if len(formValues) == 0 {
|
||
continue
|
||
}
|
||
|
||
if err := setFieldValue(field, formValues); err != nil {
|
||
return fmt.Errorf("field %s: %w", fieldType.Name, err)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// setFieldValue 将字符串值设置到反射值
|
||
func setFieldValue(field reflect.Value, values []string) error {
|
||
if !field.CanSet() {
|
||
return nil
|
||
}
|
||
|
||
value := values[0]
|
||
|
||
switch field.Kind() {
|
||
case reflect.String:
|
||
field.SetString(value)
|
||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||
if value == "" {
|
||
return nil
|
||
}
|
||
v, err := strconv.ParseInt(value, 10, 64)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
field.SetInt(v)
|
||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||
if value == "" {
|
||
return nil
|
||
}
|
||
v, err := strconv.ParseUint(value, 10, 64)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
field.SetUint(v)
|
||
case reflect.Float32, reflect.Float64:
|
||
if value == "" {
|
||
return nil
|
||
}
|
||
v, err := strconv.ParseFloat(value, 64)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
field.SetFloat(v)
|
||
case reflect.Bool:
|
||
if value == "" {
|
||
return nil
|
||
}
|
||
v, err := strconv.ParseBool(value)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
field.SetBool(v)
|
||
case reflect.Pointer:
|
||
if field.IsNil() {
|
||
field.Set(reflect.New(field.Type().Elem()))
|
||
}
|
||
return setFieldValue(field.Elem(), values)
|
||
case reflect.Slice:
|
||
slice := reflect.MakeSlice(field.Type(), len(values), len(values))
|
||
elemType := field.Type().Elem()
|
||
for i, v := range values {
|
||
if err := setFieldValue(slice.Index(i), []string{v}); err != nil {
|
||
return err
|
||
}
|
||
_ = elemType
|
||
}
|
||
field.Set(slice)
|
||
default:
|
||
return fmt.Errorf("unsupported type: %s", field.Kind())
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ShouldBindForm 尝试将表单数据绑定到结构体
|
||
// 支持 application/x-www-form-urlencoded 和 multipart/form-data
|
||
func (c *Context) ShouldBindForm(obj any) error {
|
||
if c.MaxRequestBodySize > 0 {
|
||
c.prepareRequestBody()
|
||
}
|
||
|
||
contentType := c.Request.Header.Get("Content-Type")
|
||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||
if err != nil {
|
||
return fmt.Errorf("invalid content type: %w", err)
|
||
}
|
||
|
||
switch mediaType {
|
||
case "multipart/form-data":
|
||
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
|
||
return fmt.Errorf("parse multipart form error: %w", err)
|
||
}
|
||
case "application/x-www-form-urlencoded":
|
||
if err := c.Request.ParseForm(); err != nil {
|
||
return fmt.Errorf("parse form error: %w", err)
|
||
}
|
||
default:
|
||
return fmt.Errorf("unsupported form content type: %s", mediaType)
|
||
}
|
||
|
||
if err := bindForm(c.Request.Form, obj); err != nil {
|
||
return fmt.Errorf("form binding error: %w", err)
|
||
}
|
||
c.formCache = c.Request.PostForm
|
||
return nil
|
||
}
|
||
|
||
// ShouldBind 尝试根据 Content-Type 将请求体绑定到结构体
|
||
// 支持的类型:application/json, application/x-www-form-urlencoded, multipart/form-data, application/wanf, application/vnd.wjqserver.wanf, application/gob
|
||
func (c *Context) ShouldBind(obj any) error {
|
||
contentType := c.Request.Header.Get("Content-Type")
|
||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||
if err != nil {
|
||
return fmt.Errorf("invalid content type: %w", err)
|
||
}
|
||
|
||
switch mediaType {
|
||
case "application/json":
|
||
return c.ShouldBindJSON(obj)
|
||
case "application/x-www-form-urlencoded", "multipart/form-data":
|
||
return c.ShouldBindForm(obj)
|
||
case "application/wanf", "application/vnd.wjqserver.wanf":
|
||
return c.ShouldBindWANF(obj)
|
||
case "application/gob":
|
||
return c.ShouldBindGOB(obj)
|
||
default:
|
||
return fmt.Errorf("unsupported content type: %s", mediaType)
|
||
}
|
||
}
|
||
|
||
// AddError 添加一个错误到 Context
|
||
// 允许在处理请求过程中收集多个错误
|
||
func (c *Context) AddError(err error) {
|
||
c.Errors = append(c.Errors, err)
|
||
}
|
||
|
||
// Errors 返回 Context 中收集的所有错误
|
||
func (c *Context) GetErrors() []error {
|
||
return c.Errors
|
||
}
|
||
|
||
// Client 返回 Engine 提供的 HTTPClient
|
||
// 方便在请求处理函数中进行出站 HTTP 请求
|
||
func (c *Context) Client() *httpc.Client {
|
||
return c.HTTPClient
|
||
}
|
||
|
||
// Context() 返回请求的上下文,用于取消操作
|
||
// 这是 Go 标准库的 `context.Context`,用于请求的取消和超时管理
|
||
func (c *Context) Context() context.Context {
|
||
return c.ctx
|
||
}
|
||
|
||
// Done returns a channel that is closed when the request context is cancelled or times out.
|
||
// 继承自 `context.Context`
|
||
func (c *Context) Done() <-chan struct{} {
|
||
return c.ctx.Done()
|
||
}
|
||
|
||
// Err returns the error, if any, that caused the context to be canceled or to
|
||
// time out.
|
||
// 继承自 `context.Context`
|
||
func (c *Context) Err() error {
|
||
return c.ctx.Err()
|
||
}
|
||
|
||
// Value returns the value associated with this context for key, or nil if no
|
||
// value is associated with key.
|
||
// 可以用于从 Context 中获取与特定键关联的值,包括 Go 原生 Context 的值和 Touka Context 的 Keys
|
||
func (c *Context) Value(key any) any {
|
||
if keyAsString, ok := key.(string); ok {
|
||
if val, exists := c.Get(keyAsString); exists {
|
||
return val
|
||
}
|
||
}
|
||
return c.ctx.Value(key) // 尝试从 Go 原生 Context 中获取值
|
||
}
|
||
|
||
// GetWriter 获得一个 io.Writer 接口,可以直接向响应体写入数据
|
||
// 这对于需要自定义流式写入或与其他需要 io.Writer 的库集成非常有用
|
||
func (c *Context) GetWriter() io.Writer {
|
||
return c.Writer // ResponseWriter 接口嵌入了 http.ResponseWriter,而 http.ResponseWriter 实现了 io.Writer
|
||
}
|
||
|
||
// WriteStream 接受一个 io.Reader 并将其内容流式传输到响应体
|
||
// 返回写入的字节数和可能遇到的错误
|
||
// 该方法在开始写入之前,会确保设置 HTTP 状态码为 200 OK
|
||
func (c *Context) WriteStream(reader io.Reader) (written int64, err error) {
|
||
// 确保在写入数据前设置状态码
|
||
// WriteHeader 会在第一次写入时被 Write 方法隐式调用,但显式调用可以确保状态码的预期
|
||
if !c.Writer.Written() {
|
||
c.Writer.WriteHeader(http.StatusOK) // 默认 200 OK
|
||
}
|
||
|
||
written, err = iox.Copy(c.Writer, reader) // 从 reader 读取并写入 ResponseWriter
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("failed to write stream: %w", err))
|
||
}
|
||
return written, err
|
||
}
|
||
|
||
// GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体
|
||
// 注意:请求体只能读取一次
|
||
func (c *Context) GetReqBody() io.ReadCloser {
|
||
if c.MaxRequestBodySize > 0 {
|
||
return c.prepareRequestBody()
|
||
}
|
||
if c.Request == nil || c.Request.Body == nil {
|
||
return nil
|
||
}
|
||
return c.Request.Body
|
||
}
|
||
|
||
// GetReqBodyFull 读取并返回请求体的所有内容
|
||
// 注意:请求体只能读取一次
|
||
func (c *Context) GetReqBodyFull() ([]byte, error) {
|
||
body := c.GetReqBody()
|
||
if body == nil {
|
||
return nil, nil
|
||
}
|
||
defer func() {
|
||
err := body.Close()
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||
}
|
||
}()
|
||
|
||
data, err := iox.ReadAll(body)
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||
}
|
||
return data, nil
|
||
}
|
||
|
||
// 类似 GetReqBodyFull, 返回 *bytes.Buffer
|
||
func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
|
||
body := c.GetReqBody()
|
||
if body == nil {
|
||
return nil, nil
|
||
}
|
||
defer func() {
|
||
err := body.Close()
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||
}
|
||
}()
|
||
|
||
data, err := iox.ReadAll(body)
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||
}
|
||
return bytes.NewBuffer(data), nil
|
||
}
|
||
|
||
// RequestIP 返回客户端的 IP 地址
|
||
// 它会根据 Engine 的配置 (ForwardByClientIP) 尝试从 X-Forwarded-For 或 X-Real-IP 等头部获取,
|
||
// 否则回退到 Request.RemoteAddr
|
||
func (c *Context) RequestIP() string {
|
||
if c.engine.ForwardByClientIP {
|
||
for _, headerName := range c.engine.RemoteIPHeaders {
|
||
ipValue := c.Request.Header.Get(headerName)
|
||
if ipValue == "" {
|
||
continue // 头部为空, 继续检查下一个
|
||
}
|
||
|
||
// 使用索引高效遍历逗号分隔的 IP 列表, 避免 strings.Split 的内存分配
|
||
currentPos := 0
|
||
for currentPos < len(ipValue) {
|
||
nextComma := strings.IndexByte(ipValue[currentPos:], ',')
|
||
|
||
var ipSegment string
|
||
if nextComma == -1 {
|
||
// 这是列表中的最后一个 IP
|
||
ipSegment = ipValue[currentPos:]
|
||
currentPos = len(ipValue) // 结束循环
|
||
} else {
|
||
// 截取当前 IP 段
|
||
ipSegment = ipValue[currentPos : currentPos+nextComma]
|
||
currentPos += nextComma + 1 // 移动到下一个 IP 段的起始位置
|
||
}
|
||
|
||
// 去除空格并检查是否为空 (例如 "ip1,,ip2")
|
||
trimmedIP := strings.TrimSpace(ipSegment)
|
||
if trimmedIP == "" {
|
||
continue
|
||
}
|
||
|
||
// 使用 netip.ParseAddr 进行 IP 地址的解析和验证
|
||
addr, err := netip.ParseAddr(trimmedIP)
|
||
if err == nil {
|
||
// 成功解析到合法的 IP, 立即返回
|
||
return addr.String()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 回退到 Request.RemoteAddr 的处理
|
||
// 优先使用 netip.ParseAddrPort, 它比 net.SplitHostPort 更高效且分配更少
|
||
addrp, err := netip.ParseAddrPort(c.Request.RemoteAddr)
|
||
if err == nil {
|
||
// 成功从 "ip:port" 格式中解析出 IP
|
||
return addrp.Addr().String()
|
||
}
|
||
|
||
// 如果上面的解析失败 (例如 RemoteAddr 只有 IP, 没有端口),
|
||
// 则尝试将整个字符串作为 IP 地址进行解析
|
||
addr, err := netip.ParseAddr(c.Request.RemoteAddr)
|
||
if err == nil {
|
||
return addr.String()
|
||
}
|
||
|
||
// 所有方法都失败, 返回空字符串
|
||
return ""
|
||
}
|
||
|
||
// ClientIP 返回客户端的 IP 地址
|
||
// 这是一个别名,与 RequestIP 功能相同
|
||
func (c *Context) ClientIP() string {
|
||
return c.RequestIP()
|
||
}
|
||
|
||
// ContentType 返回请求的 Content-Type 头部
|
||
func (c *Context) ContentType() string {
|
||
return c.GetReqHeader("Content-Type")
|
||
}
|
||
|
||
// UserAgent 返回请求的 User-Agent 头部
|
||
func (c *Context) UserAgent() string {
|
||
return c.GetReqHeader("User-Agent")
|
||
}
|
||
|
||
// Status 设置响应状态码
|
||
func (c *Context) Status(code int) {
|
||
c.Writer.WriteHeader(code)
|
||
}
|
||
|
||
// File 将指定路径的文件作为响应发送
|
||
// 它会设置 Content-Type 和 Content-Disposition 头部
|
||
func (c *Context) File(filepath string) {
|
||
http.ServeFile(c.Writer, c.Request, filepath)
|
||
c.Abort() // 发送文件后中止后续处理
|
||
}
|
||
|
||
// SetHeader 设置响应头部
|
||
func (c *Context) SetHeader(key, value string) {
|
||
c.Writer.Header().Set(key, value)
|
||
}
|
||
|
||
// AddHeader 添加响应头部
|
||
func (c *Context) AddHeader(key, value string) {
|
||
c.Writer.Header().Add(key, value)
|
||
}
|
||
|
||
// Header 作为SetHeader的别名
|
||
func (c *Context) Header(key, value string) {
|
||
c.SetHeader(key, value)
|
||
}
|
||
|
||
// DelHeader 删除响应头部
|
||
func (c *Context) DelHeader(key string) {
|
||
c.Writer.Header().Del(key)
|
||
}
|
||
|
||
// GetReqHeader 获取请求头部的值
|
||
func (c *Context) GetReqHeader(key string) string {
|
||
return c.Request.Header.Get(key)
|
||
}
|
||
|
||
// SetHeaders 接受headers列表
|
||
func (c *Context) SetHeaders(headers map[string][]string) {
|
||
for key, values := range headers {
|
||
for _, value := range values {
|
||
c.Writer.Header().Add(key, value)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 获取所有resp Headers
|
||
func (c *Context) GetAllRespHeader() http.Header {
|
||
return c.Writer.Header()
|
||
}
|
||
|
||
// GetAllReqHeader 获取所有请求头部
|
||
func (c *Context) GetAllReqHeader() http.Header {
|
||
return c.Request.Header
|
||
}
|
||
|
||
// 使用定义的errorHandle来处理error并结束当前handle
|
||
func (c *Context) ErrorUseHandle(code int, err error) {
|
||
if c.engine != nil && c.engine.errorHandle.handler != nil {
|
||
c.engine.errorHandle.handler(c, code, err)
|
||
c.Abort()
|
||
return
|
||
} else {
|
||
c.String(code, "%s", http.StatusText(code))
|
||
c.Abort()
|
||
}
|
||
}
|
||
|
||
// GetProtocol 获取当前连接版本
|
||
func (c *Context) GetProtocol() string {
|
||
return c.Request.Proto
|
||
}
|
||
|
||
// GetHTTPC 获取框架自带传递的httpc
|
||
func (c *Context) GetHTTPC() *httpc.Client {
|
||
return c.HTTPClient
|
||
}
|
||
|
||
// GetLogger 获取engine的Logger
|
||
func (c *Context) GetLogger() *reco.Logger {
|
||
return c.engine.LogReco
|
||
}
|
||
|
||
// GetReqQueryString
|
||
// GetReqQueryString 返回请求的原始查询字符串
|
||
func (c *Context) GetReqQueryString() string {
|
||
return c.Request.URL.RawQuery
|
||
}
|
||
|
||
// SetBodyStream 设置响应体为一个 io.Reader,并指定内容长度
|
||
// 如果 contentSize 为 -1,则表示内容长度未知,将使用 Transfer-Encoding: chunked
|
||
func (c *Context) SetBodyStream(reader io.Reader, contentSize int) {
|
||
// 如果指定了内容长度且大于等于 0,则设置 Content-Length 头部
|
||
if contentSize >= 0 {
|
||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", contentSize))
|
||
} else {
|
||
// 如果内容长度未知,移除 Content-Length 头部,通常会使用 Transfer-Encoding: chunked
|
||
c.Writer.Header().Del("Content-Length")
|
||
}
|
||
|
||
// 确保在写入数据前设置状态码
|
||
if !c.Writer.Written() {
|
||
c.Writer.WriteHeader(http.StatusOK) // 默认 200 OK
|
||
}
|
||
|
||
// 将 reader 的内容直接复制到 ResponseWriter
|
||
// ResponseWriter 实现了 io.Writer 接口
|
||
_, err := iox.Copy(c.Writer, reader)
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("failed to write stream: %w", err))
|
||
// 注意:这里可能无法设置错误状态码,因为头部可能已经发送
|
||
// 可以在调用 SetBodyStream 之前检查错误,或者在中间件中处理 Context.Errors
|
||
}
|
||
}
|
||
|
||
// GetRequestURI 返回请求的原始 URI
|
||
func (c *Context) GetRequestURI() string {
|
||
return c.Request.RequestURI
|
||
}
|
||
|
||
// GetRequestURIPath 返回请求的原始 URI 的路径部分
|
||
func (c *Context) GetRequestURIPath() string {
|
||
return c.Request.URL.Path
|
||
}
|
||
|
||
// === 文件操作 ===
|
||
|
||
// 将文件内容作为响应body
|
||
func (c *Context) SetRespBodyFile(code int, filePath string) {
|
||
// 清理path
|
||
cleanPath := filepath.Clean(filePath)
|
||
|
||
// 打开文件
|
||
file, err := os.Open(cleanPath)
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("failed to open file %s: %w", cleanPath, err))
|
||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to open file: %w", err))
|
||
return
|
||
}
|
||
defer file.Close()
|
||
|
||
// 获取文件信息以获取文件大小和MIME类型
|
||
fileInfo, err := file.Stat()
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("failed to get file info for %s: %w", cleanPath, err))
|
||
c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to get file info: %w", err))
|
||
return
|
||
}
|
||
|
||
// 尝试根据文件扩展名猜测 Content-Type
|
||
contentType := mime.TypeByExtension(filepath.Ext(cleanPath))
|
||
if contentType == "" {
|
||
// 如果无法猜测,则使用默认的二进制流类型
|
||
contentType = "application/octet-stream"
|
||
}
|
||
|
||
// 设置响应头
|
||
c.Writer.Header().Set("Content-Type", contentType)
|
||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size()))
|
||
// 还可以设置 Content-Disposition 来控制浏览器是下载还是直接显示
|
||
// c.Writer.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, path.Base(cleanPath)))
|
||
|
||
// 设置状态码
|
||
c.Writer.WriteHeader(code)
|
||
|
||
// 将文件内容写入响应体
|
||
_, err = iox.Copy(c.Writer, file)
|
||
if err != nil {
|
||
c.AddError(fmt.Errorf("failed to write file %s to response: %w", cleanPath, err))
|
||
// 注意:这里可能无法设置错误状态码,因为头部可能已经发送
|
||
// 可以在调用 SetRespBodyFile 之前检查错误,或者在中间件中处理 Context.Errors
|
||
}
|
||
c.Abort() // 文件发送后中止后续处理
|
||
}
|
||
|
||
// == cookie ===
|
||
|
||
// SetSameSite 设置响应的 SameSite cookie 属性
|
||
func (c *Context) SetSameSite(samesite http.SameSite) {
|
||
c.sameSite = samesite
|
||
}
|
||
|
||
// SetCookie 设置一个 HTTP cookie
|
||
// sameSite 参数是可选的,如果不提供则使用通过 SetSameSite 设置的值
|
||
func (c *Context) SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool, sameSite ...http.SameSite) {
|
||
if path == "" {
|
||
path = "/"
|
||
}
|
||
site := c.sameSite
|
||
if len(sameSite) > 0 {
|
||
if len(sameSite) > 1 {
|
||
c.Warnf("SetCookie: only the first SameSite value will be used, got %d values", len(sameSite))
|
||
}
|
||
site = sameSite[0]
|
||
}
|
||
http.SetCookie(c.Writer, &http.Cookie{
|
||
Name: name,
|
||
Value: url.QueryEscape(value),
|
||
MaxAge: maxAge,
|
||
Path: path,
|
||
Domain: domain,
|
||
SameSite: site,
|
||
Secure: secure,
|
||
HttpOnly: httpOnly,
|
||
})
|
||
}
|
||
|
||
func (c *Context) SetCookieData(cookie *http.Cookie) {
|
||
if cookie.Path == "" {
|
||
cookie.Path = "/"
|
||
}
|
||
if cookie.SameSite == http.SameSiteDefaultMode {
|
||
cookie.SameSite = c.sameSite
|
||
}
|
||
http.SetCookie(c.Writer, cookie)
|
||
}
|
||
|
||
// GetCookie 获取指定名称的 cookie 值
|
||
func (c *Context) GetCookie(name string) (string, error) {
|
||
cookie, err := c.Request.Cookie(name)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
// 对 cookie 值进行 URL 解码
|
||
value, err := url.QueryUnescape(cookie.Value)
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to unescape cookie value: %w", err)
|
||
}
|
||
return value, nil
|
||
}
|
||
|
||
// DeleteCookie 删除指定名称的 cookie
|
||
// 通过设置 MaxAge 为 -1 来删除 cookie
|
||
func (c *Context) DeleteCookie(name string) {
|
||
c.SetCookie(name, "", -1, "/", "", false, false) // 设置 MaxAge 为 -1 删除 cookie
|
||
}
|
||
|
||
// === 日志记录 ===
|
||
func (c *Context) Debugf(format string, args ...any) {
|
||
c.engine.LogReco.Debugf(format, args...)
|
||
}
|
||
|
||
func (c *Context) Infof(format string, args ...any) {
|
||
c.engine.LogReco.Infof(format, args...)
|
||
}
|
||
|
||
func (c *Context) Warnf(format string, args ...any) {
|
||
c.engine.LogReco.Warnf(format, args...)
|
||
}
|
||
|
||
func (c *Context) Errorf(format string, args ...any) {
|
||
c.engine.LogReco.Errorf(format, args...)
|
||
}
|
||
|
||
func (c *Context) Fatalf(format string, args ...any) {
|
||
c.engine.LogReco.Fatalf(format, args...)
|
||
}
|
||
|
||
func (c *Context) Panicf(format string, args ...any) {
|
||
c.engine.LogReco.Panicf(format, args...)
|
||
}
|