|
|
@@ -36,6 +36,7 @@ type HttpProxy struct {
|
|
|
listener net.Listener
|
|
|
waitGroup *sync.WaitGroup
|
|
|
httpRelay *http.Transport
|
|
|
+ openConns map[net.Conn]bool
|
|
|
}
|
|
|
|
|
|
// NewHttpProxy initializes and runs a new HTTP proxy server.
|
|
|
@@ -59,6 +60,7 @@ func NewHttpProxy(listenPort int, tunnel *Tunnel, stoppedSignal chan struct{}) (
|
|
|
listener: listener,
|
|
|
waitGroup: new(sync.WaitGroup),
|
|
|
httpRelay: transport,
|
|
|
+ openConns: make(map[net.Conn]bool),
|
|
|
}
|
|
|
proxy.waitGroup.Add(1)
|
|
|
go proxy.serveHttpRequests()
|
|
|
@@ -70,6 +72,10 @@ func NewHttpProxy(listenPort int, tunnel *Tunnel, stoppedSignal chan struct{}) (
|
|
|
func (proxy *HttpProxy) Close() {
|
|
|
proxy.listener.Close()
|
|
|
proxy.waitGroup.Wait()
|
|
|
+ // Close local->proxy persistent connections
|
|
|
+ proxy.closeOpenConns()
|
|
|
+ // Close idle proxy->origin persistent connections
|
|
|
+ // TODO: also close active connections
|
|
|
proxy.httpRelay.CloseIdleConnections()
|
|
|
}
|
|
|
|
|
|
@@ -99,7 +105,7 @@ func (proxy *HttpProxy) ServeHTTP(responseWriter http.ResponseWriter, request *h
|
|
|
return
|
|
|
}
|
|
|
go func() {
|
|
|
- err := httpConnectHandler(proxy.tunnel, conn, request.URL.Host)
|
|
|
+ err := proxy.httpConnectHandler(proxy.tunnel, conn, request.URL.Host)
|
|
|
if err != nil {
|
|
|
Notice(NOTICE_ALERT, "%s", ContextError(err))
|
|
|
}
|
|
|
@@ -121,7 +127,7 @@ func (proxy *HttpProxy) ServeHTTP(responseWriter http.ResponseWriter, request *h
|
|
|
response, err := proxy.httpRelay.RoundTrip(request)
|
|
|
if err != nil {
|
|
|
Notice(NOTICE_ALERT, "%s", ContextError(err))
|
|
|
- http.Error(responseWriter, "", http.StatusInternalServerError)
|
|
|
+ forceClose(responseWriter)
|
|
|
return
|
|
|
}
|
|
|
defer response.Body.Close()
|
|
|
@@ -142,11 +148,22 @@ func (proxy *HttpProxy) ServeHTTP(responseWriter http.ResponseWriter, request *h
|
|
|
_, err = io.Copy(responseWriter, response.Body)
|
|
|
if err != nil {
|
|
|
Notice(NOTICE_ALERT, "%s", ContextError(err))
|
|
|
- http.Error(responseWriter, "", http.StatusInternalServerError)
|
|
|
+ forceClose(responseWriter)
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// forceClose hijacks and closes persistent connections. This is used
|
|
|
+// to ensure local persistent connections into the HTTP proxy are closed
|
|
|
+// when ServeHTTP encounters an error.
|
|
|
+func forceClose(responseWriter http.ResponseWriter) {
|
|
|
+ hijacker, _ := responseWriter.(http.Hijacker)
|
|
|
+ conn, _, err := hijacker.Hijack()
|
|
|
+ if err == nil {
|
|
|
+ conn.Close()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
// From // https://golang.org/src/pkg/net/http/httputil/reverseproxy.go:
|
|
|
// Hop-by-hop headers. These are removed when sent to the backend.
|
|
|
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
|
|
@@ -162,8 +179,10 @@ var hopHeaders = []string{
|
|
|
"Upgrade",
|
|
|
}
|
|
|
|
|
|
-func httpConnectHandler(tunnel *Tunnel, localHttpConn net.Conn, target string) (err error) {
|
|
|
+func (proxy *HttpProxy) httpConnectHandler(tunnel *Tunnel, localHttpConn net.Conn, target string) (err error) {
|
|
|
defer localHttpConn.Close()
|
|
|
+ defer proxy.removeOpenConn(localHttpConn)
|
|
|
+ proxy.addOpenConn(localHttpConn)
|
|
|
remoteSshForward, err := tunnel.sshClient.Dial("tcp", target)
|
|
|
if err != nil {
|
|
|
return ContextError(err)
|
|
|
@@ -177,11 +196,44 @@ func httpConnectHandler(tunnel *Tunnel, localHttpConn net.Conn, target string) (
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+// httpConnStateCallback is called by http.Server when the state of a local->proxy
|
|
|
+// connection changes. Open connections are tracked so that all local->proxy persistent
|
|
|
+// connections can be closed by HttpProxy.Close()
|
|
|
+// TODO: if the HttpProxy is decoupled from a single Tunnel instance and
|
|
|
+// instead uses the "current" Tunnel, it may not be necessary to close
|
|
|
+// local persistent connections when the tunnel reconnects.
|
|
|
+func (proxy *HttpProxy) httpConnStateCallback(conn net.Conn, connState http.ConnState) {
|
|
|
+ switch connState {
|
|
|
+ case http.StateNew:
|
|
|
+ proxy.addOpenConn(conn)
|
|
|
+ case http.StateActive, http.StateIdle:
|
|
|
+ // No action
|
|
|
+ case http.StateHijacked, http.StateClosed:
|
|
|
+ proxy.removeOpenConn(conn)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (proxy *HttpProxy) addOpenConn(conn net.Conn) {
|
|
|
+ proxy.openConns[conn] = true
|
|
|
+}
|
|
|
+
|
|
|
+func (proxy *HttpProxy) removeOpenConn(conn net.Conn) {
|
|
|
+ delete(proxy.openConns, conn)
|
|
|
+}
|
|
|
+
|
|
|
+func (proxy *HttpProxy) closeOpenConns() {
|
|
|
+ for conn, _ := range proxy.openConns {
|
|
|
+ conn.Close()
|
|
|
+ }
|
|
|
+ proxy.openConns = make(map[net.Conn]bool)
|
|
|
+}
|
|
|
+
|
|
|
func (proxy *HttpProxy) serveHttpRequests() {
|
|
|
defer proxy.listener.Close()
|
|
|
defer proxy.waitGroup.Done()
|
|
|
httpServer := &http.Server{
|
|
|
- Handler: proxy,
|
|
|
+ Handler: proxy,
|
|
|
+ ConnState: proxy.httpConnStateCallback,
|
|
|
}
|
|
|
// Note: will be interrupted by listener.Close() call made by proxy.Close()
|
|
|
err := httpServer.Serve(proxy.listener)
|
|
|
@@ -191,7 +243,6 @@ func (proxy *HttpProxy) serveHttpRequests() {
|
|
|
default:
|
|
|
}
|
|
|
Notice(NOTICE_ALERT, "%s", ContextError(err))
|
|
|
- return
|
|
|
}
|
|
|
Notice(NOTICE_HTTP_PROXY, "HTTP proxy stopped")
|
|
|
}
|