فهرست منبع

make Dial and Stop threadsafe

Adam Pritchard 1 سال پیش
والد
کامیت
45e1212906
1فایلهای تغییر یافته به همراه41 افزوده شده و 26 حذف شده
  1. 41 26
      ClientLibrary/clientlib/clientlib.go

+ 41 - 26
ClientLibrary/clientlib/clientlib.go

@@ -77,9 +77,10 @@ type Parameters struct {
 // PsiphonTunnel is the tunnel object. It can be used for stopping the tunnel and
 // retrieving proxy ports.
 type PsiphonTunnel struct {
+	mu                          sync.Mutex
+	cancelTunnelCtx             context.CancelCauseFunc
 	embeddedServerListWaitGroup sync.WaitGroup
 	controllerWaitGroup         sync.WaitGroup
-	stopController              context.CancelFunc
 	controllerDial              func(string, net.Conn) (net.Conn, error)
 
 	// The port on which the HTTP proxy is running
@@ -188,10 +189,10 @@ func StartTunnel(
 		}
 	}
 
-	// Will receive a value when the tunnel has successfully connected.
-	connected := make(chan struct{}, 1)
-	// Will receive a value if an error occurs during the connection sequence.
-	errored := make(chan error, 1)
+	// Will be closed when the tunnel has successfully connected
+	connectedCh := make(chan struct{})
+	// Will receive a value if an error occurs during the connection sequence
+	erroredCh := make(chan error, 1)
 
 	// Create the tunnel object
 	tunnel := new(PsiphonTunnel)
@@ -206,7 +207,7 @@ func StartTunnel(
 				// We'll interpret it as a connection error and abort.
 				err = errors.TraceMsg(err, "failed to unmarshal notice JSON")
 				select {
-				case errored <- err:
+				case erroredCh <- err:
 				default:
 				}
 				return
@@ -220,16 +221,13 @@ func StartTunnel(
 				tunnel.SOCKSProxyPort = int(port)
 			} else if event.Type == "EstablishTunnelTimeout" {
 				select {
-				case errored <- ErrTimeout:
+				case erroredCh <- ErrTimeout:
 				default:
 				}
 			} else if event.Type == "Tunnels" {
 				count := event.Data["count"].(float64)
 				if count > 0 {
-					select {
-					case connected <- struct{}{}:
-					default:
-					}
+					close(connectedCh)
 				}
 			}
 
@@ -244,6 +242,7 @@ func StartTunnel(
 	if err != nil {
 		return nil, errors.TraceMsg(err, "failed to open data store")
 	}
+
 	// Make sure we close the datastore in case of error
 	defer func() {
 		if retErr != nil {
@@ -254,8 +253,8 @@ func StartTunnel(
 	}()
 
 	// Create a cancelable context that will be used for stopping the tunnel
-	var controllerCtx context.Context
-	controllerCtx, tunnel.stopController = context.WithCancel(ctx)
+	var tunnelCtx context.Context
+	tunnelCtx, tunnel.cancelTunnelCtx = context.WithCancelCause(ctx)
 
 	// If specified, the embedded server list is loaded and stored. When there
 	// are no server candidates at all, we wait for this import to complete
@@ -274,7 +273,7 @@ func StartTunnel(
 		defer tunnel.embeddedServerListWaitGroup.Done()
 
 		err := psiphon.ImportEmbeddedServerEntries(
-			controllerCtx,
+			tunnelCtx,
 			config,
 			"",
 			embeddedServerEntryList)
@@ -291,8 +290,7 @@ func StartTunnel(
 	// Create the Psiphon controller
 	controller, err := psiphon.NewController(config)
 	if err != nil {
-		tunnel.stopController()
-		tunnel.embeddedServerListWaitGroup.Wait()
+		tunnel.cancelTunnelCtx(fmt.Errorf("psiphon.NewController failed: %w", err))
 		return nil, errors.TraceMsg(err, "psiphon.NewController failed")
 	}
 
@@ -304,12 +302,12 @@ func StartTunnel(
 		defer tunnel.controllerWaitGroup.Done()
 
 		// Start the tunnel. Only returns on error (or internal timeout).
-		controller.Run(controllerCtx)
+		controller.Run(tunnelCtx)
 
 		// controller.Run does not exit until the goroutine that posts
 		// EstablishTunnelTimeout has terminated; so, if there was a
 		// EstablishTunnelTimeout event, ErrTimeout is guaranteed to be sent to
-		// errord before this next error and will be the StartTunnel return value.
+		// errored before this next error and will be the StartTunnel return value.
 
 		var err error
 		switch ctx.Err() {
@@ -321,17 +319,17 @@ func StartTunnel(
 			err = errors.TraceNew("controller.Run exited unexpectedly")
 		}
 		select {
-		case errored <- err:
+		case erroredCh <- err:
 		default:
 		}
 	}()
 
 	// Wait for an active tunnel or error
 	select {
-	case <-connected:
+	case <-connectedCh:
 		return tunnel, nil
-	case err := <-errored:
-		tunnel.Stop()
+	case err := <-erroredCh:
+		tunnel.cancelTunnelCtx(fmt.Errorf("tunnel establishment failed: %w", err))
 		if err != ErrTimeout {
 			err = errors.TraceMsg(err, "tunnel start produced error")
 		}
@@ -342,16 +340,33 @@ func StartTunnel(
 // Stop stops/disconnects/shuts down the tunnel. It is safe to call when not connected.
 // Not safe to call concurrently with Start.
 func (tunnel *PsiphonTunnel) Stop() {
-	if tunnel.stopController == nil {
+	tunnel.mu.Lock()
+	cancelTunnelCtx := tunnel.cancelTunnelCtx
+	tunnel.cancelTunnelCtx = nil
+	tunnel.controllerDial = nil
+	tunnel.mu.Unlock()
+
+	if cancelTunnelCtx == nil {
 		return
 	}
-	tunnel.stopController()
-	tunnel.controllerWaitGroup.Wait()
+
+	cancelTunnelCtx(fmt.Errorf("Stop called"))
 	tunnel.embeddedServerListWaitGroup.Wait()
+	tunnel.controllerWaitGroup.Wait()
 	psiphon.CloseDataStore()
 }
 
 // Dial connects to the specified address through the Psiphon tunnel.
 func (tunnel *PsiphonTunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
-	return tunnel.controllerDial(remoteAddr, nil)
+	// Ensure the dial is accessed in a thread-safe manner, without holding the lock
+	// while calling the dial function.
+	// Note that it is safe for controller.Dial to be called even after or during a tunnel
+	// shutdown (i.e., if the context has been canceled).
+	tunnel.mu.Lock()
+	dial := tunnel.controllerDial
+	tunnel.mu.Unlock()
+	if dial == nil {
+		return nil, errors.TraceNew("tunnel not started")
+	}
+	return dial(remoteAddr, nil)
 }