|
|
@@ -8,9 +8,7 @@ import (
|
|
|
"io/ioutil"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
- "net/url"
|
|
|
"strings"
|
|
|
- "time"
|
|
|
)
|
|
|
|
|
|
type ProxyAuthTransport struct {
|
|
|
@@ -20,31 +18,45 @@ type ProxyAuthTransport struct {
|
|
|
Password string
|
|
|
}
|
|
|
|
|
|
-func NewProxyAuthTransport(proxy string, dialFn DialFunc, responseHeaderTimeout time.Duration) (*ProxyAuthTransport, error) {
|
|
|
- tr := &ProxyAuthTransport{Dial: dialFn}
|
|
|
-
|
|
|
- wrappedDialFn := tr.wrapTransportDial()
|
|
|
- proxyUrl, err := url.Parse(proxy)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
+func NewProxyAuthTransport(rawTransport *http.Transport) (*ProxyAuthTransport, error) {
|
|
|
+ dialFn := rawTransport.Dial
|
|
|
+ if dialFn == nil {
|
|
|
+ dialFn = net.Dial
|
|
|
}
|
|
|
- tr.Username = proxyUrl.User.Username()
|
|
|
- tr.Password, _ = proxyUrl.User.Password()
|
|
|
- tr.Transport = &http.Transport{
|
|
|
- Dial: wrappedDialFn,
|
|
|
- Proxy: http.ProxyURL(proxyUrl),
|
|
|
- ResponseHeaderTimeout: responseHeaderTimeout,
|
|
|
+ tr := &ProxyAuthTransport{Dial: dialFn}
|
|
|
+ proxyUrlFn := rawTransport.Proxy
|
|
|
+ if proxyUrlFn != nil {
|
|
|
+ wrappedDialFn := tr.wrapTransportDial()
|
|
|
+ proxyUrl, err := proxyUrlFn(nil)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ if proxyUrl.Scheme != "http" {
|
|
|
+ return nil, fmt.Errorf("Only HTTP proxysupported, for SOCKS use http.Transport with custom dialers & upstreamproxy.NewProxyDialFunc")
|
|
|
+ }
|
|
|
+ tr.Username = proxyUrl.User.Username()
|
|
|
+ tr.Password, _ = proxyUrl.User.Password()
|
|
|
+ rawTransport.Dial = wrappedDialFn
|
|
|
}
|
|
|
+
|
|
|
+ tr.Transport = rawTransport
|
|
|
return tr, nil
|
|
|
}
|
|
|
|
|
|
+func (tr *ProxyAuthTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
|
|
|
+ if req.URL.Scheme != "http" {
|
|
|
+ return nil, fmt.Errorf("Only plain HTTP supported, for HTTPS use http.Transport with DialTLS & upstreamproxy.NewProxyDialFunc")
|
|
|
+ }
|
|
|
+ return tr.Transport.RoundTrip(req)
|
|
|
+}
|
|
|
+
|
|
|
func (tr *ProxyAuthTransport) wrapTransportDial() DialFunc {
|
|
|
return func(network, addr string) (net.Conn, error) {
|
|
|
c, err := tr.Dial("tcp", addr)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
- tc := newTransportConn(c, tr.Dial, tr)
|
|
|
+ tc := newTransportConn(c, tr)
|
|
|
return tc, nil
|
|
|
}
|
|
|
}
|
|
|
@@ -53,32 +65,35 @@ type transportConn struct {
|
|
|
net.Conn
|
|
|
requestWriter io.Writer
|
|
|
reqDone chan struct{}
|
|
|
+ errChannel chan error
|
|
|
connReader *bufio.Reader
|
|
|
lastRequest *http.Request
|
|
|
- Dial DialFunc
|
|
|
authenticator HttpAuthenticator
|
|
|
authState HttpAuthState
|
|
|
transport *ProxyAuthTransport
|
|
|
}
|
|
|
|
|
|
-func newTransportConn(c net.Conn, dialFn DialFunc, tr *ProxyAuthTransport) *transportConn {
|
|
|
+func newTransportConn(c net.Conn, tr *ProxyAuthTransport) *transportConn {
|
|
|
tc := &transportConn{
|
|
|
Conn: c,
|
|
|
reqDone: make(chan struct{}),
|
|
|
+ errChannel: make(chan error),
|
|
|
connReader: bufio.NewReader(c),
|
|
|
- Dial: dialFn,
|
|
|
transport: tr,
|
|
|
}
|
|
|
+ pr, pw := io.Pipe()
|
|
|
+ tc.requestWriter = pw
|
|
|
go func() {
|
|
|
- pr, pw := io.Pipe()
|
|
|
- defer pr.Close()
|
|
|
- defer pw.Close()
|
|
|
- tc.requestWriter = pw
|
|
|
+ requestInterceptLoop:
|
|
|
for {
|
|
|
- //Request intercepting loop
|
|
|
req, err := http.ReadRequest(bufio.NewReader(pr))
|
|
|
if err != nil {
|
|
|
- fmt.Println("http.ReadRequest error: ", err)
|
|
|
+ tc.Close()
|
|
|
+ tc.errChannel <- fmt.Errorf("intercept request loop http.ReadRequest error: %s", err)
|
|
|
+ pr.Close()
|
|
|
+ pw.Close()
|
|
|
+ tc.Close()
|
|
|
+ break requestInterceptLoop
|
|
|
}
|
|
|
//read and copy entire body
|
|
|
body, _ := ioutil.ReadAll(req.Body)
|
|
|
@@ -130,10 +145,11 @@ func (tc *transportConn) Read(p []byte) (int, error) {
|
|
|
// dial a new one
|
|
|
addr := tc.Conn.RemoteAddr()
|
|
|
tc.Conn.Close()
|
|
|
- tc.Conn, err = tc.Dial(addr.Network(), addr.String())
|
|
|
+ tc.Conn, err = tc.transport.Dial(addr.Network(), addr.String())
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
+ tc.connReader = bufio.NewReader(tc.Conn)
|
|
|
}
|
|
|
|
|
|
err = tc.authenticator.Authenticate(tc.lastRequest, resp, tc.transport.Username, tc.transport.Password)
|
|
|
@@ -143,9 +159,13 @@ func (tc *transportConn) Read(p []byte) (int, error) {
|
|
|
|
|
|
//TODO: eliminate possible race condition
|
|
|
//Replay authenticated request
|
|
|
+ //block until a new cycle of the request intercept loop started
|
|
|
+ //<-tc.writeBlocker
|
|
|
tc.lastRequest.WriteProxy(tc)
|
|
|
return tc.Read(p)
|
|
|
}
|
|
|
+ case err = <-tc.errChannel:
|
|
|
+ return 0, err
|
|
|
default:
|
|
|
}
|
|
|
n, err := tc.connReader.Read(p)
|