Преглед изворни кода

Fix: cancel timers when no longer needed

Always call Timer.Stop to ensure timer resources are cleaned up
when a timer or Timer.AfterFunc is no longer needed. Don't use
time.After, as it provides no way to cancel its timer.

One serious case was Controller.connectedReporter, which started
but never canceled a timer after every connected event. Since the
timer duration in that case is 24 hours, many unnecessary timer
objects would accumulate when frequently reconnecting.
Rod Hynes пре 8 година
родитељ
комит
c768e570dd
5 измењених фајлова са 48 додато и 21 уклоњено
  1. 22 11
      psiphon/controller.go
  2. 2 0
      psiphon/meekConn.go
  3. 7 2
      psiphon/server/tunnelServer.go
  4. 2 1
      psiphon/tlsDialer.go
  5. 15 7
      psiphon/tunnel.go

+ 22 - 11
psiphon/controller.go

@@ -401,10 +401,11 @@ fetcherLoop:
 
 			NoticeAlert("failed to fetch %s remote server list: %s", name, err)
 
-			timeout := time.After(retryPeriod)
+			timer := time.NewTimer(retryPeriod)
 			select {
-			case <-timeout:
+			case <-timer.C:
 			case <-controller.shutdownBroadcast:
+				timer.Stop()
 				break fetcherLoop
 			}
 		}
@@ -421,10 +422,12 @@ fetcherLoop:
 func (controller *Controller) establishTunnelWatcher() {
 	defer controller.runWaitGroup.Done()
 
-	timeout := time.After(
+	timer := time.NewTimer(
 		time.Duration(*controller.config.EstablishTunnelTimeoutSeconds) * time.Second)
+	defer timer.Stop()
+
 	select {
-	case <-timeout:
+	case <-timer.C:
 		if !controller.hasEstablishedOnce() {
 			NoticeAlert("failed to establish tunnel before timeout")
 			controller.SignalComponentFailure()
@@ -468,13 +471,14 @@ loop:
 		} else {
 			duration = PSIPHON_API_CONNECTED_REQUEST_RETRY_PERIOD
 		}
-		timeout := time.After(duration)
+		timer := time.NewTimer(duration)
 		select {
 		case <-controller.signalReportConnected:
-		case <-timeout:
+			timer.Stop()
+		case <-timer.C:
 			// Make another connected request
-
 		case <-controller.shutdownBroadcast:
+			timer.Stop()
 			break loop
 		}
 	}
@@ -570,11 +574,12 @@ downloadLoop:
 
 			NoticeAlert("failed to download upgrade: %s", err)
 
-			timeout := time.After(
+			timer := time.NewTimer(
 				time.Duration(*controller.config.DownloadUpgradeRetryPeriodSeconds) * time.Second)
 			select {
-			case <-timeout:
+			case <-timer.C:
 			case <-controller.shutdownBroadcast:
+				timer.Stop()
 				break downloadLoop
 			}
 		}
@@ -1340,8 +1345,10 @@ loop:
 				case <-timer.C:
 				case <-controller.serverAffinityDoneBroadcast:
 				case <-controller.stopEstablishingBroadcast:
+					timer.Stop()
 					break loop
 				case <-controller.shutdownBroadcast:
+					timer.Stop()
 					break loop
 				}
 			} else if controller.config.StaggerConnectionWorkersMilliseconds != 0 {
@@ -1353,8 +1360,10 @@ loop:
 				select {
 				case <-timer.C:
 				case <-controller.stopEstablishingBroadcast:
+					timer.Stop()
 					break loop
 				case <-controller.shutdownBroadcast:
+					timer.Stop()
 					break loop
 				}
 			}
@@ -1396,14 +1405,16 @@ loop:
 		// network conditions to change. Also allows for fetch remote to complete,
 		// in typical conditions (it isn't strictly necessary to wait for this, there will
 		// be more rounds if required).
-		timeout := time.After(
+		timer := time.NewTimer(
 			time.Duration(*controller.config.EstablishTunnelPausePeriodSeconds) * time.Second)
 		select {
-		case <-timeout:
+		case <-timer.C:
 			// Retry iterating
 		case <-controller.stopEstablishingBroadcast:
+			timer.Stop()
 			break loop
 		case <-controller.shutdownBroadcast:
+			timer.Stop()
 			break loop
 		}
 

+ 2 - 0
psiphon/meekConn.go

@@ -629,6 +629,7 @@ func (meek *MeekConn) relay() {
 		MIN_POLL_INTERVAL_JITTER)
 
 	timeout := time.NewTimer(interval)
+	defer timeout.Stop()
 
 	for {
 		timeout.Reset(interval)
@@ -954,6 +955,7 @@ func (meek *MeekConn) roundTrip(sendBuffer *bytes.Buffer) (int64, error) {
 		select {
 		case <-delayTimer.C:
 		case <-meek.runContext.Done():
+			delayTimer.Stop()
 			return 0, common.ContextError(err)
 		}
 

+ 7 - 2
psiphon/server/tunnelServer.go

@@ -865,8 +865,9 @@ func (sshClient *sshClient) run(clientConn net.Conn) {
 
 	resultChannel := make(chan *sshNewServerConnResult, 2)
 
+	var afterFunc *time.Timer
 	if SSH_HANDSHAKE_TIMEOUT > 0 {
-		time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
+		afterFunc = time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
 			resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
 		})
 	}
@@ -909,11 +910,15 @@ func (sshClient *sshClient) run(clientConn net.Conn) {
 	case result = <-resultChannel:
 	case <-sshClient.sshServer.shutdownBroadcast:
 		// Close() will interrupt an ongoing handshake
-		// TODO: wait for goroutine to exit before returning?
+		// TODO: wait for SSH handshake goroutines to exit before returning?
 		clientConn.Close()
 		return
 	}
 
+	if afterFunc != nil {
+		afterFunc.Stop()
+	}
+
 	if result.err != nil {
 		clientConn.Close()
 		// This is a Debug log due to noise. The handshake often fails due to I/O

+ 2 - 1
psiphon/tlsDialer.go

@@ -195,9 +195,10 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (net.Conn, err
 	var errChannel chan error
 	if config.Timeout != 0 {
 		errChannel = make(chan error, 2)
-		time.AfterFunc(config.Timeout, func() {
+		timeoutFunc := time.AfterFunc(config.Timeout, func() {
 			errChannel <- errors.New("timed out")
 		})
+		defer timeoutFunc.Stop()
 	}
 
 	dialAddr := addr

+ 15 - 7
psiphon/tunnel.go

@@ -243,6 +243,9 @@ func (tunnel *Tunnel) Activate(
 				tunnel.Close(true)
 				<-resultChannel
 			}
+
+			timer.Stop()
+
 		} else {
 
 			select {
@@ -319,10 +322,10 @@ func (tunnel *Tunnel) Close(isDiscarded bool) {
 		// precedence over the PSIPHON_API_SERVER_TIMEOUT http.Client.Timeout
 		// value set in makePsiphonHttpsClient.
 		if isActivated {
-			timer := time.AfterFunc(TUNNEL_OPERATE_SHUTDOWN_TIMEOUT, func() { tunnel.conn.Close() })
+			afterFunc := time.AfterFunc(TUNNEL_OPERATE_SHUTDOWN_TIMEOUT, func() { tunnel.conn.Close() })
 			close(tunnel.shutdownOperateBroadcast)
 			tunnel.operateWaitGroup.Wait()
-			timer.Stop()
+			afterFunc.Stop()
 		}
 		tunnel.sshClient.Close()
 		// tunnel.conn.Close() may get called multiple times, which is allowed.
@@ -385,9 +388,10 @@ func (tunnel *Tunnel) Dial(
 	}
 	resultChannel := make(chan *tunnelDialResult, 2)
 	if *tunnel.config.TunnelPortForwardDialTimeoutSeconds > 0 {
-		time.AfterFunc(time.Duration(*tunnel.config.TunnelPortForwardDialTimeoutSeconds)*time.Second, func() {
+		afterFunc := time.AfterFunc(time.Duration(*tunnel.config.TunnelPortForwardDialTimeoutSeconds)*time.Second, func() {
 			resultChannel <- &tunnelDialResult{nil, errors.New("tunnel dial timeout")}
 		})
+		defer afterFunc.Stop()
 	}
 	go func() {
 		sshPortForwardConn, err := tunnel.sshClient.Dial("tcp", remoteAddr)
@@ -990,9 +994,10 @@ func dialSsh(
 	}
 	resultChannel := make(chan *sshNewClientResult, 2)
 	if *config.TunnelConnectTimeoutSeconds > 0 {
-		time.AfterFunc(time.Duration(*config.TunnelConnectTimeoutSeconds)*time.Second, func() {
+		afterFunc := time.AfterFunc(time.Duration(*config.TunnelConnectTimeoutSeconds)*time.Second, func() {
 			resultChannel <- &sshNewClientResult{nil, nil, errors.New("ssh dial timeout")}
 		})
+		defer afterFunc.Stop()
 	}
 
 	go func() {
@@ -1216,11 +1221,13 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 				if failCount > PSIPHON_API_CLIENT_VERIFICATION_REQUEST_MAX_RETRIES {
 					return
 				}
-				timeout := time.After(PSIPHON_API_CLIENT_VERIFICATION_REQUEST_RETRY_PERIOD)
+				timer := time.NewTimer(PSIPHON_API_CLIENT_VERIFICATION_REQUEST_RETRY_PERIOD)
 				select {
-				case <-timeout:
+				case <-timer.C:
 				case clientVerificationPayload = <-tunnel.newClientVerificationPayload:
+					timer.Stop()
 				case <-signalStopClientVerificationRequests:
+					timer.Stop()
 					return
 				}
 			}
@@ -1430,9 +1437,10 @@ func sendSshKeepAlive(
 
 	errChannel := make(chan error, 2)
 	if timeout > 0 {
-		time.AfterFunc(timeout, func() {
+		afterFunc := time.AfterFunc(timeout, func() {
 			errChannel <- errors.New("timed out")
 		})
+		defer afterFunc.Stop()
 	}
 
 	go func() {