diff --git a/context.go b/context.go index 0620c41..c9586dd 100644 --- a/context.go +++ b/context.go @@ -9,10 +9,13 @@ import ( "html/template" "io" "math" + "mime" "net" "net/http" "net/netip" "net/url" + "os" + "path" "strings" "sync" "time" @@ -649,6 +652,56 @@ func (c *Context) GetRequestURIPath() string { return c.Request.URL.Path } +// === 文件操作 === + +// 将文件内容作为响应body +func (c *Context) SetRespBodyFile(code int, filePath string) { + // 清理path + cleanPath := path.Clean(filePath) + + // 打开文件 + file, err := os.Open(cleanPath) + if err != nil { + c.AddError(fmt.Errorf("failed to open file %s: %w", cleanPath, err)) + c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to open file: %w", err)) + return + } + defer file.Close() + + // 获取文件信息以获取文件大小和MIME类型 + fileInfo, err := file.Stat() + if err != nil { + c.AddError(fmt.Errorf("failed to get file info for %s: %w", cleanPath, err)) + c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to get file info: %w", err)) + return + } + + // 尝试根据文件扩展名猜测 Content-Type + contentType := mime.TypeByExtension(path.Ext(cleanPath)) + if contentType == "" { + // 如果无法猜测,则使用默认的二进制流类型 + contentType = "application/octet-stream" + } + + // 设置响应头 + c.Writer.Header().Set("Content-Type", contentType) + c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size())) + // 还可以设置 Content-Disposition 来控制浏览器是下载还是直接显示 + // c.Writer.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, path.Base(cleanPath))) + + // 设置状态码 + c.Writer.WriteHeader(code) + + // 将文件内容写入响应体 + _, err = copyb.Copy(c.Writer, file) + if err != nil { + c.AddError(fmt.Errorf("failed to write file %s to response: %w", cleanPath, err)) + // 注意:这里可能无法设置错误状态码,因为头部可能已经发送 + // 可以在调用 SetRespBodyFile 之前检查错误,或者在中间件中处理 Context.Errors + } + c.Abort() // 文件发送后中止后续处理 +} + // == cookie === // SetSameSite 设置响应的 SameSite cookie 属性 diff --git a/engine.go b/engine.go index fb18274..babe365 100644 --- a/engine.go +++ b/engine.go @@ -121,6 +121,21 @@ func defaultErrorWarp(handler ErrorHandler) ErrorHandler { return } } + // 查看context内有没有收集到error + if len(c.Errors) > 0 { + c.Errorf("errpage: context errors: %v, current error: %v", errors.Join(c.Errors...), err) + if err == nil { + err = errors.Join(c.Errors...) + } + } + // 如果客户端已经断开连接,则不尝试写入响应 + // 避免在客户端已关闭连接后写入响应导致的问题 + // 检查 context.Context 是否已取消 + if errors.Is(c.Request.Context().Err(), context.Canceled) { + log.Printf("errpage: client disconnected, skipping error page rendering for status %d, err: %v", code, err) + return + } + handler(c, code, err) } }