Преглед изворни кода

Merge pull request #427 from rod-hynes/master

Fix: cancel timers when no longer needed
Rod Hynes пре 8 година
родитељ
комит
e794576bb7
5 измењених фајлова са 92 додато и 33 уклоњено
  1. 37 14
      psiphon/controller.go
  2. 28 9
      psiphon/meekConn.go
  3. 7 2
      psiphon/server/tunnelServer.go
  4. 2 1
      psiphon/tlsDialer.go
  5. 18 7
      psiphon/tunnel.go

+ 37 - 14
psiphon/controller.go

@@ -401,10 +401,11 @@ fetcherLoop:
 
 
 			NoticeAlert("failed to fetch %s remote server list: %s", name, err)
 			NoticeAlert("failed to fetch %s remote server list: %s", name, err)
 
 
-			timeout := time.After(retryPeriod)
+			timer := time.NewTimer(retryPeriod)
 			select {
 			select {
-			case <-timeout:
+			case <-timer.C:
 			case <-controller.shutdownBroadcast:
 			case <-controller.shutdownBroadcast:
+				timer.Stop()
 				break fetcherLoop
 				break fetcherLoop
 			}
 			}
 		}
 		}
@@ -421,10 +422,12 @@ fetcherLoop:
 func (controller *Controller) establishTunnelWatcher() {
 func (controller *Controller) establishTunnelWatcher() {
 	defer controller.runWaitGroup.Done()
 	defer controller.runWaitGroup.Done()
 
 
-	timeout := time.After(
+	timer := time.NewTimer(
 		time.Duration(*controller.config.EstablishTunnelTimeoutSeconds) * time.Second)
 		time.Duration(*controller.config.EstablishTunnelTimeoutSeconds) * time.Second)
+	defer timer.Stop()
+
 	select {
 	select {
-	case <-timeout:
+	case <-timer.C:
 		if !controller.hasEstablishedOnce() {
 		if !controller.hasEstablishedOnce() {
 			NoticeAlert("failed to establish tunnel before timeout")
 			NoticeAlert("failed to establish tunnel before timeout")
 			controller.SignalComponentFailure()
 			controller.SignalComponentFailure()
@@ -468,13 +471,17 @@ loop:
 		} else {
 		} else {
 			duration = PSIPHON_API_CONNECTED_REQUEST_RETRY_PERIOD
 			duration = PSIPHON_API_CONNECTED_REQUEST_RETRY_PERIOD
 		}
 		}
-		timeout := time.After(duration)
+		timer := time.NewTimer(duration)
+		doBreak := false
 		select {
 		select {
 		case <-controller.signalReportConnected:
 		case <-controller.signalReportConnected:
-		case <-timeout:
+		case <-timer.C:
 			// Make another connected request
 			// Make another connected request
-
 		case <-controller.shutdownBroadcast:
 		case <-controller.shutdownBroadcast:
+			doBreak = true
+		}
+		timer.Stop()
+		if doBreak {
 			break loop
 			break loop
 		}
 		}
 	}
 	}
@@ -570,11 +577,12 @@ downloadLoop:
 
 
 			NoticeAlert("failed to download upgrade: %s", err)
 			NoticeAlert("failed to download upgrade: %s", err)
 
 
-			timeout := time.After(
+			timer := time.NewTimer(
 				time.Duration(*controller.config.DownloadUpgradeRetryPeriodSeconds) * time.Second)
 				time.Duration(*controller.config.DownloadUpgradeRetryPeriodSeconds) * time.Second)
 			select {
 			select {
-			case <-timeout:
+			case <-timer.C:
 			case <-controller.shutdownBroadcast:
 			case <-controller.shutdownBroadcast:
+				timer.Stop()
 				break downloadLoop
 				break downloadLoop
 			}
 			}
 		}
 		}
@@ -1336,12 +1344,17 @@ loop:
 				// and the grace period has elapsed.
 				// and the grace period has elapsed.
 
 
 				timer := time.NewTimer(ESTABLISH_TUNNEL_SERVER_AFFINITY_GRACE_PERIOD)
 				timer := time.NewTimer(ESTABLISH_TUNNEL_SERVER_AFFINITY_GRACE_PERIOD)
+				doBreak := false
 				select {
 				select {
 				case <-timer.C:
 				case <-timer.C:
 				case <-controller.serverAffinityDoneBroadcast:
 				case <-controller.serverAffinityDoneBroadcast:
 				case <-controller.stopEstablishingBroadcast:
 				case <-controller.stopEstablishingBroadcast:
-					break loop
+					doBreak = true
 				case <-controller.shutdownBroadcast:
 				case <-controller.shutdownBroadcast:
+					doBreak = true
+				}
+				timer.Stop()
+				if doBreak {
 					break loop
 					break loop
 				}
 				}
 			} else if controller.config.StaggerConnectionWorkersMilliseconds != 0 {
 			} else if controller.config.StaggerConnectionWorkersMilliseconds != 0 {
@@ -1350,11 +1363,16 @@ loop:
 
 
 				timer := time.NewTimer(time.Millisecond * time.Duration(
 				timer := time.NewTimer(time.Millisecond * time.Duration(
 					controller.config.StaggerConnectionWorkersMilliseconds))
 					controller.config.StaggerConnectionWorkersMilliseconds))
+				doBreak := false
 				select {
 				select {
 				case <-timer.C:
 				case <-timer.C:
 				case <-controller.stopEstablishingBroadcast:
 				case <-controller.stopEstablishingBroadcast:
-					break loop
+					doBreak = true
 				case <-controller.shutdownBroadcast:
 				case <-controller.shutdownBroadcast:
+					doBreak = true
+				}
+				timer.Stop()
+				if doBreak {
 					break loop
 					break loop
 				}
 				}
 			}
 			}
@@ -1396,14 +1414,19 @@ loop:
 		// network conditions to change. Also allows for fetch remote to complete,
 		// network conditions to change. Also allows for fetch remote to complete,
 		// in typical conditions (it isn't strictly necessary to wait for this, there will
 		// in typical conditions (it isn't strictly necessary to wait for this, there will
 		// be more rounds if required).
 		// be more rounds if required).
-		timeout := time.After(
+		timer := time.NewTimer(
 			time.Duration(*controller.config.EstablishTunnelPausePeriodSeconds) * time.Second)
 			time.Duration(*controller.config.EstablishTunnelPausePeriodSeconds) * time.Second)
+		doBreak := false
 		select {
 		select {
-		case <-timeout:
+		case <-timer.C:
 			// Retry iterating
 			// Retry iterating
 		case <-controller.stopEstablishingBroadcast:
 		case <-controller.stopEstablishingBroadcast:
-			break loop
+			doBreak = true
 		case <-controller.shutdownBroadcast:
 		case <-controller.shutdownBroadcast:
+			doBreak = true
+		}
+		timer.Stop()
+		if doBreak {
 			break loop
 			break loop
 		}
 		}
 
 

+ 28 - 9
psiphon/meekConn.go

@@ -629,6 +629,7 @@ func (meek *MeekConn) relay() {
 		MIN_POLL_INTERVAL_JITTER)
 		MIN_POLL_INTERVAL_JITTER)
 
 
 	timeout := time.NewTimer(interval)
 	timeout := time.NewTimer(interval)
+	defer timeout.Stop()
 
 
 	for {
 	for {
 		timeout.Reset(interval)
 		timeout.Reset(interval)
@@ -723,14 +724,19 @@ func (meek *MeekConn) relay() {
 // RoundTrip has called Close and will no longer use the buffer.
 // RoundTrip has called Close and will no longer use the buffer.
 // See: https://golang.org/pkg/net/http/#RoundTripper
 // See: https://golang.org/pkg/net/http/#RoundTripper
 type readCloseSignaller struct {
 type readCloseSignaller struct {
-	reader io.Reader
-	closed chan struct{}
+	context context.Context
+	reader  io.Reader
+	closed  chan struct{}
 }
 }
 
 
-func NewReadCloseSignaller(reader io.Reader) *readCloseSignaller {
+func NewReadCloseSignaller(
+	context context.Context,
+	reader io.Reader) *readCloseSignaller {
+
 	return &readCloseSignaller{
 	return &readCloseSignaller{
-		reader: reader,
-		closed: make(chan struct{}, 1),
+		context: context,
+		reader:  reader,
+		closed:  make(chan struct{}, 1),
 	}
 	}
 }
 }
 
 
@@ -746,8 +752,13 @@ func (r *readCloseSignaller) Close() error {
 	return nil
 	return nil
 }
 }
 
 
-func (r *readCloseSignaller) AwaitClosed() {
-	<-r.closed
+func (r *readCloseSignaller) AwaitClosed() bool {
+	select {
+	case <-r.context.Done():
+	case <-r.closed:
+		return true
+	}
+	return false
 }
 }
 
 
 // roundTrip configures and makes the actual HTTP POST request
 // roundTrip configures and makes the actual HTTP POST request
@@ -816,7 +827,7 @@ func (meek *MeekConn) roundTrip(sendBuffer *bytes.Buffer) (int64, error) {
 			// still reading the current round trip response. signaller provides
 			// still reading the current round trip response. signaller provides
 			// the hook for awaiting RoundTrip's call to Close.
 			// the hook for awaiting RoundTrip's call to Close.
 
 
-			signaller = NewReadCloseSignaller(bytes.NewReader(sendBuffer.Bytes()))
+			signaller = NewReadCloseSignaller(meek.runContext, bytes.NewReader(sendBuffer.Bytes()))
 			requestBody = signaller
 			requestBody = signaller
 			contentLength = sendBuffer.Len()
 			contentLength = sendBuffer.Len()
 		}
 		}
@@ -864,7 +875,14 @@ func (meek *MeekConn) roundTrip(sendBuffer *bytes.Buffer) (int64, error) {
 		// subsequently replace sendBuffer in both the success and
 		// subsequently replace sendBuffer in both the success and
 		// error cases.
 		// error cases.
 		if signaller != nil {
 		if signaller != nil {
-			signaller.AwaitClosed()
+			if !signaller.AwaitClosed() {
+				// AwaitClosed encountered Done(). Abort immediately. Do not
+				// replace sendBuffer, as we cannot be certain RoundTrip is
+				// done with it. MeekConn.Write will exit on Done and not hang
+				// awaiting sendBuffer.
+				sendBuffer = nil
+				return 0, common.ContextError(errors.New("meek connection has closed"))
+			}
 		}
 		}
 
 
 		if err != nil {
 		if err != nil {
@@ -954,6 +972,7 @@ func (meek *MeekConn) roundTrip(sendBuffer *bytes.Buffer) (int64, error) {
 		select {
 		select {
 		case <-delayTimer.C:
 		case <-delayTimer.C:
 		case <-meek.runContext.Done():
 		case <-meek.runContext.Done():
+			delayTimer.Stop()
 			return 0, common.ContextError(err)
 			return 0, common.ContextError(err)
 		}
 		}
 
 

+ 7 - 2
psiphon/server/tunnelServer.go

@@ -865,8 +865,9 @@ func (sshClient *sshClient) run(clientConn net.Conn) {
 
 
 	resultChannel := make(chan *sshNewServerConnResult, 2)
 	resultChannel := make(chan *sshNewServerConnResult, 2)
 
 
+	var afterFunc *time.Timer
 	if SSH_HANDSHAKE_TIMEOUT > 0 {
 	if SSH_HANDSHAKE_TIMEOUT > 0 {
-		time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
+		afterFunc = time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
 			resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
 			resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
 		})
 		})
 	}
 	}
@@ -909,11 +910,15 @@ func (sshClient *sshClient) run(clientConn net.Conn) {
 	case result = <-resultChannel:
 	case result = <-resultChannel:
 	case <-sshClient.sshServer.shutdownBroadcast:
 	case <-sshClient.sshServer.shutdownBroadcast:
 		// Close() will interrupt an ongoing handshake
 		// Close() will interrupt an ongoing handshake
-		// TODO: wait for goroutine to exit before returning?
+		// TODO: wait for SSH handshake goroutines to exit before returning?
 		clientConn.Close()
 		clientConn.Close()
 		return
 		return
 	}
 	}
 
 
+	if afterFunc != nil {
+		afterFunc.Stop()
+	}
+
 	if result.err != nil {
 	if result.err != nil {
 		clientConn.Close()
 		clientConn.Close()
 		// This is a Debug log due to noise. The handshake often fails due to I/O
 		// This is a Debug log due to noise. The handshake often fails due to I/O

+ 2 - 1
psiphon/tlsDialer.go

@@ -195,9 +195,10 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (net.Conn, err
 	var errChannel chan error
 	var errChannel chan error
 	if config.Timeout != 0 {
 	if config.Timeout != 0 {
 		errChannel = make(chan error, 2)
 		errChannel = make(chan error, 2)
-		time.AfterFunc(config.Timeout, func() {
+		timeoutFunc := time.AfterFunc(config.Timeout, func() {
 			errChannel <- errors.New("timed out")
 			errChannel <- errors.New("timed out")
 		})
 		})
+		defer timeoutFunc.Stop()
 	}
 	}
 
 
 	dialAddr := addr
 	dialAddr := addr

+ 18 - 7
psiphon/tunnel.go

@@ -243,6 +243,9 @@ func (tunnel *Tunnel) Activate(
 				tunnel.Close(true)
 				tunnel.Close(true)
 				<-resultChannel
 				<-resultChannel
 			}
 			}
+
+			timer.Stop()
+
 		} else {
 		} else {
 
 
 			select {
 			select {
@@ -319,10 +322,10 @@ func (tunnel *Tunnel) Close(isDiscarded bool) {
 		// precedence over the PSIPHON_API_SERVER_TIMEOUT http.Client.Timeout
 		// precedence over the PSIPHON_API_SERVER_TIMEOUT http.Client.Timeout
 		// value set in makePsiphonHttpsClient.
 		// value set in makePsiphonHttpsClient.
 		if isActivated {
 		if isActivated {
-			timer := time.AfterFunc(TUNNEL_OPERATE_SHUTDOWN_TIMEOUT, func() { tunnel.conn.Close() })
+			afterFunc := time.AfterFunc(TUNNEL_OPERATE_SHUTDOWN_TIMEOUT, func() { tunnel.conn.Close() })
 			close(tunnel.shutdownOperateBroadcast)
 			close(tunnel.shutdownOperateBroadcast)
 			tunnel.operateWaitGroup.Wait()
 			tunnel.operateWaitGroup.Wait()
-			timer.Stop()
+			afterFunc.Stop()
 		}
 		}
 		tunnel.sshClient.Close()
 		tunnel.sshClient.Close()
 		// tunnel.conn.Close() may get called multiple times, which is allowed.
 		// tunnel.conn.Close() may get called multiple times, which is allowed.
@@ -385,9 +388,10 @@ func (tunnel *Tunnel) Dial(
 	}
 	}
 	resultChannel := make(chan *tunnelDialResult, 2)
 	resultChannel := make(chan *tunnelDialResult, 2)
 	if *tunnel.config.TunnelPortForwardDialTimeoutSeconds > 0 {
 	if *tunnel.config.TunnelPortForwardDialTimeoutSeconds > 0 {
-		time.AfterFunc(time.Duration(*tunnel.config.TunnelPortForwardDialTimeoutSeconds)*time.Second, func() {
+		afterFunc := time.AfterFunc(time.Duration(*tunnel.config.TunnelPortForwardDialTimeoutSeconds)*time.Second, func() {
 			resultChannel <- &tunnelDialResult{nil, errors.New("tunnel dial timeout")}
 			resultChannel <- &tunnelDialResult{nil, errors.New("tunnel dial timeout")}
 		})
 		})
+		defer afterFunc.Stop()
 	}
 	}
 	go func() {
 	go func() {
 		sshPortForwardConn, err := tunnel.sshClient.Dial("tcp", remoteAddr)
 		sshPortForwardConn, err := tunnel.sshClient.Dial("tcp", remoteAddr)
@@ -990,9 +994,10 @@ func dialSsh(
 	}
 	}
 	resultChannel := make(chan *sshNewClientResult, 2)
 	resultChannel := make(chan *sshNewClientResult, 2)
 	if *config.TunnelConnectTimeoutSeconds > 0 {
 	if *config.TunnelConnectTimeoutSeconds > 0 {
-		time.AfterFunc(time.Duration(*config.TunnelConnectTimeoutSeconds)*time.Second, func() {
+		afterFunc := time.AfterFunc(time.Duration(*config.TunnelConnectTimeoutSeconds)*time.Second, func() {
 			resultChannel <- &sshNewClientResult{nil, nil, errors.New("ssh dial timeout")}
 			resultChannel <- &sshNewClientResult{nil, nil, errors.New("ssh dial timeout")}
 		})
 		})
+		defer afterFunc.Stop()
 	}
 	}
 
 
 	go func() {
 	go func() {
@@ -1216,11 +1221,16 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 				if failCount > PSIPHON_API_CLIENT_VERIFICATION_REQUEST_MAX_RETRIES {
 				if failCount > PSIPHON_API_CLIENT_VERIFICATION_REQUEST_MAX_RETRIES {
 					return
 					return
 				}
 				}
-				timeout := time.After(PSIPHON_API_CLIENT_VERIFICATION_REQUEST_RETRY_PERIOD)
+				timer := time.NewTimer(PSIPHON_API_CLIENT_VERIFICATION_REQUEST_RETRY_PERIOD)
+				doReturn := false
 				select {
 				select {
-				case <-timeout:
+				case <-timer.C:
 				case clientVerificationPayload = <-tunnel.newClientVerificationPayload:
 				case clientVerificationPayload = <-tunnel.newClientVerificationPayload:
 				case <-signalStopClientVerificationRequests:
 				case <-signalStopClientVerificationRequests:
+					doReturn = true
+				}
+				timer.Stop()
+				if doReturn {
 					return
 					return
 				}
 				}
 			}
 			}
@@ -1430,9 +1440,10 @@ func sendSshKeepAlive(
 
 
 	errChannel := make(chan error, 2)
 	errChannel := make(chan error, 2)
 	if timeout > 0 {
 	if timeout > 0 {
-		time.AfterFunc(timeout, func() {
+		afterFunc := time.AfterFunc(timeout, func() {
 			errChannel <- errors.New("timed out")
 			errChannel <- errors.New("timed out")
 		})
 		})
+		defer afterFunc.Stop()
 	}
 	}
 
 
 	go func() {
 	go func() {