From 91c50536c49c0a0eee61ca544cfd615695072e5c Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 31 Mar 2026 23:37:02 +0800 Subject: [PATCH] fix(maxreader): avoid hangs after reaching body limit --- context_bodylimit_test.go | 62 +++++++++++++++++++++++++++++++++++++++ maxreader.go | 53 +++++++++++++-------------------- 2 files changed, 83 insertions(+), 32 deletions(-) diff --git a/context_bodylimit_test.go b/context_bodylimit_test.go index 546f06e..37f5e46 100644 --- a/context_bodylimit_test.go +++ b/context_bodylimit_test.go @@ -11,6 +11,32 @@ import ( "testing" ) +type zeroNilThenEOFReader struct { + readCalls int +} + +func (r *zeroNilThenEOFReader) Read(_ []byte) (int, error) { + r.readCalls++ + if r.readCalls == 1 { + return 0, nil + } + return 0, io.EOF +} + +func (r *zeroNilThenEOFReader) Close() error { + return nil +} + +type zeroNilForeverReader struct{} + +func (r *zeroNilForeverReader) Read(_ []byte) (int, error) { + return 0, nil +} + +func (r *zeroNilForeverReader) Close() error { + return nil +} + func TestFileTextUsesProvidedStatusCode(t *testing.T) { t.Helper() @@ -63,6 +89,42 @@ func TestMaxBytesReaderRejectsOverLimit(t *testing.T) { } } +func TestMaxBytesReaderAllowsZeroNilThenEOFAtExactLimit(t *testing.T) { + t.Helper() + + reader := NewMaxBytesReader(&zeroNilThenEOFReader{}, 1) + defer reader.Close() + + buf := make([]byte, 1) + n, err := reader.Read(buf) + if n != 0 || err != nil { + t.Fatalf("expected initial zero,nil read result, got n=%d err=%v", n, err) + } + + n, err = reader.Read(buf) + if n != 0 || !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF after retry, got n=%d err=%v", n, err) + } +} + +func TestMaxBytesReaderRejectsOverLimitWithoutProbeLoop(t *testing.T) { + t.Helper() + + reader := NewMaxBytesReader(&zeroNilForeverReader{}, 0) + defer reader.Close() + + buf := make([]byte, 1) + n, err := reader.Read(buf) + if n != 0 || err != nil { + t.Fatalf("expected initial zero,nil read result, got n=%d err=%v", n, err) + } + + n, err = reader.Read(buf) + if n != 0 || !errors.Is(err, ErrBodyTooLarge) { + t.Fatalf("expected ErrBodyTooLarge after repeated zero,nil reads, got n=%d err=%v", n, err) + } +} + func TestShouldBindJSONHonorsMaxRequestBodySize(t *testing.T) { t.Helper() diff --git a/maxreader.go b/maxreader.go index 96e54c7..8191853 100644 --- a/maxreader.go +++ b/maxreader.go @@ -23,6 +23,8 @@ type maxBytesReader struct { n int64 // read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数. read atomic.Int64 + // emptyAtLimit 记录在达到上限后是否已经遇到过一次 0,nil 读. + emptyAtLimit atomic.Bool } // NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据, @@ -52,14 +54,11 @@ func (mbr *maxBytesReader) Read(p []byte) (int, error) { // 在函数开始时只加载一次原子变量, 减少后续的原子操作开销. readSoFar := mbr.read.Load() - - if readSoFar > mbr.n { + remaining := mbr.n - readSoFar + if remaining < 0 { return 0, ErrBodyTooLarge } - - // 当已恰好读满限制时, 需要探测底层是否还有额外数据. - // 如果下一次读取立即 EOF, 说明请求体大小恰好等于限制, 属于合法情况. - if readSoFar == mbr.n { + if remaining == 0 { var probe [1]byte n, err := mbr.r.Read(probe[:]) if n > 0 { @@ -69,43 +68,33 @@ func (mbr *maxBytesReader) Read(p []byte) (int, error) { if err != nil { return 0, err } - return 0, ErrBodyTooLarge + if mbr.emptyAtLimit.Swap(true) { + return 0, ErrBodyTooLarge + } + return 0, nil } + mbr.emptyAtLimit.Store(false) - // 计算当前还可以读取多少字节. - remaining := mbr.n - readSoFar - - // 如果请求读取的长度大于剩余可读长度, 我们需要限制本次读取的长度. - // 这样可以保证即使 p 很大, 我们也只读取到恰好达到 maxBytes 的字节数. - if int64(len(p)) > remaining { - p = p[:remaining] + // 最多多读一个字节, 以区分“恰好到上限”和“已经超限”。 + if int64(len(p))-1 > remaining { + p = p[:remaining+1] } // 从底层 Reader 读取数据. n, err := mbr.r.Read(p) - // 如果实际读取到了数据, 更新原子计数器. - if n > 0 { - readSoFar = mbr.read.Add(int64(n)) - } - - // 如果底层 Read 返回错误 (例如 io.EOF). - if err != nil { - // 如果是 EOF, 并且我们还没有读满 n 个字节, 这是一个正常的结束. - // 如果已经读满了 n 个字节, 即使是 EOF, 也可以认为成功了. + if int64(n) <= remaining { + if n > 0 { + mbr.read.Add(int64(n)) + } return n, err } - // 读后检查: 如果这次读取使得总字节数超过了限制, 返回超限错误. - // 这是处理"跨越"限制情况的关键. - if readSoFar > mbr.n { - // 返回实际读取的字节数 n, 并附上超限错误. - // 上层调用者知道已经有 n 字节被读入了缓冲区 p, 但流已因超限而关闭. - return n, ErrBodyTooLarge + // 读取结果跨过了限制,只向上层暴露允许的部分。 + if remaining > 0 { + mbr.read.Add(remaining) } - - // 一切正常, 返回读取的字节数和 nil 错误. - return n, nil + return int(remaining), ErrBodyTooLarge } // Close 方法关闭底层的 ReadCloser, 保证资源释放.