|
|
@@ -28,16 +28,21 @@ import (
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"strings"
|
|
|
+ "sync"
|
|
|
)
|
|
|
|
|
|
+const HTTP_STAT_LINE_LENGTH = 12
|
|
|
+
|
|
|
// ProxyAuthTransport provides support for proxy authentication when doing plain HTTP
|
|
|
// by tapping into HTTP conversation and adding authentication headers to the requests
|
|
|
// when requested by server
|
|
|
type ProxyAuthTransport struct {
|
|
|
*http.Transport
|
|
|
- Dial DialFunc
|
|
|
- Username string
|
|
|
- Password string
|
|
|
+ Dial DialFunc
|
|
|
+ Username string
|
|
|
+ Password string
|
|
|
+ Authenticator HttpAuthenticator
|
|
|
+ mu sync.Mutex
|
|
|
}
|
|
|
|
|
|
func NewProxyAuthTransport(rawTransport *http.Transport) (*ProxyAuthTransport, error) {
|
|
|
@@ -49,6 +54,7 @@ func NewProxyAuthTransport(rawTransport *http.Transport) (*ProxyAuthTransport, e
|
|
|
proxyUrlFn := rawTransport.Proxy
|
|
|
if proxyUrlFn != nil {
|
|
|
wrappedDialFn := tr.wrapTransportDial()
|
|
|
+ rawTransport.Dial = wrappedDialFn
|
|
|
proxyUrl, err := proxyUrlFn(nil)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
@@ -56,20 +62,83 @@ func NewProxyAuthTransport(rawTransport *http.Transport) (*ProxyAuthTransport, e
|
|
|
if proxyUrl.Scheme != "http" {
|
|
|
return nil, fmt.Errorf("Only HTTP proxy supported, for SOCKS use http.Transport with custom dialers & upstreamproxy.NewProxyDialFunc")
|
|
|
}
|
|
|
- tr.Username = proxyUrl.User.Username()
|
|
|
- tr.Password, _ = proxyUrl.User.Password()
|
|
|
- rawTransport.Dial = wrappedDialFn
|
|
|
+ if proxyUrl.User != nil {
|
|
|
+ tr.Username = proxyUrl.User.Username()
|
|
|
+ tr.Password, _ = proxyUrl.User.Password()
|
|
|
+ }
|
|
|
+ // strip username and password from the proxyURL because
|
|
|
+ // we do not want the wrapped transport to handle authentication
|
|
|
+ proxyUrl.User = nil
|
|
|
+ rawTransport.Proxy = http.ProxyURL(proxyUrl)
|
|
|
}
|
|
|
|
|
|
tr.Transport = rawTransport
|
|
|
return tr, nil
|
|
|
}
|
|
|
|
|
|
+func (tr *ProxyAuthTransport) preAuthenticateRequest(req *http.Request) error {
|
|
|
+ tr.mu.Lock()
|
|
|
+ defer tr.mu.Unlock()
|
|
|
+ if tr.Authenticator == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ return tr.Authenticator.PreAuthenticate(req)
|
|
|
+}
|
|
|
+
|
|
|
func (tr *ProxyAuthTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
|
|
|
if req.URL.Scheme != "http" {
|
|
|
return nil, fmt.Errorf("Only plain HTTP supported, for HTTPS use http.Transport with DialTLS & upstreamproxy.NewProxyDialFunc")
|
|
|
}
|
|
|
- return tr.Transport.RoundTrip(req)
|
|
|
+ err = tr.preAuthenticateRequest(req)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ var ha HttpAuthenticator = nil
|
|
|
+
|
|
|
+ //Clone request early because RoundTrip will destroy request Body
|
|
|
+ newReq := cloneRequest(req)
|
|
|
+
|
|
|
+ resp, err = tr.Transport.RoundTrip(newReq)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ return resp, proxyError(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if resp.StatusCode == 407 {
|
|
|
+ tr.mu.Lock()
|
|
|
+ defer tr.mu.Unlock()
|
|
|
+ ha, err = NewHttpAuthenticator(resp, tr.Username, tr.Password)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ if ha.IsConnectionBased() {
|
|
|
+ return nil, proxyError(fmt.Errorf("Connection based auth was not handled by transportConn!"))
|
|
|
+ }
|
|
|
+ tr.Authenticator = ha
|
|
|
+ authenticationLoop:
|
|
|
+ for {
|
|
|
+ newReq = cloneRequest(req)
|
|
|
+ err = tr.Authenticator.Authenticate(newReq, resp)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ resp, err = tr.Transport.RoundTrip(newReq)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ return resp, proxyError(err)
|
|
|
+ }
|
|
|
+ if resp.StatusCode != 407 {
|
|
|
+ if tr.Authenticator != nil && tr.Authenticator.IsComplete() {
|
|
|
+ tr.Authenticator.Reset()
|
|
|
+ }
|
|
|
+ break authenticationLoop
|
|
|
+ } else {
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return resp, err
|
|
|
+
|
|
|
}
|
|
|
|
|
|
// wrapTransportDial wraps original transport Dial function
|
|
|
@@ -87,19 +156,36 @@ func (tr *ProxyAuthTransport) wrapTransportDial() DialFunc {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func cloneRequest(r *http.Request) *http.Request {
|
|
|
+ // shallow copy of the struct
|
|
|
+ r2 := new(http.Request)
|
|
|
+ *r2 = *r
|
|
|
+ // deep copy of the Header
|
|
|
+ r2.Header = make(http.Header)
|
|
|
+ for k, s := range r.Header {
|
|
|
+ r2.Header[k] = s
|
|
|
+ }
|
|
|
+
|
|
|
+ if r.Body != nil {
|
|
|
+ body, _ := ioutil.ReadAll(r.Body)
|
|
|
+ defer r.Body.Close()
|
|
|
+ // restore original request Body
|
|
|
+ // drained by ReadAll()
|
|
|
+ r.Body = ioutil.NopCloser(bytes.NewReader(body))
|
|
|
+
|
|
|
+ r2.Body = ioutil.NopCloser(bytes.NewReader(body))
|
|
|
+ }
|
|
|
+ return r2
|
|
|
+}
|
|
|
+
|
|
|
type transportConn struct {
|
|
|
net.Conn
|
|
|
requestInterceptor io.Writer
|
|
|
reqDone chan struct{}
|
|
|
errChannel chan error
|
|
|
- // a buffered Reader from the raw net.Conn so we could Peek at the data
|
|
|
- // without advancing the 'read' pointer
|
|
|
- connReader *bufio.Reader
|
|
|
- // last written request holder
|
|
|
- lastRequest *http.Request
|
|
|
- authenticator HttpAuthenticator
|
|
|
- authState HttpAuthState
|
|
|
- transport *ProxyAuthTransport
|
|
|
+ lastRequest *http.Request
|
|
|
+ authenticator HttpAuthenticator
|
|
|
+ transport *ProxyAuthTransport
|
|
|
}
|
|
|
|
|
|
func newTransportConn(c net.Conn, tr *ProxyAuthTransport) *transportConn {
|
|
|
@@ -107,7 +193,6 @@ func newTransportConn(c net.Conn, tr *ProxyAuthTransport) *transportConn {
|
|
|
Conn: c,
|
|
|
reqDone: make(chan struct{}),
|
|
|
errChannel: make(chan error),
|
|
|
- connReader: bufio.NewReader(c),
|
|
|
transport: tr,
|
|
|
}
|
|
|
// Intercept outgoing request as it is written out to server and store it
|
|
|
@@ -115,10 +200,11 @@ func newTransportConn(c net.Conn, tr *ProxyAuthTransport) *transportConn {
|
|
|
//NOTE that pipelining is currently not supported
|
|
|
pr, pw := io.Pipe()
|
|
|
tc.requestInterceptor = pw
|
|
|
+ requestReader := bufio.NewReader(pr)
|
|
|
go func() {
|
|
|
requestInterceptLoop:
|
|
|
for {
|
|
|
- req, err := http.ReadRequest(bufio.NewReader(pr))
|
|
|
+ req, err := http.ReadRequest(requestReader)
|
|
|
if err != nil {
|
|
|
tc.Conn.Close()
|
|
|
pr.Close()
|
|
|
@@ -139,34 +225,53 @@ func newTransportConn(c net.Conn, tr *ProxyAuthTransport) *transportConn {
|
|
|
|
|
|
// Read peeks into the new response and checks if the proxy requests authentication
|
|
|
// If so, the last intercepted request is authenticated against the response
|
|
|
-// authentication challenge and replayed
|
|
|
-func (tc *transportConn) Read(p []byte) (int, error) {
|
|
|
- peeked, err := tc.connReader.Peek(12)
|
|
|
- if err != nil {
|
|
|
- return 0, err
|
|
|
+// in case of connection based auth scheme(i.e. NTLM)
|
|
|
+// All the non-connection based schemes are handled by the ProxyAuthTransport.RoundTrip()
|
|
|
+func (tc *transportConn) Read(p []byte) (n int, read_err error) {
|
|
|
+ n, read_err = tc.Conn.Read(p)
|
|
|
+ if n < HTTP_STAT_LINE_LENGTH {
|
|
|
+ return
|
|
|
}
|
|
|
- line := string(peeked)
|
|
|
select {
|
|
|
case _ = <-tc.reqDone:
|
|
|
+ line := string(p[:HTTP_STAT_LINE_LENGTH])
|
|
|
//This is a new response
|
|
|
//Let's see if proxy requests authentication
|
|
|
f := strings.SplitN(line, " ", 2)
|
|
|
+
|
|
|
+ readBufferReader := io.NewSectionReader(bytes.NewReader(p), 0, int64(n))
|
|
|
+ responseReader := bufio.NewReader(readBufferReader)
|
|
|
if (f[0] == "HTTP/1.0" || f[0] == "HTTP/1.1") && f[1] == "407" {
|
|
|
- resp, err := http.ReadResponse(tc.connReader, nil)
|
|
|
+ resp, err := http.ReadResponse(responseReader, nil)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
- // make sure we read the body of the response so that
|
|
|
- // we don't block the reader
|
|
|
+ ha, err := NewHttpAuthenticator(resp, tc.transport.Username, tc.transport.Password)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ // If connection based auth is requested, we are going to
|
|
|
+ // authenticate request on this very connection
|
|
|
+ // otherwise just return what we read
|
|
|
+ if !ha.IsConnectionBased() {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Drain the rest of the response
|
|
|
+ // in order to perform auth handshake
|
|
|
+ // on the connection
|
|
|
+ readBufferReader.Seek(0, 0)
|
|
|
+ responseReader = bufio.NewReader(io.MultiReader(readBufferReader, tc.Conn))
|
|
|
+ resp, err = http.ReadResponse(responseReader, nil)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
ioutil.ReadAll(resp.Body)
|
|
|
resp.Body.Close()
|
|
|
|
|
|
- if tc.authState == HTTP_AUTH_STATE_UNCHALLENGED {
|
|
|
- tc.authenticator, err = NewHttpAuthenticator(resp)
|
|
|
- if err != nil {
|
|
|
- return 0, err
|
|
|
- }
|
|
|
- tc.authState = HTTP_AUTH_STATE_CHALLENGED
|
|
|
+ if tc.authenticator == nil {
|
|
|
+ tc.authenticator = ha
|
|
|
}
|
|
|
|
|
|
if resp.Close == true {
|
|
|
@@ -178,23 +283,21 @@ func (tc *transportConn) Read(p []byte) (int, error) {
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
- tc.connReader = bufio.NewReader(tc.Conn)
|
|
|
}
|
|
|
|
|
|
- // Authenticate and replay the request
|
|
|
- err = tc.authenticator.Authenticate(tc.lastRequest, resp, tc.transport.Username, tc.transport.Password)
|
|
|
+ // Authenticate and replay the request on the connection
|
|
|
+ err = tc.authenticator.Authenticate(tc.lastRequest, resp)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
tc.lastRequest.WriteProxy(tc)
|
|
|
return tc.Read(p)
|
|
|
}
|
|
|
- case err = <-tc.errChannel:
|
|
|
+ case err := <-tc.errChannel:
|
|
|
return 0, err
|
|
|
default:
|
|
|
}
|
|
|
- n, err := tc.connReader.Read(p)
|
|
|
- return n, err
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
func (tc *transportConn) Write(p []byte) (n int, err error) {
|