mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
feat(http2): support OPTIONS * and extended CONNECT
This commit is contained in:
parent
ed44c592d3
commit
2165cc4114
8 changed files with 316 additions and 12 deletions
|
|
@ -3,6 +3,7 @@ package touka
|
|||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
|
@ -15,6 +16,8 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
||||
|
|
@ -680,7 +683,7 @@ func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) {
|
||||
func TestEngineHandlesOptionsAsteriskLocally(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
engine := New()
|
||||
|
|
@ -695,9 +698,12 @@ func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) {
|
|||
rr := httptest.NewRecorder()
|
||||
engine.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusNotFound {
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code)
|
||||
}
|
||||
if got := rr.Header().Get("Content-Length"); got != "0" {
|
||||
t.Fatalf("unexpected Content-Length header: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyConnectTunnel(t *testing.T) {
|
||||
|
|
@ -848,6 +854,119 @@ func TestReverseProxyConnectNeedsHijacker(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if r.ProtoMajor != 2 {
|
||||
errCh <- fmt.Errorf("unexpected upstream protocol version: %s", r.Proto)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get(":protocol"); got != "websocket" {
|
||||
errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.URL.Path; got != "/ws" {
|
||||
errCh <- fmt.Errorf("unexpected upstream path: %q", got)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(w)
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = controller.Flush()
|
||||
|
||||
line, err := bufio.NewReader(r.Body).ReadString('\n')
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
|
||||
return
|
||||
}
|
||||
if _, err := io.WriteString(w, "echo:"+line); err != nil {
|
||||
errCh <- fmt.Errorf("write tunneled response body failed: %w", err)
|
||||
return
|
||||
}
|
||||
_ = controller.Flush()
|
||||
}))
|
||||
upstream.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
|
||||
t.Fatalf("configure upstream HTTP/2 server: %v", err)
|
||||
}
|
||||
upstream.StartTLS()
|
||||
defer upstream.Close()
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, upstream.URL),
|
||||
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||
Via: "proxy.test",
|
||||
}))
|
||||
|
||||
proxy := httptest.NewUnstartedServer(engine)
|
||||
proxy.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
||||
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
||||
}
|
||||
proxy.StartTLS()
|
||||
defer proxy.Close()
|
||||
|
||||
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
defer transport.CloseIdleConnections()
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
||||
if err != nil {
|
||||
t.Fatalf("new CONNECT request: %v", err)
|
||||
}
|
||||
req.Header.Set(":protocol", "websocket")
|
||||
|
||||
resp, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("round trip extended CONNECT: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "2.0 proxy.test" {
|
||||
t.Fatalf("unexpected Via response header: %#v", gotVia)
|
||||
}
|
||||
|
||||
if _, err := io.WriteString(pw, "ping\n"); err != nil {
|
||||
t.Fatalf("write tunneled request body: %v", err)
|
||||
}
|
||||
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("read tunneled response body: %v", err)
|
||||
}
|
||||
if message != "echo:ping\n" {
|
||||
t.Fatalf("unexpected tunneled response body: %q", message)
|
||||
}
|
||||
if err := pw.Close(); err != nil {
|
||||
t.Fatalf("close tunneled request body: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue