Преглед изворни кода

RoundTripper to use in place of http.Transport for doing HTTP via authenticating proxy

Eugene Fryntov пре 10 година
родитељ
комит
ae44941c68

+ 1 - 1
psiphon/upstreamproxy/auth_basic.go

@@ -21,7 +21,7 @@ func newBasicAuthenticator() *BasicHttpAuthenticator {
 	return &BasicHttpAuthenticator{state: BASIC_HTTP_AUTH_STATE_CHALLENGE_RECEIVED}
 	return &BasicHttpAuthenticator{state: BASIC_HTTP_AUTH_STATE_CHALLENGE_RECEIVED}
 }
 }
 
 
-func (a *BasicHttpAuthenticator) authenticate(req *http.Request, resp *http.Response, username, password string) error {
+func (a *BasicHttpAuthenticator) Authenticate(req *http.Request, resp *http.Response, username, password string) error {
 	if a.state == BASIC_HTTP_AUTH_STATE_CHALLENGE_RECEIVED {
 	if a.state == BASIC_HTTP_AUTH_STATE_CHALLENGE_RECEIVED {
 		auth := username + ":" + password
 		auth := username + ":" + password
 		req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
 		req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))

+ 1 - 1
psiphon/upstreamproxy/auth_digest.go

@@ -106,7 +106,7 @@ func h(data string) string {
 	return fmt.Sprintf("%x", digest.Sum(nil))
 	return fmt.Sprintf("%x", digest.Sum(nil))
 }
 }
 
 
-func (a *DigestHttpAuthenticator) authenticate(req *http.Request, resp *http.Response, username, password string) error {
+func (a *DigestHttpAuthenticator) Authenticate(req *http.Request, resp *http.Response, username, password string) error {
 	if a.state != DIGEST_HTTP_AUTH_STATE_CHALLENGE_RECEIVED {
 	if a.state != DIGEST_HTTP_AUTH_STATE_CHALLENGE_RECEIVED {
 		return errors.New("upstreamproxy: Authorization is not accepted by the proxy server")
 		return errors.New("upstreamproxy: Authorization is not accepted by the proxy server")
 	}
 	}

+ 9 - 1
psiphon/upstreamproxy/auth_ntlm.go

@@ -24,10 +24,18 @@ func newNTLMAuthenticator() *NTLMHttpAuthenticator {
 	return &NTLMHttpAuthenticator{state: NTLM_HTTP_AUTH_STATE_CHALLENGE_RECEIVED}
 	return &NTLMHttpAuthenticator{state: NTLM_HTTP_AUTH_STATE_CHALLENGE_RECEIVED}
 }
 }
 
 
-func (a *NTLMHttpAuthenticator) authenticate(req *http.Request, resp *http.Response, username, password string) error {
+func (a *NTLMHttpAuthenticator) Authenticate(req *http.Request, resp *http.Response, username, password string) error {
+	if a.state == NTLM_HTTP_AUTH_STATE_RESPONSE_TYPE3_GENERATED {
+		return errors.New("upstreamproxy: Authorization is not accepted by the proxy server")
+	}
 	challenges, err := parseAuthChallenge(resp)
 	challenges, err := parseAuthChallenge(resp)
 
 
 	challenge, ok := challenges["NTLM"]
 	challenge, ok := challenges["NTLM"]
+	if challenge == "" {
+		a.state = NTLM_HTTP_AUTH_STATE_CHALLENGE_RECEIVED
+	} else {
+		a.state = NTLM_HTTP_AUTH_STATE_RESPONSE_TYPE1_GENERATED
+	}
 	if !ok {
 	if !ok {
 		return errors.New("upstreamproxy: Bad proxy response, no NTLM challenge for NTLMHttpAuthenticator")
 		return errors.New("upstreamproxy: Bad proxy response, no NTLM challenge for NTLMHttpAuthenticator")
 	}
 	}

+ 2 - 2
psiphon/upstreamproxy/http_authenticator.go

@@ -16,7 +16,7 @@ const (
 )
 )
 
 
 type HttpAuthenticator interface {
 type HttpAuthenticator interface {
-	authenticate(req *http.Request, resp *http.Response, username, pasword string) error
+	Authenticate(req *http.Request, resp *http.Response, username, pasword string) error
 }
 }
 
 
 func parseAuthChallenge(resp *http.Response) (map[string]string, error) {
 func parseAuthChallenge(resp *http.Response) (map[string]string, error) {
@@ -38,7 +38,7 @@ func parseAuthChallenge(resp *http.Response) (map[string]string, error) {
 	return challenges, nil
 	return challenges, nil
 }
 }
 
 
-func newHttpAuthenticator(resp *http.Response) (HttpAuthenticator, error) {
+func NewHttpAuthenticator(resp *http.Response) (HttpAuthenticator, error) {
 
 
 	challenges, err := parseAuthChallenge(resp)
 	challenges, err := parseAuthChallenge(resp)
 	if err != nil {
 	if err != nil {

+ 2 - 3
psiphon/upstreamproxy/proxy_http.go

@@ -80,7 +80,6 @@ func newHTTP(uri *url.URL, forward proxy.Dialer) (proxy.Dialer, error) {
 
 
 func (hp *httpProxy) Dial(network, addr string) (net.Conn, error) {
 func (hp *httpProxy) Dial(network, addr string) (net.Conn, error) {
 	// Dial and create the http client connection.
 	// Dial and create the http client connection.
-
 	pc := &proxyConn{authState: HTTP_AUTH_STATE_UNCHALLENGED}
 	pc := &proxyConn{authState: HTTP_AUTH_STATE_UNCHALLENGED}
 	err := pc.makeNewClientConn(hp.forward, hp.hostPort)
 	err := pc.makeNewClientConn(hp.forward, hp.hostPort)
 	if err != nil {
 	if err != nil {
@@ -143,7 +142,7 @@ func (pc *proxyConn) handshake(addr, username, password string) error {
 	req.Header.Set("User-Agent", "")
 	req.Header.Set("User-Agent", "")
 
 
 	if pc.authState == HTTP_AUTH_STATE_CHALLENGED {
 	if pc.authState == HTTP_AUTH_STATE_CHALLENGED {
-		err := pc.authenticator.authenticate(req, pc.authResponse, username, password)
+		err := pc.authenticator.Authenticate(req, pc.authResponse, username, password)
 		if err != nil {
 		if err != nil {
 			pc.authState = HTTP_AUTH_STATE_FAILURE
 			pc.authState = HTTP_AUTH_STATE_FAILURE
 			return err
 			return err
@@ -166,7 +165,7 @@ func (pc *proxyConn) handshake(addr, username, password string) error {
 	if resp.StatusCode == 407 {
 	if resp.StatusCode == 407 {
 		if pc.authState == HTTP_AUTH_STATE_UNCHALLENGED {
 		if pc.authState == HTTP_AUTH_STATE_UNCHALLENGED {
 			var auth_err error = nil
 			var auth_err error = nil
-			pc.authenticator, auth_err = newHttpAuthenticator(resp)
+			pc.authenticator, auth_err = NewHttpAuthenticator(resp)
 			if auth_err != nil {
 			if auth_err != nil {
 				pc.httpClientConn.Close()
 				pc.httpClientConn.Close()
 				pc.authState = HTTP_AUTH_STATE_FAILURE
 				pc.authState = HTTP_AUTH_STATE_FAILURE

+ 141 - 33
psiphon/upstreamproxy/transport_proxy_auth.go

@@ -1,51 +1,159 @@
 package upstreamproxy
 package upstreamproxy
 
 
 import (
 import (
+	"bufio"
+	"bytes"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"net"
 	"net/http"
 	"net/http"
+	"net/url"
+	"strings"
+	"time"
 )
 )
 
 
-type Transport struct {
-	Username  string
-	Password  string
-	transport http.RoundTripper
+type ProxyAuthTransport struct {
+	*http.Transport
+	Dial     DialFunc
+	Username string
+	Password string
 }
 }
 
 
-func NewTransport(username, password string, dialFn DialFunc) *Transport {
-	t := &Transport{
-		Username: username,
-		Password: password,
+func NewProxyAuthTransport(proxy string, dialFn DialFunc, responseHeaderTimeout time.Duration) (*ProxyAuthTransport, error) {
+	tr := &ProxyAuthTransport{Dial: dialFn}
+
+	wrappedDialFn := tr.wrapTransportDial()
+	proxyUrl, err := url.Parse(proxy)
+	if err != nil {
+		return nil, err
+	}
+	tr.Username = proxyUrl.User.Username()
+	tr.Password, _ = proxyUrl.User.Password()
+	tr.Transport = &http.Transport{
+		Dial:  wrappedDialFn,
+		Proxy: http.ProxyURL(proxyUrl),
+		ResponseHeaderTimeout: responseHeaderTimeout,
 	}
 	}
-	t.transport = &http.Transport{Dial: dialFn}
-	return t
+	return tr, nil
 }
 }
 
 
-/*
-   func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
-
-	// TODO: Check if we cached auth header for the transport ProxyURL
-	resp, err := t.transport.RoundTrip(req)
-	if resp.StatusCode == 407 {
-		//TODO: Generate new auth header and cache it
-		req2 := cloneRequest(req)
-		err = authenticateRequest(req2, resp, t.Username, t.Password)
+func (tr *ProxyAuthTransport) wrapTransportDial() DialFunc {
+	return func(network, addr string) (net.Conn, error) {
+		c, err := tr.Dial("tcp", addr)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
-		//TODO: avoid going into endless loop due to bad proxy credentials
-		return t.RoundTrip(req2)
+		tc := newTransportConn(c, tr.Dial, tr)
+		return tc, nil
 	}
 	}
-	return resp, err
 }
 }
-*/
-
-func cloneRequest(r *http.Request) *http.Request {
-	// shallow copy of the struct
-	r2 := new(http.Request)
-	*r2 = *r
-	// deep copy of the Header
-	r2.Header = make(http.Header, len(r.Header))
-	for k, s := range r.Header {
-		r2.Header[k] = append([]string(nil), s...)
+
+type transportConn struct {
+	net.Conn
+	requestWriter io.Writer
+	reqDone       chan struct{}
+	connReader    *bufio.Reader
+	lastRequest   *http.Request
+	Dial          DialFunc
+	authenticator HttpAuthenticator
+	authState     HttpAuthState
+	transport     *ProxyAuthTransport
+}
+
+func newTransportConn(c net.Conn, dialFn DialFunc, tr *ProxyAuthTransport) *transportConn {
+	tc := &transportConn{
+		Conn:       c,
+		reqDone:    make(chan struct{}),
+		connReader: bufio.NewReader(c),
+		Dial:       dialFn,
+		transport:  tr,
 	}
 	}
-	return r2
+	go func() {
+		pr, pw := io.Pipe()
+		defer pr.Close()
+		defer pw.Close()
+		tc.requestWriter = pw
+		for {
+			//Request intercepting loop
+			req, err := http.ReadRequest(bufio.NewReader(pr))
+			if err != nil {
+				fmt.Println("http.ReadRequest error: ", err)
+			}
+			//read and copy entire body
+			body, _ := ioutil.ReadAll(req.Body)
+			tc.lastRequest = req
+			tc.lastRequest.Body = ioutil.NopCloser(bytes.NewReader(body))
+			tc.reqDone <- struct{}{}
+		}
+	}()
+	return tc
+}
+
+func (tc *transportConn) Read(p []byte) (int, error) {
+	/*
+	   The first Read on a new RoundTrip will occur *before* Write and
+	   will block until request is written out completely and response
+	   headers are read in
+
+	   Peek will actually call Read and buffer read data
+	*/
+	peeked, err := tc.connReader.Peek(12)
+	if err != nil {
+		return 0, err
+	}
+	line := string(peeked)
+	select {
+	case _ = <-tc.reqDone:
+		//Brand new response
+		f := strings.SplitN(line, " ", 2)
+		if (f[0] == "HTTP/1.0" || f[0] == "HTTP/1.1") && f[1] == "407" {
+			resp, err := http.ReadResponse(tc.connReader, nil)
+			if err != nil {
+				return 0, err
+			}
+			// make sure we read the body of the response so that
+			// we don't block the reader
+			ioutil.ReadAll(resp.Body)
+			resp.Body.Close()
+
+			if tc.authState == HTTP_AUTH_STATE_UNCHALLENGED {
+				tc.authenticator, err = NewHttpAuthenticator(resp)
+				if err != nil {
+					return 0, err
+				}
+				tc.authState = HTTP_AUTH_STATE_CHALLENGED
+			}
+
+			if resp.Close == true {
+				// Server side indicated that it is closing this connection,
+				// dial a new one
+				addr := tc.Conn.RemoteAddr()
+				tc.Conn.Close()
+				tc.Conn, err = tc.Dial(addr.Network(), addr.String())
+				if err != nil {
+					return 0, err
+				}
+			}
+
+			err = tc.authenticator.Authenticate(tc.lastRequest, resp, tc.transport.Username, tc.transport.Password)
+			if err != nil {
+				return 0, err
+			}
+
+			//TODO: eliminate possible race condition
+			//Replay authenticated request
+			tc.lastRequest.WriteProxy(tc)
+			return tc.Read(p)
+		}
+	default:
+	}
+	n, err := tc.connReader.Read(p)
+	return n, err
+}
+
+func (tc *transportConn) Write(p []byte) (n int, err error) {
+	n, err = tc.Conn.Write(p)
+	tc.requestWriter.Write(p[:n])
+	return n, err
 }
 }