|
|
@@ -28,6 +28,7 @@ import (
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"strings"
|
|
|
+ "sync"
|
|
|
)
|
|
|
|
|
|
const HTTP_STAT_LINE_LENGTH = 12
|
|
|
@@ -37,9 +38,12 @@ const HTTP_STAT_LINE_LENGTH = 12
|
|
|
// when requested by server
|
|
|
type ProxyAuthTransport struct {
|
|
|
*http.Transport
|
|
|
- Dial DialFunc
|
|
|
- Username string
|
|
|
- Password string
|
|
|
+ Dial DialFunc
|
|
|
+ Username string
|
|
|
+ Password string
|
|
|
+ Authenticator HttpAuthenticator
|
|
|
+ authState HttpAuthState
|
|
|
+ mu sync.Mutex
|
|
|
}
|
|
|
|
|
|
func NewProxyAuthTransport(rawTransport *http.Transport) (*ProxyAuthTransport, error) {
|
|
|
@@ -47,10 +51,11 @@ func NewProxyAuthTransport(rawTransport *http.Transport) (*ProxyAuthTransport, e
|
|
|
if dialFn == nil {
|
|
|
dialFn = net.Dial
|
|
|
}
|
|
|
- tr := &ProxyAuthTransport{Dial: dialFn}
|
|
|
+ tr := &ProxyAuthTransport{Dial: dialFn, authState: HTTP_AUTH_STATE_UNCHALLENGED}
|
|
|
proxyUrlFn := rawTransport.Proxy
|
|
|
if proxyUrlFn != nil {
|
|
|
wrappedDialFn := tr.wrapTransportDial()
|
|
|
+ rawTransport.Dial = wrappedDialFn
|
|
|
proxyUrl, err := proxyUrlFn(nil)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
@@ -60,18 +65,79 @@ func NewProxyAuthTransport(rawTransport *http.Transport) (*ProxyAuthTransport, e
|
|
|
}
|
|
|
tr.Username = proxyUrl.User.Username()
|
|
|
tr.Password, _ = proxyUrl.User.Password()
|
|
|
- rawTransport.Dial = wrappedDialFn
|
|
|
+ // strip username and password from the proxyURL because
|
|
|
+ // we do not want the wrapped transport to handle authentication
|
|
|
+ proxyUrl.User = nil
|
|
|
+ rawTransport.Proxy = http.ProxyURL(proxyUrl)
|
|
|
}
|
|
|
|
|
|
tr.Transport = rawTransport
|
|
|
return tr, nil
|
|
|
}
|
|
|
|
|
|
+func (tr *ProxyAuthTransport) preAuthenticateRequest(req *http.Request) error {
|
|
|
+ if tr.Authenticator == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ return tr.Authenticator.PreAuthenticate(req)
|
|
|
+}
|
|
|
+
|
|
|
func (tr *ProxyAuthTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
|
|
|
if req.URL.Scheme != "http" {
|
|
|
return nil, fmt.Errorf("Only plain HTTP supported, for HTTPS use http.Transport with DialTLS & upstreamproxy.NewProxyDialFunc")
|
|
|
}
|
|
|
- return tr.Transport.RoundTrip(req)
|
|
|
+ err = tr.preAuthenticateRequest(req)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ //Clone request early because RoundTrip will destroy request Body
|
|
|
+ var ha HttpAuthenticator = nil
|
|
|
+ newReq := cloneRequest(req)
|
|
|
+ //authState := HTTP_AUTH_STATE_UNCHALLENGED
|
|
|
+
|
|
|
+ resp, err = tr.Transport.RoundTrip(newReq)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ return resp, proxyError(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if resp.StatusCode == 407 {
|
|
|
+ fmt.Println("407!")
|
|
|
+ tr.mu.Lock()
|
|
|
+ defer tr.mu.Unlock()
|
|
|
+ if tr.Authenticator == nil {
|
|
|
+ ha, err = NewHttpAuthenticator(resp, tr.Username, tr.Password)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ if ha.IsConnectionBased() {
|
|
|
+ return nil, proxyError(fmt.Errorf("Connection based auth was not handled by transportConn"))
|
|
|
+ }
|
|
|
+ tr.Authenticator = ha
|
|
|
+ }
|
|
|
+ authenticationLoop:
|
|
|
+ for {
|
|
|
+ newReq = cloneRequest(req)
|
|
|
+ err = tr.Authenticator.Authenticate(newReq, resp)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ resp, err = tr.Transport.RoundTrip(newReq)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ return resp, proxyError(err)
|
|
|
+ }
|
|
|
+ if resp.StatusCode != 407 {
|
|
|
+ if tr.Authenticator != nil && tr.Authenticator.IsComplete() {
|
|
|
+ tr.Authenticator.Reset()
|
|
|
+ }
|
|
|
+ break authenticationLoop
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return resp, err
|
|
|
+
|
|
|
}
|
|
|
|
|
|
// wrapTransportDial wraps original transport Dial function
|
|
|
@@ -89,6 +155,28 @@ func (tr *ProxyAuthTransport) wrapTransportDial() DialFunc {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// cloneRequest returns a clone of the provided *http.Request. The clone is a
|
|
|
+// shallow copy of the struct and its Header map.
|
|
|
+func cloneRequest(r *http.Request) *http.Request {
|
|
|
+ // shallow copy of the struct
|
|
|
+ r2 := new(http.Request)
|
|
|
+ *r2 = *r
|
|
|
+ // deep copy of the Header
|
|
|
+ r2.Header = make(http.Header)
|
|
|
+ for k, s := range r.Header {
|
|
|
+ r2.Header[k] = s
|
|
|
+ }
|
|
|
+
|
|
|
+ if r.Body != nil {
|
|
|
+ body, _ := ioutil.ReadAll(r.Body)
|
|
|
+ defer r.Body.Close()
|
|
|
+ //restore original request Body
|
|
|
+ r.Body = ioutil.NopCloser(bytes.NewReader(body))
|
|
|
+ r2.Body = ioutil.NopCloser(bytes.NewReader(body))
|
|
|
+ }
|
|
|
+ return r2
|
|
|
+}
|
|
|
+
|
|
|
type transportConn struct {
|
|
|
net.Conn
|
|
|
requestInterceptor io.Writer
|
|
|
@@ -98,8 +186,8 @@ type transportConn struct {
|
|
|
lastRequest *http.Request
|
|
|
authenticator HttpAuthenticator
|
|
|
authState HttpAuthState
|
|
|
- authCache string
|
|
|
transport *ProxyAuthTransport
|
|
|
+ //mutex *sync.Mutex
|
|
|
}
|
|
|
|
|
|
func newTransportConn(c net.Conn, tr *ProxyAuthTransport) *transportConn {
|
|
|
@@ -139,8 +227,8 @@ func newTransportConn(c net.Conn, tr *ProxyAuthTransport) *transportConn {
|
|
|
|
|
|
// Read peeks into the new response and checks if the proxy requests authentication
|
|
|
// If so, the last intercepted request is authenticated against the response
|
|
|
-func (tc *transportConn) Read(p []byte) (n int, err error) {
|
|
|
- n, err = tc.Conn.Read(p)
|
|
|
+func (tc *transportConn) Read(p []byte) (n int, read_err error) {
|
|
|
+ n, read_err = tc.Conn.Read(p)
|
|
|
if n < HTTP_STAT_LINE_LENGTH {
|
|
|
return
|
|
|
}
|
|
|
@@ -150,24 +238,39 @@ func (tc *transportConn) Read(p []byte) (n int, err error) {
|
|
|
//This is a new response
|
|
|
//Let's see if proxy requests authentication
|
|
|
f := strings.SplitN(line, " ", 2)
|
|
|
- readBufferReader := bytes.NewReader(p)
|
|
|
- responseReader := io.MultiReader(readBufferReader, tc.Conn)
|
|
|
+
|
|
|
+ readBufferReader := io.NewSectionReader(bytes.NewReader(p), 0, int64(n))
|
|
|
+ responseReader := bufio.NewReader(readBufferReader)
|
|
|
if (f[0] == "HTTP/1.0" || f[0] == "HTTP/1.1") && f[1] == "407" {
|
|
|
- resp, err := http.ReadResponse(bufio.NewReader(responseReader), nil)
|
|
|
+ resp, err := http.ReadResponse(responseReader, nil)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
- // make sure we read the body of the response so that
|
|
|
- // we don't block the reader
|
|
|
+ ha, err := NewHttpAuthenticator(resp, tc.transport.Username, tc.transport.Password)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ // If connection based auth is requested, we are going to
|
|
|
+ // authenticate this very connection
|
|
|
+
|
|
|
+ if !ha.IsConnectionBased() {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Drain the rest of the response
|
|
|
+ // in order to perform auth handshake
|
|
|
+ readBufferReader.Seek(0, 0)
|
|
|
+ responseReader = bufio.NewReader(io.MultiReader(readBufferReader, tc.Conn))
|
|
|
+ resp, err = http.ReadResponse(responseReader, nil)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
ioutil.ReadAll(resp.Body)
|
|
|
resp.Body.Close()
|
|
|
|
|
|
- if tc.authState == HTTP_AUTH_STATE_UNCHALLENGED {
|
|
|
- tc.authenticator, err = NewHttpAuthenticator(resp)
|
|
|
- if err != nil {
|
|
|
- return 0, err
|
|
|
- }
|
|
|
- tc.authState = HTTP_AUTH_STATE_CHALLENGED
|
|
|
+ if tc.authenticator == nil {
|
|
|
+ tc.authenticator = ha
|
|
|
}
|
|
|
|
|
|
if resp.Close == true {
|
|
|
@@ -181,15 +284,15 @@ func (tc *transportConn) Read(p []byte) (n int, err error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // Authenticate and replay the request
|
|
|
- err = tc.authenticator.Authenticate(tc.lastRequest, resp, tc.transport.Username, tc.transport.Password)
|
|
|
+ // Authenticate and replay the request on the connection
|
|
|
+ err = tc.authenticator.Authenticate(tc.lastRequest, resp)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
tc.lastRequest.WriteProxy(tc)
|
|
|
return tc.Read(p)
|
|
|
}
|
|
|
- case err = <-tc.errChannel:
|
|
|
+ case err := <-tc.errChannel:
|
|
|
return 0, err
|
|
|
default:
|
|
|
}
|