Просмотр исходного кода

improve tunnel stop logic

Encapsulate stop logic in only one place. (To make it easier to reason about, etc.)

Make stop wait for the tunnel to actually stop before returning.
Adam Pritchard 1 год назад
Родитель
Сommit
b6982b7d55
2 измененных файлов с 24 добавлено и 23 удалено
  1. 23 22
      ClientLibrary/clientlib/clientlib.go
  2. 1 1
      ClientLibrary/clientlib/clientlib_test.go

+ 23 - 22
ClientLibrary/clientlib/clientlib.go

@@ -79,7 +79,7 @@ type Parameters struct {
 // retrieving proxy ports.
 type PsiphonTunnel struct {
 	mu                          sync.Mutex
-	cancelTunnelCtx             context.CancelCauseFunc
+	stop                        func()
 	embeddedServerListWaitGroup sync.WaitGroup
 	controllerWaitGroup         sync.WaitGroup
 	controllerDial              func(string, net.Conn) (net.Conn, error)
@@ -251,20 +251,29 @@ func StartTunnel(
 		return nil, fmt.Errorf("failed to open data store: %w", err)
 	}
 
-	// Make sure we close the datastore in case of error
+	// Create a cancelable context that will be used for stopping the tunnel
+	tunnelCtx, cancelTunnelCtx := context.WithCancel(ctx)
+
+	// Because the tunnel object is only returned on success, there are at least two
+	// problems that we don't need to worry about:
+	// 1. This stop function is called both by the error-defer here and by a call to the
+	//    tunnel's Stop method.
+	// 2. This stop function is called via the tunnel's Stop method before the WaitGroups
+	//    are incremented (causing a race condition).
+	tunnel.stop = func() {
+		cancelTunnelCtx()
+		tunnel.embeddedServerListWaitGroup.Wait()
+		tunnel.controllerWaitGroup.Wait()
+		psiphon.CloseDataStore()
+		started.Store(false)
+	}
+
 	defer func() {
 		if retErr != nil {
-			tunnel.controllerWaitGroup.Wait()
-			tunnel.embeddedServerListWaitGroup.Wait()
-			psiphon.CloseDataStore()
-			started.Store(false)
+			tunnel.stop()
 		}
 	}()
 
-	// Create a cancelable context that will be used for stopping the tunnel
-	var tunnelCtx context.Context
-	tunnelCtx, tunnel.cancelTunnelCtx = context.WithCancelCause(ctx)
-
 	// If specified, the embedded server list is loaded and stored. When there
 	// are no server candidates at all, we wait for this import to complete
 	// before starting the Psiphon controller. Otherwise, we import while
@@ -299,7 +308,6 @@ func StartTunnel(
 	// Create the Psiphon controller
 	controller, err := psiphon.NewController(config)
 	if err != nil {
-		tunnel.cancelTunnelCtx(fmt.Errorf("psiphon.NewController failed: %w", err))
 		return nil, fmt.Errorf("psiphon.NewController failed: %w", err)
 	}
 
@@ -338,7 +346,6 @@ func StartTunnel(
 	case <-connectedCh:
 		return tunnel, nil
 	case err := <-erroredCh:
-		tunnel.cancelTunnelCtx(fmt.Errorf("tunnel establishment failed: %w", err))
 		if err != ErrTimeout {
 			err = fmt.Errorf("tunnel start produced error: %w", err)
 		}
@@ -351,22 +358,16 @@ func StartTunnel(
 // It is safe to call concurrently with Dial and with itself.
 func (tunnel *PsiphonTunnel) Stop() {
 	tunnel.mu.Lock()
-	cancelTunnelCtx := tunnel.cancelTunnelCtx
-	tunnel.cancelTunnelCtx = nil
+	stop := tunnel.stop
+	tunnel.stop = nil
 	tunnel.controllerDial = nil
 	tunnel.mu.Unlock()
 
-	if cancelTunnelCtx == nil {
+	if stop == nil {
 		return
 	}
 
-	cancelTunnelCtx(fmt.Errorf("Stop called"))
-	tunnel.embeddedServerListWaitGroup.Wait()
-	tunnel.controllerWaitGroup.Wait()
-	psiphon.CloseDataStore()
-
-	// Reset the started flag so that StartTunnel can be called again
-	started.Store(false)
+	stop()
 }
 
 // Dial connects to the specified address through the Psiphon tunnel.

+ 1 - 1
ClientLibrary/clientlib/clientlib_test.go

@@ -252,7 +252,7 @@ func TestStartTunnel(t *testing.T) {
 	}
 }
 
-func TestMultpleStartTunnel(t *testing.T) {
+func TestMultipleStartTunnel(t *testing.T) {
 	configJSON, err := os.ReadFile("../../psiphon/controller_test.config")
 	if err != nil {
 		// What to do if config file is not present?