mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
fix(maxreader): avoid hangs after reaching body limit
This commit is contained in:
parent
85cc9b5cf6
commit
91c50536c4
2 changed files with 83 additions and 32 deletions
|
|
@ -11,6 +11,32 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
type zeroNilThenEOFReader struct {
|
||||
readCalls int
|
||||
}
|
||||
|
||||
func (r *zeroNilThenEOFReader) Read(_ []byte) (int, error) {
|
||||
r.readCalls++
|
||||
if r.readCalls == 1 {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (r *zeroNilThenEOFReader) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type zeroNilForeverReader struct{}
|
||||
|
||||
func (r *zeroNilForeverReader) Read(_ []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *zeroNilForeverReader) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestFileTextUsesProvidedStatusCode(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -63,6 +89,42 @@ func TestMaxBytesReaderRejectsOverLimit(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestMaxBytesReaderAllowsZeroNilThenEOFAtExactLimit(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
reader := NewMaxBytesReader(&zeroNilThenEOFReader{}, 1)
|
||||
defer reader.Close()
|
||||
|
||||
buf := make([]byte, 1)
|
||||
n, err := reader.Read(buf)
|
||||
if n != 0 || err != nil {
|
||||
t.Fatalf("expected initial zero,nil read result, got n=%d err=%v", n, err)
|
||||
}
|
||||
|
||||
n, err = reader.Read(buf)
|
||||
if n != 0 || !errors.Is(err, io.EOF) {
|
||||
t.Fatalf("expected EOF after retry, got n=%d err=%v", n, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxBytesReaderRejectsOverLimitWithoutProbeLoop(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
reader := NewMaxBytesReader(&zeroNilForeverReader{}, 0)
|
||||
defer reader.Close()
|
||||
|
||||
buf := make([]byte, 1)
|
||||
n, err := reader.Read(buf)
|
||||
if n != 0 || err != nil {
|
||||
t.Fatalf("expected initial zero,nil read result, got n=%d err=%v", n, err)
|
||||
}
|
||||
|
||||
n, err = reader.Read(buf)
|
||||
if n != 0 || !errors.Is(err, ErrBodyTooLarge) {
|
||||
t.Fatalf("expected ErrBodyTooLarge after repeated zero,nil reads, got n=%d err=%v", n, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldBindJSONHonorsMaxRequestBodySize(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
|
|||
53
maxreader.go
53
maxreader.go
|
|
@ -23,6 +23,8 @@ type maxBytesReader struct {
|
|||
n int64
|
||||
// read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数.
|
||||
read atomic.Int64
|
||||
// emptyAtLimit 记录在达到上限后是否已经遇到过一次 0,nil 读.
|
||||
emptyAtLimit atomic.Bool
|
||||
}
|
||||
|
||||
// NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据,
|
||||
|
|
@ -52,14 +54,11 @@ func (mbr *maxBytesReader) Read(p []byte) (int, error) {
|
|||
|
||||
// 在函数开始时只加载一次原子变量, 减少后续的原子操作开销.
|
||||
readSoFar := mbr.read.Load()
|
||||
|
||||
if readSoFar > mbr.n {
|
||||
remaining := mbr.n - readSoFar
|
||||
if remaining < 0 {
|
||||
return 0, ErrBodyTooLarge
|
||||
}
|
||||
|
||||
// 当已恰好读满限制时, 需要探测底层是否还有额外数据.
|
||||
// 如果下一次读取立即 EOF, 说明请求体大小恰好等于限制, 属于合法情况.
|
||||
if readSoFar == mbr.n {
|
||||
if remaining == 0 {
|
||||
var probe [1]byte
|
||||
n, err := mbr.r.Read(probe[:])
|
||||
if n > 0 {
|
||||
|
|
@ -69,43 +68,33 @@ func (mbr *maxBytesReader) Read(p []byte) (int, error) {
|
|||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return 0, ErrBodyTooLarge
|
||||
if mbr.emptyAtLimit.Swap(true) {
|
||||
return 0, ErrBodyTooLarge
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
mbr.emptyAtLimit.Store(false)
|
||||
|
||||
// 计算当前还可以读取多少字节.
|
||||
remaining := mbr.n - readSoFar
|
||||
|
||||
// 如果请求读取的长度大于剩余可读长度, 我们需要限制本次读取的长度.
|
||||
// 这样可以保证即使 p 很大, 我们也只读取到恰好达到 maxBytes 的字节数.
|
||||
if int64(len(p)) > remaining {
|
||||
p = p[:remaining]
|
||||
// 最多多读一个字节, 以区分“恰好到上限”和“已经超限”。
|
||||
if int64(len(p))-1 > remaining {
|
||||
p = p[:remaining+1]
|
||||
}
|
||||
|
||||
// 从底层 Reader 读取数据.
|
||||
n, err := mbr.r.Read(p)
|
||||
|
||||
// 如果实际读取到了数据, 更新原子计数器.
|
||||
if n > 0 {
|
||||
readSoFar = mbr.read.Add(int64(n))
|
||||
}
|
||||
|
||||
// 如果底层 Read 返回错误 (例如 io.EOF).
|
||||
if err != nil {
|
||||
// 如果是 EOF, 并且我们还没有读满 n 个字节, 这是一个正常的结束.
|
||||
// 如果已经读满了 n 个字节, 即使是 EOF, 也可以认为成功了.
|
||||
if int64(n) <= remaining {
|
||||
if n > 0 {
|
||||
mbr.read.Add(int64(n))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// 读后检查: 如果这次读取使得总字节数超过了限制, 返回超限错误.
|
||||
// 这是处理"跨越"限制情况的关键.
|
||||
if readSoFar > mbr.n {
|
||||
// 返回实际读取的字节数 n, 并附上超限错误.
|
||||
// 上层调用者知道已经有 n 字节被读入了缓冲区 p, 但流已因超限而关闭.
|
||||
return n, ErrBodyTooLarge
|
||||
// 读取结果跨过了限制,只向上层暴露允许的部分。
|
||||
if remaining > 0 {
|
||||
mbr.read.Add(remaining)
|
||||
}
|
||||
|
||||
// 一切正常, 返回读取的字节数和 nil 错误.
|
||||
return n, nil
|
||||
return int(remaining), ErrBodyTooLarge
|
||||
}
|
||||
|
||||
// Close 方法关闭底层的 ReadCloser, 保证资源释放.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue