Browse Source

Fix: WaitGroup panic in sshclient.stop()
- panic when port forward was created at the same time
as a different goroutine (duplicate session closer)
stopped a client.
- rearrange code so that only sshClient.runTunnel()
creates and waits for client worker goroutines.
- also rearrange sshClient.run() and sshClient.stop():
stop() signals run() to stop, and run() is responsible
for all cleanup, including logging tunnel stats after
all port forward goroutines have terminated.

Rod Hynes 9 years ago
parent
commit
a337accb8e
1 changed files with 92 additions and 87 deletions
  1. 92 87
      psiphon/server/tunnelServer.go

+ 92 - 87
psiphon/server/tunnelServer.go

@@ -586,25 +586,24 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 
 type sshClient struct {
 	sync.Mutex
-	sshServer               *sshServer
-	tunnelProtocol          string
-	sshConn                 ssh.Conn
-	activityConn            *common.ActivityMonitoredConn
-	throttledConn           *common.ThrottledConn
-	geoIPData               GeoIPData
-	sessionID               string
-	supportsServerRequests  bool
-	handshakeState          handshakeState
-	udpChannel              ssh.Channel
-	trafficRules            TrafficRules
-	tcpTrafficState         trafficState
-	udpTrafficState         trafficState
-	qualityMetrics          qualityMetrics
-	channelHandlerWaitGroup *sync.WaitGroup
-	tcpPortForwardLRU       *common.LRUConns
-	oslClientSeedState      *osl.ClientSeedState
-	signalIssueSLOKs        chan struct{}
-	stopBroadcast           chan struct{}
+	sshServer              *sshServer
+	tunnelProtocol         string
+	sshConn                ssh.Conn
+	activityConn           *common.ActivityMonitoredConn
+	throttledConn          *common.ThrottledConn
+	geoIPData              GeoIPData
+	sessionID              string
+	supportsServerRequests bool
+	handshakeState         handshakeState
+	udpChannel             ssh.Channel
+	trafficRules           TrafficRules
+	tcpTrafficState        trafficState
+	udpTrafficState        trafficState
+	qualityMetrics         qualityMetrics
+	tcpPortForwardLRU      *common.LRUConns
+	oslClientSeedState     *osl.ClientSeedState
+	signalIssueSLOKs       chan struct{}
+	stopBroadcast          chan struct{}
 }
 
 type trafficState struct {
@@ -636,13 +635,12 @@ type handshakeState struct {
 func newSshClient(
 	sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
 	return &sshClient{
-		sshServer:               sshServer,
-		tunnelProtocol:          tunnelProtocol,
-		geoIPData:               geoIPData,
-		channelHandlerWaitGroup: new(sync.WaitGroup),
-		tcpPortForwardLRU:       common.NewLRUConns(),
-		signalIssueSLOKs:        make(chan struct{}, 1),
-		stopBroadcast:           make(chan struct{}),
+		sshServer:         sshServer,
+		tunnelProtocol:    tunnelProtocol,
+		geoIPData:         geoIPData,
+		tcpPortForwardLRU: common.NewLRUConns(),
+		signalIssueSLOKs:  make(chan struct{}, 1),
+		stopBroadcast:     make(chan struct{}),
 	}
 }
 
@@ -760,12 +758,20 @@ func (sshClient *sshClient) run(clientConn net.Conn) {
 		log.WithContext().Warning("register failed")
 		return
 	}
-	defer sshClient.sshServer.unregisterEstablishedClient(sshClient.sessionID)
 
 	sshClient.runTunnel(result.channels, result.requests)
 
 	// Note: sshServer.unregisterEstablishedClient calls sshClient.stop(),
 	// which also closes underlying transport Conn.
+
+	sshClient.sshServer.unregisterEstablishedClient(sshClient.sessionID)
+
+	sshClient.logTunnel()
+
+	// Initiate cleanup of the GeoIP session cache. To allow for post-tunnel
+	// final status requests, the lifetime of cached GeoIP records exceeds the
+	// lifetime of the sshClient.
+	sshClient.sshServer.support.GeoIPService.MarkSessionCacheToExpire(sshClient.sessionID)
 }
 
 func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
@@ -863,54 +869,13 @@ func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string
 	}
 }
 
+// stop signals the ssh connection to shutdown. After sshConn() returns,
+// the connection has terminated but sshClient.run() may still be
+// running and in the process of exiting.
 func (sshClient *sshClient) stop() {
 
 	sshClient.sshConn.Close()
 	sshClient.sshConn.Wait()
-
-	close(sshClient.stopBroadcast)
-	sshClient.channelHandlerWaitGroup.Wait()
-
-	// Note: reporting duration based on last confirmed data transfer, which
-	// is reads for sshClient.activityConn.GetActiveDuration(), and not
-	// connection closing is important for protocols such as meek. For
-	// meek, the connection remains open until the HTTP session expires,
-	// which may be some time after the tunnel has closed. (The meek
-	// protocol has no allowance for signalling payload EOF, and even if
-	// it did the client may not have the opportunity to send a final
-	// request with an EOF flag set.)
-
-	sshClient.Lock()
-
-	logFields := getRequestLogFields(
-		sshClient.sshServer.support,
-		"server_tunnel",
-		sshClient.geoIPData,
-		sshClient.handshakeState.apiParams,
-		baseRequestParams)
-
-	logFields["handshake_completed"] = sshClient.handshakeState.completed
-	logFields["start_time"] = sshClient.activityConn.GetStartTime()
-	logFields["duration"] = sshClient.activityConn.GetActiveDuration() / time.Millisecond
-	logFields["bytes_up_tcp"] = sshClient.tcpTrafficState.bytesUp
-	logFields["bytes_down_tcp"] = sshClient.tcpTrafficState.bytesDown
-	logFields["peak_concurrent_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentPortForwardCount
-	logFields["total_port_forward_count_tcp"] = sshClient.tcpTrafficState.totalPortForwardCount
-	logFields["bytes_up_udp"] = sshClient.udpTrafficState.bytesUp
-	logFields["bytes_down_udp"] = sshClient.udpTrafficState.bytesDown
-	logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
-	logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
-
-	sessionID := sshClient.sessionID
-
-	sshClient.Unlock()
-
-	// Initiate cleanup of the GeoIP session cache. To allow for post-tunnel
-	// final status requests, the lifetime of cached GeoIP records exceeds the
-	// lifetime of the sshClient.
-	sshClient.sshServer.support.GeoIPService.MarkSessionCacheToExpire(sessionID)
-
-	log.LogRawFieldsWithTimestamp(logFields)
 }
 
 // runTunnel handles/dispatches new channel and new requests from the client.
@@ -919,13 +884,11 @@ func (sshClient *sshClient) stop() {
 func (sshClient *sshClient) runTunnel(
 	channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
 
-	stopBroadcast := make(chan struct{})
+	waitGroup := new(sync.WaitGroup)
 
-	requestsWaitGroup := new(sync.WaitGroup)
-
-	requestsWaitGroup.Add(1)
+	waitGroup.Add(1)
 	go func() {
-		defer requestsWaitGroup.Done()
+		defer waitGroup.Done()
 
 		for request := range requests {
 
@@ -959,10 +922,10 @@ func (sshClient *sshClient) runTunnel(
 	}()
 
 	if sshClient.supportsServerRequests {
-		requestsWaitGroup.Add(1)
+		waitGroup.Add(1)
 		go func() {
-			defer requestsWaitGroup.Done()
-			sshClient.runOSLSender(stopBroadcast)
+			defer waitGroup.Done()
+			sshClient.runOSLSender()
 		}()
 	}
 
@@ -974,23 +937,66 @@ func (sshClient *sshClient) runTunnel(
 		}
 
 		// process each port forward concurrently
-		sshClient.channelHandlerWaitGroup.Add(1)
-		go sshClient.handleNewPortForwardChannel(newChannel)
+		waitGroup.Add(1)
+		go func() {
+			defer waitGroup.Done()
+			sshClient.handleNewPortForwardChannel(newChannel)
+		}()
 	}
 
-	close(stopBroadcast)
+	// The channel loop is interrupted by a client
+	// disconnect or by calling sshClient.stop().
+
+	close(sshClient.stopBroadcast)
+
+	waitGroup.Wait()
+}
+
+func (sshClient *sshClient) logTunnel() {
+
+	// Note: reporting duration based on last confirmed data transfer, which
+	// is reads for sshClient.activityConn.GetActiveDuration(), and not
+	// connection closing is important for protocols such as meek. For
+	// meek, the connection remains open until the HTTP session expires,
+	// which may be some time after the tunnel has closed. (The meek
+	// protocol has no allowance for signalling payload EOF, and even if
+	// it did the client may not have the opportunity to send a final
+	// request with an EOF flag set.)
+
+	sshClient.Lock()
+
+	logFields := getRequestLogFields(
+		sshClient.sshServer.support,
+		"server_tunnel",
+		sshClient.geoIPData,
+		sshClient.handshakeState.apiParams,
+		baseRequestParams)
+
+	logFields["handshake_completed"] = sshClient.handshakeState.completed
+	logFields["start_time"] = sshClient.activityConn.GetStartTime()
+	logFields["duration"] = sshClient.activityConn.GetActiveDuration() / time.Millisecond
+	logFields["bytes_up_tcp"] = sshClient.tcpTrafficState.bytesUp
+	logFields["bytes_down_tcp"] = sshClient.tcpTrafficState.bytesDown
+	logFields["peak_concurrent_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentPortForwardCount
+	logFields["total_port_forward_count_tcp"] = sshClient.tcpTrafficState.totalPortForwardCount
+	logFields["bytes_up_udp"] = sshClient.udpTrafficState.bytesUp
+	logFields["bytes_down_udp"] = sshClient.udpTrafficState.bytesDown
+	logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
+	logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
+
+	sshClient.Unlock()
 
-	requestsWaitGroup.Wait()
+	log.LogRawFieldsWithTimestamp(logFields)
 }
 
-func (sshClient *sshClient) runOSLSender(stopBroadcast <-chan struct{}) {
+func (sshClient *sshClient) runOSLSender() {
 
 	for {
 		// Await a signal that there are SLOKs to send
 		// TODO: use reflect.SelectCase, and optionally await timer here?
 		select {
 		case <-sshClient.signalIssueSLOKs:
-		case <-stopBroadcast:
+		case <-sshClient.stopBroadcast:
 			return
 		}
 
@@ -1008,7 +1014,7 @@ func (sshClient *sshClient) runOSLSender(stopBroadcast <-chan struct{}) {
 			select {
 			case <-retryTimer.C:
 			case <-sshClient.signalIssueSLOKs:
-			case <-stopBroadcast:
+			case <-sshClient.stopBroadcast:
 				retryTimer.Stop()
 				return
 			}
@@ -1070,7 +1076,6 @@ func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason s
 }
 
 func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChannel) {
-	defer sshClient.channelHandlerWaitGroup.Done()
 
 	// http://tools.ietf.org/html/rfc4254#section-7.2
 	var directTcpipExtraData struct {