Răsfoiți Sursa

Merge pull request #431 from rod-hynes/master

 Remove upstreamproxy.transportConn
Rod Hynes 8 ani în urmă
părinte
comite
7e228eb375
2 a modificat fișierele cu 88 adăugiri și 259 ștergeri
  1. 10 6
      psiphon/common/osl/osl.go
  2. 78 253
      psiphon/upstreamproxy/transport_proxy_auth.go

+ 10 - 6
psiphon/common/osl/osl.go

@@ -1389,12 +1389,12 @@ func NewOSLReader(
 		return nil, common.ContextError(errors.New("unseeded OSL"))
 	}
 
-	if len(fileKey) != 32 {
+	if len(fileKey) != KEY_LENGTH_BYTES {
 		return nil, common.ContextError(errors.New("invalid key length"))
 	}
 
 	var nonce [24]byte
-	var key [32]byte
+	var key [KEY_LENGTH_BYTES]byte
 	copy(key[:], fileKey)
 
 	unboxer, err := secretbox.NewOpenReadSeeker(oslFileContent, &nonce, &key)
@@ -1426,6 +1426,10 @@ func (z *zeroReader) Read(p []byte) (int, error) {
 // purpose CSPRNG.
 func newSeededKeyMaterialReader(seed []byte) (io.Reader, error) {
 
+	if len(seed) != KEY_LENGTH_BYTES {
+		return nil, common.ContextError(errors.New("invalid key length"))
+	}
+
 	aesCipher, err := aes.NewCipher(seed)
 	if err != nil {
 		return nil, common.ContextError(err)
@@ -1516,11 +1520,11 @@ func shamirCombine(shares [][]byte) []byte {
 // A constant nonce is used, which is secure so long as
 // each key is used to encrypt only one message.
 func box(key, plaintext []byte) ([]byte, error) {
-	if len(key) != 32 {
+	if len(key) != KEY_LENGTH_BYTES {
 		return nil, common.ContextError(errors.New("invalid key length"))
 	}
 	var nonce [24]byte
-	var secretboxKey [32]byte
+	var secretboxKey [KEY_LENGTH_BYTES]byte
 	copy(secretboxKey[:], key)
 	box := secretbox.Seal(nil, plaintext, &nonce, &secretboxKey)
 	return box, nil
@@ -1528,11 +1532,11 @@ func box(key, plaintext []byte) ([]byte, error) {
 
 // unbox is a helper wrapper for secretbox.Open
 func unbox(key, box []byte) ([]byte, error) {
-	if len(key) != 32 {
+	if len(key) != KEY_LENGTH_BYTES {
 		return nil, common.ContextError(errors.New("invalid key length"))
 	}
 	var nonce [24]byte
-	var secretboxKey [32]byte
+	var secretboxKey [KEY_LENGTH_BYTES]byte
 	copy(secretboxKey[:], key)
 	plaintext, ok := secretbox.Open(nil, box, &nonce, &secretboxKey)
 	if !ok {

+ 78 - 253
psiphon/upstreamproxy/transport_proxy_auth.go

@@ -20,61 +20,41 @@
 package upstreamproxy
 
 import (
-	"bufio"
 	"bytes"
-	"context"
 	"fmt"
-	"io"
 	"io/ioutil"
-	"net"
 	"net/http"
-	"strings"
-	"sync"
 )
 
-const HTTP_STAT_LINE_LENGTH = 12
-
 // ProxyAuthTransport provides support for proxy authentication when doing plain HTTP
 // by tapping into HTTP conversation and adding authentication headers to the requests
 // when requested by server
+//
+// Limitation: in violation of https://golang.org/pkg/net/http/#RoundTripper,
+// ProxyAuthTransport is _not_ safe for concurrent RoundTrip calls. This is acceptable
+// for its use in Psiphon to provide upstream proxy support for meek, which makes only
+// serial RoundTrip calls. Concurrent RoundTrip calls will result in data race conditions
+// and undefined behavior during an authentication handshake.
 type ProxyAuthTransport struct {
 	*http.Transport
-	Username      string
-	Password      string
-	Authenticator HttpAuthenticator
-	mu            sync.Mutex
-	CustomHeaders http.Header
+	username         string
+	password         string
+	authenticator    HttpAuthenticator
+	customHeaders    http.Header
+	clonedBodyBuffer bytes.Buffer
 }
 
 func NewProxyAuthTransport(
 	rawTransport *http.Transport,
 	customHeaders http.Header) (*ProxyAuthTransport, error) {
 
-	if rawTransport.DialContext == nil {
-		return nil, fmt.Errorf("rawTransport must have DialContext")
-	}
-
 	if rawTransport.Proxy == nil {
 		return nil, fmt.Errorf("rawTransport must have Proxy")
 	}
 
 	tr := &ProxyAuthTransport{
 		Transport:     rawTransport,
-		CustomHeaders: customHeaders,
-	}
-
-	// Wrap the original transport's custom dialed conns in transportConns,
-	// which handle connection-based authentication.
-	originalDialContext := rawTransport.DialContext
-	rawTransport.DialContext = func(
-		ctx context.Context, network, addr string) (net.Conn, error) {
-		conn, err := originalDialContext(ctx, "tcp", addr)
-		if err != nil {
-			return nil, err
-		}
-		// Any additional dials made by transportConn are within
-		// the original dial context.
-		return newTransportConn(ctx, conn, tr), nil
+		customHeaders: customHeaders,
 	}
 
 	proxyUrl, err := rawTransport.Proxy(nil)
@@ -85,8 +65,8 @@ func NewProxyAuthTransport(
 		return nil, fmt.Errorf("%s unsupported", proxyUrl.Scheme)
 	}
 	if proxyUrl.User != nil {
-		tr.Username = proxyUrl.User.Username()
-		tr.Password, _ = proxyUrl.User.Password()
+		tr.username = proxyUrl.User.Username()
+		tr.password, _ = proxyUrl.User.Password()
 	}
 	// strip username and password from the proxyURL because
 	// we do not want the wrapped transport to handle authentication
@@ -96,75 +76,97 @@ func NewProxyAuthTransport(
 	return tr, nil
 }
 
-func (tr *ProxyAuthTransport) preAuthenticateRequest(req *http.Request) error {
-	tr.mu.Lock()
-	defer tr.mu.Unlock()
-	if tr.Authenticator == nil {
-		return nil
-	}
-	return tr.Authenticator.PreAuthenticate(req)
-}
+func (tr *ProxyAuthTransport) RoundTrip(request *http.Request) (*http.Response, error) {
 
-func (tr *ProxyAuthTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
-	if req.URL.Scheme != "http" {
-		return nil, fmt.Errorf("%s unsupported", req.URL.Scheme)
+	if request.URL.Scheme != "http" {
+		return nil, fmt.Errorf("%s unsupported", request.URL.Scheme)
 	}
-	err = tr.preAuthenticateRequest(req)
-	if err != nil {
-		return nil, err
+
+	// Notes:
+	//
+	// - The 407 authentication loop assumes no concurrent calls of RoundTrip
+	//   and additionally assumes that serial RoundTrip calls will always
+	//   resuse any existing HTTP persistent conn. The entire authentication
+	//   handshake must occur on the same HTTP persistent conn.
+	//
+	// - Requests are cloned for the lifetime of the ProxyAuthTransport,
+	//   since we don't know when the next initial RoundTrip may need to enter
+	//   the 407 authentication loop, which requires the initial request to be
+	//   cloned and replayable. Even if we hook into the Close call for any
+	//   existing HTTP persistent conn, it could be that it closes only after
+	//   RoundTrip is called.
+	//
+	// - Cloning reuses a buffer (clonedBodyBuffer) to store the request body
+	//   to avoid excessive allocations.
+
+	var cachedRequestBody []byte
+	if request.Body != nil {
+		tr.clonedBodyBuffer.Reset()
+		tr.clonedBodyBuffer.ReadFrom(request.Body)
+		request.Body.Close()
+		cachedRequestBody = tr.clonedBodyBuffer.Bytes()
 	}
 
-	var ha HttpAuthenticator
+	clonedRequest := cloneRequest(
+		request, tr.customHeaders, cachedRequestBody)
 
-	// Clone request early because RoundTrip will destroy request Body
-	// Also add custom headers to the cloned request
-	newReq := cloneRequest(req, tr.CustomHeaders)
+	if tr.authenticator != nil {
 
-	resp, err = tr.Transport.RoundTrip(newReq)
+		// For some authentication schemes (e.g., non-connection-based), once
+		// an initial 407 has been handled, add necessary and sufficient
+		// authentication headers to every request.
+
+		err := tr.authenticator.PreAuthenticate(clonedRequest)
+		if err != nil {
+			return nil, err
+		}
+	}
 
+	response, err := tr.Transport.RoundTrip(clonedRequest)
 	if err != nil {
-		return resp, proxyError(err)
+		return response, proxyError(err)
 	}
 
-	if resp.StatusCode == 407 {
-		tr.mu.Lock()
-		defer tr.mu.Unlock()
-		ha, err = NewHttpAuthenticator(resp, tr.Username, tr.Password)
+	if response.StatusCode == 407 {
+
+		authenticator, err := NewHttpAuthenticator(
+			response, tr.username, tr.password)
 		if err != nil {
+			response.Body.Close()
 			return nil, err
 		}
-		if ha.IsConnectionBased() {
-			return nil, proxyError(fmt.Errorf("Connection based auth was not handled by transportConn!"))
-		}
-		tr.Authenticator = ha
-	authenticationLoop:
+
 		for {
-			newReq = cloneRequest(req, tr.CustomHeaders)
-			err = tr.Authenticator.Authenticate(newReq, resp)
+			clonedRequest = cloneRequest(
+				request, tr.customHeaders, cachedRequestBody)
+
+			err = authenticator.Authenticate(clonedRequest, response)
+			response.Body.Close()
 			if err != nil {
 				return nil, err
 			}
-			resp, err = tr.Transport.RoundTrip(newReq)
 
+			response, err = tr.Transport.RoundTrip(clonedRequest)
 			if err != nil {
-				return resp, proxyError(err)
+				return nil, proxyError(err)
 			}
-			if resp.StatusCode != 407 {
-				if tr.Authenticator != nil && tr.Authenticator.IsComplete() {
-					tr.Authenticator.Reset()
-				}
-				break authenticationLoop
-			} else {
+
+			if response.StatusCode != 407 {
+
+				// Save the authenticator result to use for PreAuthenticate.
+
+				tr.authenticator = authenticator
+				break
 			}
 		}
 	}
-	return resp, err
 
+	return response, nil
 }
 
 // Based on https://github.com/golang/oauth2/blob/master/transport.go
 // Copyright 2014 The Go Authors. All rights reserved.
-func cloneRequest(r *http.Request, ch http.Header) *http.Request {
+func cloneRequest(r *http.Request, ch http.Header, body []byte) *http.Request {
 	// shallow copy of the struct
 	r2 := new(http.Request)
 	*r2 = *r
@@ -192,13 +194,7 @@ func cloneRequest(r *http.Request, ch http.Header) *http.Request {
 		}
 	}
 
-	if r.Body != nil {
-		body, _ := ioutil.ReadAll(r.Body)
-		defer r.Body.Close()
-		// restore original request Body
-		// drained by ReadAll()
-		r.Body = ioutil.NopCloser(bytes.NewReader(body))
-
+	if body != nil {
 		r2.Body = ioutil.NopCloser(bytes.NewReader(body))
 	}
 
@@ -207,174 +203,3 @@ func cloneRequest(r *http.Request, ch http.Header) *http.Request {
 
 	return r2
 }
-
-type transportConn struct {
-	net.Conn
-	ctx                context.Context
-	requestInterceptor io.Writer
-	reqDone            chan struct{}
-	errChannel         chan error
-	lastRequest        *http.Request
-	authenticator      HttpAuthenticator
-	transport          *ProxyAuthTransport
-}
-
-func newTransportConn(
-	ctx context.Context,
-	c net.Conn,
-	tr *ProxyAuthTransport) *transportConn {
-
-	// TODOs:
-	//
-	// - Additional dials made by transportConn, for authentication, use the
-	//   original conn's dial context. If authentication can be requested at any
-	//   time, instead of just at the start of a connection, then any deadline for
-	//   this context will be inappropriate.
-	//
-	// - The "intercept" goroutine spawned below will never terminate? Even if the
-	//   transportConn is closed, nothing will unblock reads of the pipe made by
-	//   http.ReadRequest. There should be a call to pw.Close() in transportConn.Close().
-	//
-	// - The ioutil.ReadAll in the "intercept" goroutine allocates new buffers for
-	//   every request. To avoid GC churn it should use a byte.Buffer to reuse a
-	//   single buffer. In practise, there will be a reasonably small maximum request
-	//   body size, so its better to retain and reuse a buffer than to continously
-	//   reallocate.
-	//
-	// - transportConn.Read will not do anything if the caller passes in a very small
-	//   read buffer. This should be documented, as its assuming that the caller is
-	//   fully reading at least HTTP_STAT_LINE_LENGTH at the start of request.
-	//
-	// - As a net.Conn, transportConn.Read should always be interrupted by a call to
-	//   Close, but it may be possible for Read to remain blocked:
-	//   1. caller writes less than a full request to Write
-	//   2. "intercept" call to http.ReadRequest will not return
-	//   3. caller calls Close, which just calls transportConn.Conn.Close
-	//   4. any existing call to Read remains blocked in the select
-
-	tc := &transportConn{
-		Conn:       c,
-		ctx:        ctx,
-		reqDone:    make(chan struct{}),
-		errChannel: make(chan error),
-		transport:  tr,
-	}
-	// Intercept outgoing request as it is written out to server and store it
-	// in case it needs to be authenticated and replayed
-	//NOTE that pipelining is currently not supported
-	pr, pw := io.Pipe()
-	tc.requestInterceptor = pw
-	requestReader := bufio.NewReader(pr)
-	go func() {
-	requestInterceptLoop:
-		for {
-			req, err := http.ReadRequest(requestReader)
-			if err != nil {
-				tc.Conn.Close()
-				pr.Close()
-				pw.Close()
-				tc.errChannel <- fmt.Errorf("intercept request loop http.ReadRequest error: %s", err)
-				break requestInterceptLoop
-			}
-			//read and copy entire body
-			body, _ := ioutil.ReadAll(req.Body)
-			tc.lastRequest = req
-			tc.lastRequest.Body = ioutil.NopCloser(bytes.NewReader(body))
-			//Signal when we have a complete request
-			tc.reqDone <- struct{}{}
-		}
-	}()
-	return tc
-}
-
-// Read peeks into the new response and checks if the proxy requests authentication
-// If so, the last intercepted request is authenticated against the response
-// in case of connection based auth scheme(i.e. NTLM)
-// All the non-connection based schemes are handled by the ProxyAuthTransport.RoundTrip()
-func (tc *transportConn) Read(p []byte) (n int, readErr error) {
-	n, readErr = tc.Conn.Read(p)
-	if n < HTTP_STAT_LINE_LENGTH {
-		return
-	}
-	select {
-	case _ = <-tc.reqDone:
-		line := string(p[:HTTP_STAT_LINE_LENGTH])
-		//This is a new response
-		//Let's see if proxy requests authentication
-		f := strings.SplitN(line, " ", 2)
-
-		readBufferReader := io.NewSectionReader(bytes.NewReader(p), 0, int64(n))
-		responseReader := bufio.NewReader(readBufferReader)
-		if (f[0] == "HTTP/1.0" || f[0] == "HTTP/1.1") && f[1] == "407" {
-			resp, err := http.ReadResponse(responseReader, nil)
-			if err != nil {
-				return 0, err
-			}
-			ha, err := NewHttpAuthenticator(resp, tc.transport.Username, tc.transport.Password)
-			if err != nil {
-				return 0, err
-			}
-			// If connection based auth is requested, we are going to
-			// authenticate request on this very connection
-			// otherwise just return what we read
-			if !ha.IsConnectionBased() {
-				return
-			}
-
-			// Drain the rest of the response
-			// in order to perform auth handshake
-			// on the connection
-			readBufferReader.Seek(0, 0)
-			responseReader = bufio.NewReader(io.MultiReader(readBufferReader, tc.Conn))
-			resp, err = http.ReadResponse(responseReader, nil)
-			if err != nil {
-				return 0, err
-			}
-
-			ioutil.ReadAll(resp.Body)
-			resp.Body.Close()
-
-			if tc.authenticator == nil {
-				tc.authenticator = ha
-			}
-
-			if resp.Close == true {
-				// Server side indicated that it is closing this connection,
-				// dial a new one
-				addr := tc.Conn.RemoteAddr()
-				tc.Conn.Close()
-
-				// Additional dials are made within the context of the dial of the
-				// outer conn this transportConn is wrapping, so the scope of outer
-				// dial timeouts includes these additional dials. This is also to
-				// ensure these dials are interrupted when the context is canceled.
-
-				tc.Conn, err = tc.transport.Transport.DialContext(
-					tc.ctx, addr.Network(), addr.String())
-
-				if err != nil {
-					return 0, err
-				}
-			}
-
-			// Authenticate and replay the request on the connection
-			err = tc.authenticator.Authenticate(tc.lastRequest, resp)
-			if err != nil {
-				return 0, err
-			}
-			tc.lastRequest.WriteProxy(tc)
-			return tc.Read(p)
-		}
-	case err := <-tc.errChannel:
-		return 0, err
-	default:
-	}
-	return
-}
-
-func (tc *transportConn) Write(p []byte) (n int, err error) {
-	n, err = tc.Conn.Write(p)
-	//also write data to the request interceptor
-	tc.requestInterceptor.Write(p[:n])
-	return n, err
-}