|
|
@@ -77,9 +77,10 @@ type Parameters struct {
|
|
|
// PsiphonTunnel is the tunnel object. It can be used for stopping the tunnel and
|
|
|
// retrieving proxy ports.
|
|
|
type PsiphonTunnel struct {
|
|
|
+ mu sync.Mutex
|
|
|
+ cancelTunnelCtx context.CancelCauseFunc
|
|
|
embeddedServerListWaitGroup sync.WaitGroup
|
|
|
controllerWaitGroup sync.WaitGroup
|
|
|
- stopController context.CancelFunc
|
|
|
controllerDial func(string, net.Conn) (net.Conn, error)
|
|
|
|
|
|
// The port on which the HTTP proxy is running
|
|
|
@@ -188,10 +189,10 @@ func StartTunnel(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // Will receive a value when the tunnel has successfully connected.
|
|
|
- connected := make(chan struct{}, 1)
|
|
|
- // Will receive a value if an error occurs during the connection sequence.
|
|
|
- errored := make(chan error, 1)
|
|
|
+ // Will be closed when the tunnel has successfully connected
|
|
|
+ connectedCh := make(chan struct{})
|
|
|
+ // Will receive a value if an error occurs during the connection sequence
|
|
|
+ erroredCh := make(chan error, 1)
|
|
|
|
|
|
// Create the tunnel object
|
|
|
tunnel := new(PsiphonTunnel)
|
|
|
@@ -206,7 +207,7 @@ func StartTunnel(
|
|
|
// We'll interpret it as a connection error and abort.
|
|
|
err = errors.TraceMsg(err, "failed to unmarshal notice JSON")
|
|
|
select {
|
|
|
- case errored <- err:
|
|
|
+ case erroredCh <- err:
|
|
|
default:
|
|
|
}
|
|
|
return
|
|
|
@@ -220,16 +221,13 @@ func StartTunnel(
|
|
|
tunnel.SOCKSProxyPort = int(port)
|
|
|
} else if event.Type == "EstablishTunnelTimeout" {
|
|
|
select {
|
|
|
- case errored <- ErrTimeout:
|
|
|
+ case erroredCh <- ErrTimeout:
|
|
|
default:
|
|
|
}
|
|
|
} else if event.Type == "Tunnels" {
|
|
|
count := event.Data["count"].(float64)
|
|
|
if count > 0 {
|
|
|
- select {
|
|
|
- case connected <- struct{}{}:
|
|
|
- default:
|
|
|
- }
|
|
|
+ close(connectedCh)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -244,6 +242,7 @@ func StartTunnel(
|
|
|
if err != nil {
|
|
|
return nil, errors.TraceMsg(err, "failed to open data store")
|
|
|
}
|
|
|
+
|
|
|
// Make sure we close the datastore in case of error
|
|
|
defer func() {
|
|
|
if retErr != nil {
|
|
|
@@ -254,8 +253,8 @@ func StartTunnel(
|
|
|
}()
|
|
|
|
|
|
// Create a cancelable context that will be used for stopping the tunnel
|
|
|
- var controllerCtx context.Context
|
|
|
- controllerCtx, tunnel.stopController = context.WithCancel(ctx)
|
|
|
+ var tunnelCtx context.Context
|
|
|
+ tunnelCtx, tunnel.cancelTunnelCtx = context.WithCancelCause(ctx)
|
|
|
|
|
|
// If specified, the embedded server list is loaded and stored. When there
|
|
|
// are no server candidates at all, we wait for this import to complete
|
|
|
@@ -274,7 +273,7 @@ func StartTunnel(
|
|
|
defer tunnel.embeddedServerListWaitGroup.Done()
|
|
|
|
|
|
err := psiphon.ImportEmbeddedServerEntries(
|
|
|
- controllerCtx,
|
|
|
+ tunnelCtx,
|
|
|
config,
|
|
|
"",
|
|
|
embeddedServerEntryList)
|
|
|
@@ -291,8 +290,7 @@ func StartTunnel(
|
|
|
// Create the Psiphon controller
|
|
|
controller, err := psiphon.NewController(config)
|
|
|
if err != nil {
|
|
|
- tunnel.stopController()
|
|
|
- tunnel.embeddedServerListWaitGroup.Wait()
|
|
|
+ tunnel.cancelTunnelCtx(fmt.Errorf("psiphon.NewController failed: %w", err))
|
|
|
return nil, errors.TraceMsg(err, "psiphon.NewController failed")
|
|
|
}
|
|
|
|
|
|
@@ -304,12 +302,12 @@ func StartTunnel(
|
|
|
defer tunnel.controllerWaitGroup.Done()
|
|
|
|
|
|
// Start the tunnel. Only returns on error (or internal timeout).
|
|
|
- controller.Run(controllerCtx)
|
|
|
+ controller.Run(tunnelCtx)
|
|
|
|
|
|
// controller.Run does not exit until the goroutine that posts
|
|
|
// EstablishTunnelTimeout has terminated; so, if there was a
|
|
|
// EstablishTunnelTimeout event, ErrTimeout is guaranteed to be sent to
|
|
|
- // errord before this next error and will be the StartTunnel return value.
|
|
|
+ // errored before this next error and will be the StartTunnel return value.
|
|
|
|
|
|
var err error
|
|
|
switch ctx.Err() {
|
|
|
@@ -321,17 +319,17 @@ func StartTunnel(
|
|
|
err = errors.TraceNew("controller.Run exited unexpectedly")
|
|
|
}
|
|
|
select {
|
|
|
- case errored <- err:
|
|
|
+ case erroredCh <- err:
|
|
|
default:
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
// Wait for an active tunnel or error
|
|
|
select {
|
|
|
- case <-connected:
|
|
|
+ case <-connectedCh:
|
|
|
return tunnel, nil
|
|
|
- case err := <-errored:
|
|
|
- tunnel.Stop()
|
|
|
+ case err := <-erroredCh:
|
|
|
+ tunnel.cancelTunnelCtx(fmt.Errorf("tunnel establishment failed: %w", err))
|
|
|
if err != ErrTimeout {
|
|
|
err = errors.TraceMsg(err, "tunnel start produced error")
|
|
|
}
|
|
|
@@ -342,16 +340,33 @@ func StartTunnel(
|
|
|
// Stop stops/disconnects/shuts down the tunnel. It is safe to call when not connected.
|
|
|
// Not safe to call concurrently with Start.
|
|
|
func (tunnel *PsiphonTunnel) Stop() {
|
|
|
- if tunnel.stopController == nil {
|
|
|
+ tunnel.mu.Lock()
|
|
|
+ cancelTunnelCtx := tunnel.cancelTunnelCtx
|
|
|
+ tunnel.cancelTunnelCtx = nil
|
|
|
+ tunnel.controllerDial = nil
|
|
|
+ tunnel.mu.Unlock()
|
|
|
+
|
|
|
+ if cancelTunnelCtx == nil {
|
|
|
return
|
|
|
}
|
|
|
- tunnel.stopController()
|
|
|
- tunnel.controllerWaitGroup.Wait()
|
|
|
+
|
|
|
+ cancelTunnelCtx(fmt.Errorf("Stop called"))
|
|
|
tunnel.embeddedServerListWaitGroup.Wait()
|
|
|
+ tunnel.controllerWaitGroup.Wait()
|
|
|
psiphon.CloseDataStore()
|
|
|
}
|
|
|
|
|
|
// Dial connects to the specified address through the Psiphon tunnel.
|
|
|
func (tunnel *PsiphonTunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
|
|
|
- return tunnel.controllerDial(remoteAddr, nil)
|
|
|
+ // Ensure the dial is accessed in a thread-safe manner, without holding the lock
|
|
|
+ // while calling the dial function.
|
|
|
+ // Note that it is safe for controller.Dial to be called even after or during a tunnel
|
|
|
+ // shutdown (i.e., if the context has been canceled).
|
|
|
+ tunnel.mu.Lock()
|
|
|
+ dial := tunnel.controllerDial
|
|
|
+ tunnel.mu.Unlock()
|
|
|
+ if dial == nil {
|
|
|
+ return nil, errors.TraceNew("tunnel not started")
|
|
|
+ }
|
|
|
+ return dial(remoteAddr, nil)
|
|
|
}
|