Browse Source

Final HTTP auth proxy transport, updated README with code examples

Eugene Fryntov 10 years ago
parent
commit
09075b0214
2 changed files with 126 additions and 32 deletions
  1. 80 6
      psiphon/upstreamproxy/README.md
  2. 46 26
      psiphon/upstreamproxy/transport_proxy_auth.go

+ 80 - 6
psiphon/upstreamproxy/README.md

@@ -10,6 +10,8 @@ Currently supported protocols:
 
 # Usage
 
+Note: `NewProxyDialFunc` returns `ForwardDialFunc` if `ProxyURIString` is empty
+
 ```
 /* 
    Proxy URI examples:
@@ -18,11 +20,83 @@ Currently supported protocols:
    "http://NTDOMAIN\NTUser:password@proxyhost:3375"
 */
 
-var proxyDialer psiphon.Dialer 
-proxyDialer = NewProxyDialFunc((
-            ForwardDialFunc: psiphon.NewTCPDialer(tcpDialerConfig),
-            ProxyURIString: "http://user:password@proxyhost:8080"
-            })
+//Plain HTTP transport via HTTP proxy
+func doAuthenticatedHTTP() {
+	proxyUrl, err := url.Parse("http://user:password@172.16.1.1:8080")
+	transport := &http.Transport{Proxy: http.ProxyURL(proxyUrl)}
+
+	authHttpTransport, err := upstreamproxy.NewProxyAuthTransport(transport)
+	if err != nil {
+		fmt.Println("Error: ", err)
+		return
+	}
+	r, err := http.NewRequest("GET", "http://www.reddit.com", bytes.NewReader(data))
+	if err != nil {
+		fmt.Println("Error: ", err)
+		return
+	}
+	resp, err := authHttpTransport.RoundTrip(r)
+	if err != nil {
+		fmt.Println("RoundTrip Error: ", err)
+		return
+	}
+	ioutil.ReadAll(resp.Body)
+	fmt.Println(string(resp.Status))
+}
+
+//HTTPS transport via HTTP proxy
+func doAuthenticatedHTTPS() {
+	dialTlsFn := func(netw, addr string) (net.Conn, error) {
+		config := &upstreamproxy.UpstreamProxyConfig{
+			ForwardDialFunc: net.Dial,
+			ProxyURIString:  "http://user:password@172.16.1.1:8080",
+		}
+
+		proxyDialFunc := upstreamproxy.NewProxyDialFunc(config)
+
+		conn, err := proxyDialFunc(netw, addr)
+		if err != nil {
+			return nil, err
+		}
+		tlsconfig := &tls.Config{InsecureSkipVerify: true}
+		tlsConn := tls.Client(conn, tlsconfig)
+
+		return tlsConn, tlsConn.Handshake()
+	}
+
+	r, err := http.NewRequest("GET", "https://www.reddit.com", bytes.NewReader(data))
+	transport = &http.Transport{DialTLS: dialTlsFn}
+	resp, err := transport.RoundTrip(r)
+	if err != nil {
+		log.Println("RoundTrip Error: ", err)
+		return
+	}
+	ioutil.ReadAll(resp.Body)
+	fmt.Println(string(resp.Status))
+}
+
+//HTTP transport via SOCKS5 proxy
+func doAuthenticatedHttpSocks() {
+	dialFn := func(netw, addr string) (net.Conn, error) {
+		config := &upstreamproxy.UpstreamProxyConfig{
+			ForwardDialFunc: net.Dial,
+			ProxyURIString:  "socks5://user:password@172.16.1.1:5555",
+		}
+
+		proxyDialFunc := upstreamproxy.NewProxyDialFunc(config)
+
+		return proxyDialFunc(netw, addr)
+	}
+
+	r, err := http.NewRequest("GET", "https://www.reddit.com", bytes.NewReader(data))
+	transport = &http.Transport{Dial: dialFn}
+	resp, err := transport.RoundTrip(r)
+	if err != nil {
+		log.Println("RoundTrip Error: ", err)
+		return
+	}
+	ioutil.ReadAll(resp.Body)
+	fmt.Println(string(resp.Status))
+}
 ```
 
-Note: `NewProxyDialFunc` returns `ForwardDialFunc` if `ProxyURIString` is empty

+ 46 - 26
psiphon/upstreamproxy/transport_proxy_auth.go

@@ -8,9 +8,7 @@ import (
 	"io/ioutil"
 	"net"
 	"net/http"
-	"net/url"
 	"strings"
-	"time"
 )
 
 type ProxyAuthTransport struct {
@@ -20,31 +18,45 @@ type ProxyAuthTransport struct {
 	Password string
 }
 
-func NewProxyAuthTransport(proxy string, dialFn DialFunc, responseHeaderTimeout time.Duration) (*ProxyAuthTransport, error) {
-	tr := &ProxyAuthTransport{Dial: dialFn}
-
-	wrappedDialFn := tr.wrapTransportDial()
-	proxyUrl, err := url.Parse(proxy)
-	if err != nil {
-		return nil, err
+func NewProxyAuthTransport(rawTransport *http.Transport) (*ProxyAuthTransport, error) {
+	dialFn := rawTransport.Dial
+	if dialFn == nil {
+		dialFn = net.Dial
 	}
-	tr.Username = proxyUrl.User.Username()
-	tr.Password, _ = proxyUrl.User.Password()
-	tr.Transport = &http.Transport{
-		Dial:  wrappedDialFn,
-		Proxy: http.ProxyURL(proxyUrl),
-		ResponseHeaderTimeout: responseHeaderTimeout,
+	tr := &ProxyAuthTransport{Dial: dialFn}
+	proxyUrlFn := rawTransport.Proxy
+	if proxyUrlFn != nil {
+		wrappedDialFn := tr.wrapTransportDial()
+		proxyUrl, err := proxyUrlFn(nil)
+		if err != nil {
+			return nil, err
+		}
+		if proxyUrl.Scheme != "http" {
+			return nil, fmt.Errorf("Only HTTP proxysupported, for SOCKS use http.Transport with custom dialers & upstreamproxy.NewProxyDialFunc")
+		}
+		tr.Username = proxyUrl.User.Username()
+		tr.Password, _ = proxyUrl.User.Password()
+		rawTransport.Dial = wrappedDialFn
 	}
+
+	tr.Transport = rawTransport
 	return tr, nil
 }
 
+func (tr *ProxyAuthTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
+	if req.URL.Scheme != "http" {
+		return nil, fmt.Errorf("Only plain HTTP supported, for HTTPS use http.Transport with DialTLS & upstreamproxy.NewProxyDialFunc")
+	}
+	return tr.Transport.RoundTrip(req)
+}
+
 func (tr *ProxyAuthTransport) wrapTransportDial() DialFunc {
 	return func(network, addr string) (net.Conn, error) {
 		c, err := tr.Dial("tcp", addr)
 		if err != nil {
 			return nil, err
 		}
-		tc := newTransportConn(c, tr.Dial, tr)
+		tc := newTransportConn(c, tr)
 		return tc, nil
 	}
 }
@@ -53,32 +65,35 @@ type transportConn struct {
 	net.Conn
 	requestWriter io.Writer
 	reqDone       chan struct{}
+	errChannel    chan error
 	connReader    *bufio.Reader
 	lastRequest   *http.Request
-	Dial          DialFunc
 	authenticator HttpAuthenticator
 	authState     HttpAuthState
 	transport     *ProxyAuthTransport
 }
 
-func newTransportConn(c net.Conn, dialFn DialFunc, tr *ProxyAuthTransport) *transportConn {
+func newTransportConn(c net.Conn, tr *ProxyAuthTransport) *transportConn {
 	tc := &transportConn{
 		Conn:       c,
 		reqDone:    make(chan struct{}),
+		errChannel: make(chan error),
 		connReader: bufio.NewReader(c),
-		Dial:       dialFn,
 		transport:  tr,
 	}
+	pr, pw := io.Pipe()
+	tc.requestWriter = pw
 	go func() {
-		pr, pw := io.Pipe()
-		defer pr.Close()
-		defer pw.Close()
-		tc.requestWriter = pw
+	requestInterceptLoop:
 		for {
-			//Request intercepting loop
 			req, err := http.ReadRequest(bufio.NewReader(pr))
 			if err != nil {
-				fmt.Println("http.ReadRequest error: ", err)
+				tc.Close()
+				tc.errChannel <- fmt.Errorf("intercept request loop http.ReadRequest error: %s", err)
+				pr.Close()
+				pw.Close()
+				tc.Close()
+				break requestInterceptLoop
 			}
 			//read and copy entire body
 			body, _ := ioutil.ReadAll(req.Body)
@@ -130,10 +145,11 @@ func (tc *transportConn) Read(p []byte) (int, error) {
 				// dial a new one
 				addr := tc.Conn.RemoteAddr()
 				tc.Conn.Close()
-				tc.Conn, err = tc.Dial(addr.Network(), addr.String())
+				tc.Conn, err = tc.transport.Dial(addr.Network(), addr.String())
 				if err != nil {
 					return 0, err
 				}
+				tc.connReader = bufio.NewReader(tc.Conn)
 			}
 
 			err = tc.authenticator.Authenticate(tc.lastRequest, resp, tc.transport.Username, tc.transport.Password)
@@ -143,9 +159,13 @@ func (tc *transportConn) Read(p []byte) (int, error) {
 
 			//TODO: eliminate possible race condition
 			//Replay authenticated request
+			//block until a new cycle of the request intercept loop started
+			//<-tc.writeBlocker
 			tc.lastRequest.WriteProxy(tc)
 			return tc.Read(p)
 		}
+	case err = <-tc.errChannel:
+		return 0, err
 	default:
 	}
 	n, err := tc.connReader.Read(p)