feat(http2): support OPTIONS * and extended CONNECT

This commit is contained in:
wjqserver 2026-04-02 03:53:17 +08:00
parent ed44c592d3
commit 2165cc4114
8 changed files with 316 additions and 12 deletions

View file

@ -3,6 +3,7 @@ package touka
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
@ -15,6 +16,8 @@ import (
"strings"
"testing"
"time"
"golang.org/x/net/http2"
)
func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
@ -680,7 +683,7 @@ func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) {
}
}
func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) {
func TestEngineHandlesOptionsAsteriskLocally(t *testing.T) {
t.Helper()
engine := New()
@ -695,9 +698,12 @@ func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) {
rr := httptest.NewRecorder()
engine.ServeHTTP(rr, req)
if rr.Code != http.StatusNotFound {
if rr.Code != http.StatusOK {
t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code)
}
if got := rr.Header().Get("Content-Length"); got != "0" {
t.Fatalf("unexpected Content-Length header: %q", got)
}
}
func TestReverseProxyConnectTunnel(t *testing.T) {
@ -848,6 +854,119 @@ func TestReverseProxyConnectNeedsHijacker(t *testing.T) {
}
}
func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
t.Helper()
enableHTTP2ExtendedConnectProtocol()
errCh := make(chan error, 4)
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if r.ProtoMajor != 2 {
errCh <- fmt.Errorf("unexpected upstream protocol version: %s", r.Proto)
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.Header.Get(":protocol"); got != "websocket" {
errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got)
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.URL.Path; got != "/ws" {
errCh <- fmt.Errorf("unexpected upstream path: %q", got)
w.WriteHeader(http.StatusBadRequest)
return
}
controller := http.NewResponseController(w)
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
return
}
w.WriteHeader(http.StatusOK)
_ = controller.Flush()
line, err := bufio.NewReader(r.Body).ReadString('\n')
if err != nil {
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
return
}
if _, err := io.WriteString(w, "echo:"+line); err != nil {
errCh <- fmt.Errorf("write tunneled response body failed: %w", err)
return
}
_ = controller.Flush()
}))
upstream.EnableHTTP2 = true
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
t.Fatalf("configure upstream HTTP/2 server: %v", err)
}
upstream.StartTLS()
defer upstream.Close()
engine := New()
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, upstream.URL),
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
Via: "proxy.test",
}))
proxy := httptest.NewUnstartedServer(engine)
proxy.EnableHTTP2 = true
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
t.Fatalf("configure proxy HTTP/2 server: %v", err)
}
proxy.StartTLS()
defer proxy.Close()
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
defer transport.CloseIdleConnections()
pr, pw := io.Pipe()
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
if err != nil {
t.Fatalf("new CONNECT request: %v", err)
}
req.Header.Set(":protocol", "websocket")
resp, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("round trip extended CONNECT: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "2.0 proxy.test" {
t.Fatalf("unexpected Via response header: %#v", gotVia)
}
if _, err := io.WriteString(pw, "ping\n"); err != nil {
t.Fatalf("write tunneled request body: %v", err)
}
message, err := bufio.NewReader(resp.Body).ReadString('\n')
if err != nil {
t.Fatalf("read tunneled response body: %v", err)
}
if message != "echo:ping\n" {
t.Fatalf("unexpected tunneled response body: %q", message)
}
if err := pw.Close(); err != nil {
t.Fatalf("close tunneled request body: %v", err)
}
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}
func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) {
t.Helper()