mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
refactor: improve TLS config handling and add bridge connection tests
This commit is contained in:
parent
dcdb1504a3
commit
d53693952a
3 changed files with 153 additions and 7 deletions
|
|
@ -51,9 +51,10 @@ func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripp
|
||||||
transport := cloneDefaultTransport()
|
transport := cloneDefaultTransport()
|
||||||
transport.Protocols = new(http.Protocols)
|
transport.Protocols = new(http.Protocols)
|
||||||
transport.Protocols.SetHTTP1(true)
|
transport.Protocols.SetHTTP1(true)
|
||||||
transport.TLSClientConfig = tlsConfig
|
if tlsConfig == nil {
|
||||||
if transport.TLSClientConfig == nil {
|
|
||||||
transport.TLSClientConfig = &tls.Config{}
|
transport.TLSClientConfig = &tls.Config{}
|
||||||
|
} else {
|
||||||
|
transport.TLSClientConfig = tlsConfig.Clone()
|
||||||
}
|
}
|
||||||
if len(transport.TLSClientConfig.NextProtos) == 0 {
|
if len(transport.TLSClientConfig.NextProtos) == 0 {
|
||||||
transport.TLSClientConfig.NextProtos = []string{"http/1.1"}
|
transport.TLSClientConfig.NextProtos = []string{"http/1.1"}
|
||||||
|
|
|
||||||
|
|
@ -1024,7 +1024,6 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r
|
||||||
}()
|
}()
|
||||||
defer close(backConnClosed)
|
defer close(backConnClosed)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
defer backConn.Close()
|
|
||||||
|
|
||||||
errc := make(chan error, 2)
|
errc := make(chan error, 2)
|
||||||
copyer := switchProtocolCopier{user: conn, backend: backConn}
|
copyer := switchProtocolCopier{user: conn, backend: backConn}
|
||||||
|
|
@ -1374,11 +1373,11 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) {
|
func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) {
|
||||||
protocol := reverseProxyExtendedConnectProtocol(req)
|
|
||||||
if req == nil || req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") {
|
|
||||||
if req == nil {
|
if req == nil {
|
||||||
return context.Background(), false, nil
|
return context.Background(), false, nil
|
||||||
}
|
}
|
||||||
|
protocol := reverseProxyExtendedConnectProtocol(req)
|
||||||
|
if req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") {
|
||||||
return req.Context(), false, nil
|
return req.Context(), false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -933,6 +933,29 @@ func TestHTTP2TransportBuildersDoNotPanicWhenDefaultTransportIsCustom(t *testing
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewHTTP1BridgeTransportWithTLSConfigClonesInput(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{InsecureSkipVerify: true}
|
||||||
|
rt := newHTTP1BridgeTransportWithTLSConfig(tlsConfig)
|
||||||
|
transport, ok := rt.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("unexpected transport type: %T", rt)
|
||||||
|
}
|
||||||
|
if transport.TLSClientConfig == nil {
|
||||||
|
t.Fatal("expected TLS client config")
|
||||||
|
}
|
||||||
|
if transport.TLSClientConfig == tlsConfig {
|
||||||
|
t.Fatal("expected bridge transport to clone TLS config")
|
||||||
|
}
|
||||||
|
if len(tlsConfig.NextProtos) != 0 {
|
||||||
|
t.Fatalf("input TLS config was mutated: %#v", tlsConfig.NextProtos)
|
||||||
|
}
|
||||||
|
if got := transport.TLSClientConfig.NextProtos; len(got) != 1 || got[0] != "http/1.1" {
|
||||||
|
t.Fatalf("unexpected transport NextProtos: %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) {
|
func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|
@ -1633,6 +1656,98 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
enableHTTP2ExtendedConnectProtocol()
|
||||||
|
|
||||||
|
closeCalls := atomic.Int32{}
|
||||||
|
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
if req.Method != http.MethodGet {
|
||||||
|
return nil, fmt.Errorf("unexpected upstream method: %s", req.Method)
|
||||||
|
}
|
||||||
|
backend := &countingReadWriteCloser{
|
||||||
|
readData: []byte("echo:ping\n"),
|
||||||
|
closeCalls: &closeCalls,
|
||||||
|
closeWriteErr: http.ErrNotSupported,
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusSwitchingProtocols,
|
||||||
|
Header: http.Header{
|
||||||
|
"Connection": []string{"Upgrade"},
|
||||||
|
"Upgrade": []string{"websocket"},
|
||||||
|
"Sec-WebSocket-Accept": []string{"ignored"},
|
||||||
|
},
|
||||||
|
Body: backend,
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: mustParseURL(t, "http://example.com"),
|
||||||
|
Transport: transport,
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewUnstartedServer(engine)
|
||||||
|
proxy.EnableHTTP2 = true
|
||||||
|
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
||||||
|
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
||||||
|
}
|
||||||
|
proxy.StartTLS()
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
clientTransport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||||
|
defer clientTransport.CloseIdleConnections()
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new CONNECT request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set(":protocol", "websocket")
|
||||||
|
|
||||||
|
resp, err := clientTransport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("round trip extended CONNECT: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if _, err := io.WriteString(pw, "ping\n"); err != nil {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
t.Fatalf("write tunneled request body: %v", err)
|
||||||
|
}
|
||||||
|
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
t.Fatalf("read tunneled response body: %v", err)
|
||||||
|
}
|
||||||
|
if message != "echo:ping\n" {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
t.Fatalf("unexpected tunneled response body: %q", message)
|
||||||
|
}
|
||||||
|
if err := pw.Close(); err != nil {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
t.Fatalf("close tunneled request body: %v", err)
|
||||||
|
}
|
||||||
|
if err := resp.Body.Close(); err != nil {
|
||||||
|
t.Fatalf("close response body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if closeCalls.Load() > 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
if got := closeCalls.Load(); got != 1 {
|
||||||
|
t.Fatalf("expected backend connection to close exactly once, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) {
|
func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|
@ -2311,6 +2426,37 @@ func (r errorReader) Read([]byte) (int, error) {
|
||||||
return 0, r.err
|
return 0, r.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type countingReadWriteCloser struct {
|
||||||
|
readData []byte
|
||||||
|
writeBuf bytes.Buffer
|
||||||
|
closeCalls *atomic.Int32
|
||||||
|
closeWriteErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *countingReadWriteCloser) Read(p []byte) (int, error) {
|
||||||
|
if len(r.readData) == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n := copy(p, r.readData)
|
||||||
|
r.readData = r.readData[n:]
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *countingReadWriteCloser) Write(p []byte) (int, error) {
|
||||||
|
return r.writeBuf.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *countingReadWriteCloser) Close() error {
|
||||||
|
if r.closeCalls != nil {
|
||||||
|
r.closeCalls.Add(1)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *countingReadWriteCloser) CloseWrite() error {
|
||||||
|
return r.closeWriteErr
|
||||||
|
}
|
||||||
|
|
||||||
func mustParseURL(t *testing.T, raw string) *url.URL {
|
func mustParseURL(t *testing.T, raw string) *url.URL {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
u, err := url.Parse(raw)
|
u, err := url.Parse(raw)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue