Эх сурвалжийг харах

Remove upstreamproxy.transportConn

- upstreamproxy.transportConn is not necessary, as ProxyAuthTransport
  is actually capable of handling NTLM/connection-based authentication
  itself -- with some practical limitations (now documented).

- Mitigates all upstreamproxy.transportConn issues noted here:
  https://github.com/Psiphon-Labs/psiphon-tunnel-core/blob/48611aff343c3dd4f98c5e0d5d43c87af990afaf/psiphon/upstreamproxy/transport_proxy_auth.go#L227

- Cleanup/fix some issues in ProxyAuthTransport, including properly
  closing all Response.Body instances when taking the 407 code path,
  and reusing buffer for request body caching to avoid excessive
  memory allocation.
Rod Hynes 8 жил өмнө
parent
commit
407edebf04

+ 78 - 253
psiphon/upstreamproxy/transport_proxy_auth.go

@@ -20,61 +20,41 @@
 package upstreamproxy
 
 import (
-	"bufio"
 	"bytes"
-	"context"
 	"fmt"
-	"io"
 	"io/ioutil"
-	"net"
 	"net/http"
-	"strings"
-	"sync"
 )
 
-const HTTP_STAT_LINE_LENGTH = 12
-
 // ProxyAuthTransport provides support for proxy authentication when doing plain HTTP
 // by tapping into HTTP conversation and adding authentication headers to the requests
 // when requested by server
+//
+// Limitation: in violation of https://golang.org/pkg/net/http/#RoundTripper,
+// ProxyAuthTransport is _not_ safe for concurrent RoundTrip calls. This is acceptable
+// for its use in Psiphon to provide upstream proxy support for meek, which makes only
+// serial RoundTrip calls. Concurrent RoundTrip calls will result in data race conditions
+// and undefined behavior during an authentication handshake.
 type ProxyAuthTransport struct {
 	*http.Transport
-	Username      string
-	Password      string
-	Authenticator HttpAuthenticator
-	mu            sync.Mutex
-	CustomHeaders http.Header
+	username         string
+	password         string
+	authenticator    HttpAuthenticator
+	customHeaders    http.Header
+	clonedBodyBuffer bytes.Buffer
 }
 
 func NewProxyAuthTransport(
 	rawTransport *http.Transport,
 	customHeaders http.Header) (*ProxyAuthTransport, error) {
 
-	if rawTransport.DialContext == nil {
-		return nil, fmt.Errorf("rawTransport must have DialContext")
-	}
-
 	if rawTransport.Proxy == nil {
 		return nil, fmt.Errorf("rawTransport must have Proxy")
 	}
 
 	tr := &ProxyAuthTransport{
 		Transport:     rawTransport,
-		CustomHeaders: customHeaders,
-	}
-
-	// Wrap the original transport's custom dialed conns in transportConns,
-	// which handle connection-based authentication.
-	originalDialContext := rawTransport.DialContext
-	rawTransport.DialContext = func(
-		ctx context.Context, network, addr string) (net.Conn, error) {
-		conn, err := originalDialContext(ctx, "tcp", addr)
-		if err != nil {
-			return nil, err
-		}
-		// Any additional dials made by transportConn are within
-		// the original dial context.
-		return newTransportConn(ctx, conn, tr), nil
+		customHeaders: customHeaders,
 	}
 
 	proxyUrl, err := rawTransport.Proxy(nil)
@@ -85,8 +65,8 @@ func NewProxyAuthTransport(
 		return nil, fmt.Errorf("%s unsupported", proxyUrl.Scheme)
 	}
 	if proxyUrl.User != nil {
-		tr.Username = proxyUrl.User.Username()
-		tr.Password, _ = proxyUrl.User.Password()
+		tr.username = proxyUrl.User.Username()
+		tr.password, _ = proxyUrl.User.Password()
 	}
 	// strip username and password from the proxyURL because
 	// we do not want the wrapped transport to handle authentication
@@ -96,75 +76,97 @@ func NewProxyAuthTransport(
 	return tr, nil
 }
 
-func (tr *ProxyAuthTransport) preAuthenticateRequest(req *http.Request) error {
-	tr.mu.Lock()
-	defer tr.mu.Unlock()
-	if tr.Authenticator == nil {
-		return nil
-	}
-	return tr.Authenticator.PreAuthenticate(req)
-}
+func (tr *ProxyAuthTransport) RoundTrip(request *http.Request) (*http.Response, error) {
 
-func (tr *ProxyAuthTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
-	if req.URL.Scheme != "http" {
-		return nil, fmt.Errorf("%s unsupported", req.URL.Scheme)
+	if request.URL.Scheme != "http" {
+		return nil, fmt.Errorf("%s unsupported", request.URL.Scheme)
 	}
-	err = tr.preAuthenticateRequest(req)
-	if err != nil {
-		return nil, err
+
+	// Notes:
+	//
+	// - The 407 authentication loop assumes no concurrent calls of RoundTrip
+	//   and additionally assumes that serial RoundTrip calls will always
+	//   resuse any existing HTTP persistent conn. The entire authentication
+	//   handshake must occur on the same HTTP persistent conn.
+	//
+	// - Requests are cloned for the lifetime of the ProxyAuthTransport,
+	//   since we don't know when the next initial RoundTrip may need to enter
+	//   the 407 authentication loop, which requires the initial request to be
+	//   cloned and replayable. Even if we hook into the Close call for any
+	//   existing HTTP persistent conn, it could be that it closes only after
+	//   RoundTrip is called.
+	//
+	// - Cloning reuses a buffer (clonedBodyBuffer) to store the request body
+	//   to avoid excessive allocations.
+
+	var cachedRequestBody []byte
+	if request.Body != nil {
+		tr.clonedBodyBuffer.Reset()
+		tr.clonedBodyBuffer.ReadFrom(request.Body)
+		request.Body.Close()
+		cachedRequestBody = tr.clonedBodyBuffer.Bytes()
 	}
 
-	var ha HttpAuthenticator
+	clonedRequest := cloneRequest(
+		request, tr.customHeaders, cachedRequestBody)
 
-	// Clone request early because RoundTrip will destroy request Body
-	// Also add custom headers to the cloned request
-	newReq := cloneRequest(req, tr.CustomHeaders)
+	if tr.authenticator != nil {
 
-	resp, err = tr.Transport.RoundTrip(newReq)
+		// For some authentication schemes (e.g., non-connection-based), once
+		// an initial 407 has been handled, add necessary and sufficient
+		// authentication headers to every request.
+
+		err := tr.authenticator.PreAuthenticate(clonedRequest)
+		if err != nil {
+			return nil, err
+		}
+	}
 
+	response, err := tr.Transport.RoundTrip(clonedRequest)
 	if err != nil {
-		return resp, proxyError(err)
+		return response, proxyError(err)
 	}
 
-	if resp.StatusCode == 407 {
-		tr.mu.Lock()
-		defer tr.mu.Unlock()
-		ha, err = NewHttpAuthenticator(resp, tr.Username, tr.Password)
+	if response.StatusCode == 407 {
+
+		authenticator, err := NewHttpAuthenticator(
+			response, tr.username, tr.password)
 		if err != nil {
+			response.Body.Close()
 			return nil, err
 		}
-		if ha.IsConnectionBased() {
-			return nil, proxyError(fmt.Errorf("Connection based auth was not handled by transportConn!"))
-		}
-		tr.Authenticator = ha
-	authenticationLoop:
+
 		for {
-			newReq = cloneRequest(req, tr.CustomHeaders)
-			err = tr.Authenticator.Authenticate(newReq, resp)
+			clonedRequest = cloneRequest(
+				request, tr.customHeaders, cachedRequestBody)
+
+			err = authenticator.Authenticate(clonedRequest, response)
+			response.Body.Close()
 			if err != nil {
 				return nil, err
 			}
-			resp, err = tr.Transport.RoundTrip(newReq)
 
+			response, err = tr.Transport.RoundTrip(clonedRequest)
 			if err != nil {
-				return resp, proxyError(err)
+				return nil, proxyError(err)
 			}
-			if resp.StatusCode != 407 {
-				if tr.Authenticator != nil && tr.Authenticator.IsComplete() {
-					tr.Authenticator.Reset()
-				}
-				break authenticationLoop
-			} else {
+
+			if response.StatusCode != 407 {
+
+				// Save the authenticator result to use for PreAuthenticate.
+
+				tr.authenticator = authenticator
+				break
 			}
 		}
 	}
-	return resp, err
 
+	return response, nil
 }
 
 // Based on https://github.com/golang/oauth2/blob/master/transport.go
 // Copyright 2014 The Go Authors. All rights reserved.
-func cloneRequest(r *http.Request, ch http.Header) *http.Request {
+func cloneRequest(r *http.Request, ch http.Header, body []byte) *http.Request {
 	// shallow copy of the struct
 	r2 := new(http.Request)
 	*r2 = *r
@@ -192,13 +194,7 @@ func cloneRequest(r *http.Request, ch http.Header) *http.Request {
 		}
 	}
 
-	if r.Body != nil {
-		body, _ := ioutil.ReadAll(r.Body)
-		defer r.Body.Close()
-		// restore original request Body
-		// drained by ReadAll()
-		r.Body = ioutil.NopCloser(bytes.NewReader(body))
-
+	if body != nil {
 		r2.Body = ioutil.NopCloser(bytes.NewReader(body))
 	}
 
@@ -207,174 +203,3 @@ func cloneRequest(r *http.Request, ch http.Header) *http.Request {
 
 	return r2
 }
-
-type transportConn struct {
-	net.Conn
-	ctx                context.Context
-	requestInterceptor io.Writer
-	reqDone            chan struct{}
-	errChannel         chan error
-	lastRequest        *http.Request
-	authenticator      HttpAuthenticator
-	transport          *ProxyAuthTransport
-}
-
-func newTransportConn(
-	ctx context.Context,
-	c net.Conn,
-	tr *ProxyAuthTransport) *transportConn {
-
-	// TODOs:
-	//
-	// - Additional dials made by transportConn, for authentication, use the
-	//   original conn's dial context. If authentication can be requested at any
-	//   time, instead of just at the start of a connection, then any deadline for
-	//   this context will be inappropriate.
-	//
-	// - The "intercept" goroutine spawned below will never terminate? Even if the
-	//   transportConn is closed, nothing will unblock reads of the pipe made by
-	//   http.ReadRequest. There should be a call to pw.Close() in transportConn.Close().
-	//
-	// - The ioutil.ReadAll in the "intercept" goroutine allocates new buffers for
-	//   every request. To avoid GC churn it should use a byte.Buffer to reuse a
-	//   single buffer. In practise, there will be a reasonably small maximum request
-	//   body size, so its better to retain and reuse a buffer than to continously
-	//   reallocate.
-	//
-	// - transportConn.Read will not do anything if the caller passes in a very small
-	//   read buffer. This should be documented, as its assuming that the caller is
-	//   fully reading at least HTTP_STAT_LINE_LENGTH at the start of request.
-	//
-	// - As a net.Conn, transportConn.Read should always be interrupted by a call to
-	//   Close, but it may be possible for Read to remain blocked:
-	//   1. caller writes less than a full request to Write
-	//   2. "intercept" call to http.ReadRequest will not return
-	//   3. caller calls Close, which just calls transportConn.Conn.Close
-	//   4. any existing call to Read remains blocked in the select
-
-	tc := &transportConn{
-		Conn:       c,
-		ctx:        ctx,
-		reqDone:    make(chan struct{}),
-		errChannel: make(chan error),
-		transport:  tr,
-	}
-	// Intercept outgoing request as it is written out to server and store it
-	// in case it needs to be authenticated and replayed
-	//NOTE that pipelining is currently not supported
-	pr, pw := io.Pipe()
-	tc.requestInterceptor = pw
-	requestReader := bufio.NewReader(pr)
-	go func() {
-	requestInterceptLoop:
-		for {
-			req, err := http.ReadRequest(requestReader)
-			if err != nil {
-				tc.Conn.Close()
-				pr.Close()
-				pw.Close()
-				tc.errChannel <- fmt.Errorf("intercept request loop http.ReadRequest error: %s", err)
-				break requestInterceptLoop
-			}
-			//read and copy entire body
-			body, _ := ioutil.ReadAll(req.Body)
-			tc.lastRequest = req
-			tc.lastRequest.Body = ioutil.NopCloser(bytes.NewReader(body))
-			//Signal when we have a complete request
-			tc.reqDone <- struct{}{}
-		}
-	}()
-	return tc
-}
-
-// Read peeks into the new response and checks if the proxy requests authentication
-// If so, the last intercepted request is authenticated against the response
-// in case of connection based auth scheme(i.e. NTLM)
-// All the non-connection based schemes are handled by the ProxyAuthTransport.RoundTrip()
-func (tc *transportConn) Read(p []byte) (n int, readErr error) {
-	n, readErr = tc.Conn.Read(p)
-	if n < HTTP_STAT_LINE_LENGTH {
-		return
-	}
-	select {
-	case _ = <-tc.reqDone:
-		line := string(p[:HTTP_STAT_LINE_LENGTH])
-		//This is a new response
-		//Let's see if proxy requests authentication
-		f := strings.SplitN(line, " ", 2)
-
-		readBufferReader := io.NewSectionReader(bytes.NewReader(p), 0, int64(n))
-		responseReader := bufio.NewReader(readBufferReader)
-		if (f[0] == "HTTP/1.0" || f[0] == "HTTP/1.1") && f[1] == "407" {
-			resp, err := http.ReadResponse(responseReader, nil)
-			if err != nil {
-				return 0, err
-			}
-			ha, err := NewHttpAuthenticator(resp, tc.transport.Username, tc.transport.Password)
-			if err != nil {
-				return 0, err
-			}
-			// If connection based auth is requested, we are going to
-			// authenticate request on this very connection
-			// otherwise just return what we read
-			if !ha.IsConnectionBased() {
-				return
-			}
-
-			// Drain the rest of the response
-			// in order to perform auth handshake
-			// on the connection
-			readBufferReader.Seek(0, 0)
-			responseReader = bufio.NewReader(io.MultiReader(readBufferReader, tc.Conn))
-			resp, err = http.ReadResponse(responseReader, nil)
-			if err != nil {
-				return 0, err
-			}
-
-			ioutil.ReadAll(resp.Body)
-			resp.Body.Close()
-
-			if tc.authenticator == nil {
-				tc.authenticator = ha
-			}
-
-			if resp.Close == true {
-				// Server side indicated that it is closing this connection,
-				// dial a new one
-				addr := tc.Conn.RemoteAddr()
-				tc.Conn.Close()
-
-				// Additional dials are made within the context of the dial of the
-				// outer conn this transportConn is wrapping, so the scope of outer
-				// dial timeouts includes these additional dials. This is also to
-				// ensure these dials are interrupted when the context is canceled.
-
-				tc.Conn, err = tc.transport.Transport.DialContext(
-					tc.ctx, addr.Network(), addr.String())
-
-				if err != nil {
-					return 0, err
-				}
-			}
-
-			// Authenticate and replay the request on the connection
-			err = tc.authenticator.Authenticate(tc.lastRequest, resp)
-			if err != nil {
-				return 0, err
-			}
-			tc.lastRequest.WriteProxy(tc)
-			return tc.Read(p)
-		}
-	case err := <-tc.errChannel:
-		return 0, err
-	default:
-	}
-	return
-}
-
-func (tc *transportConn) Write(p []byte) (n int, err error) {
-	n, err = tc.Conn.Write(p)
-	//also write data to the request interceptor
-	tc.requestInterceptor.Write(p[:n])
-	return n, err
-}