Explorar o código

only allow single tunnel instance

Adam Pritchard hai 1 ano
pai
achega
007d81eeb2
Modificáronse 2 ficheiros con 83 adicións e 10 borrados
  1. 25 10
      ClientLibrary/clientlib/clientlib.go
  2. 58 0
      ClientLibrary/clientlib/clientlib_test.go

+ 25 - 10
ClientLibrary/clientlib/clientlib.go

@@ -27,6 +27,7 @@ import (
 	"net"
 	"path/filepath"
 	"sync"
+	"sync/atomic"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
@@ -104,6 +105,10 @@ type NoticeEvent struct {
 
 // ErrTimeout is returned when the tunnel establishment attempt fails due to timeout
 var ErrTimeout = std_errors.New("clientlib: tunnel establishment timeout")
+var errMultipleStart = std_errors.New("clientlib: StartTunnel called multiple times")
+
+// started is used to ensure StartTunnel is called only once
+var started atomic.Bool
 
 // StartTunnel establishes a Psiphon tunnel. It returns an error if the establishment
 // was not successful. If the returned error is nil, the returned tunnel can be used
@@ -131,9 +136,13 @@ func StartTunnel(
 	paramsDelta ParametersDelta,
 	noticeReceiver func(NoticeEvent)) (retTunnel *PsiphonTunnel, retErr error) {
 
+	if !started.CompareAndSwap(false, true) {
+		return nil, errMultipleStart
+	}
+
 	config, err := psiphon.LoadConfig(configJSON)
 	if err != nil {
-		return nil, errors.TraceMsg(err, "failed to load config file")
+		return nil, fmt.Errorf("failed to load config file: %w", err)
 	}
 
 	// Use params.DataRootDirectory to set related config values.
@@ -177,15 +186,14 @@ func StartTunnel(
 	// or attempting to connect.
 	err = config.Commit(true)
 	if err != nil {
-		return nil, errors.TraceMsg(err, "config.Commit failed")
+		return nil, fmt.Errorf("config.Commit failed: %w", err)
 	}
 
 	// If supplied, apply the parameters delta
 	if len(paramsDelta) > 0 {
 		err = config.SetParameters("", false, paramsDelta)
 		if err != nil {
-			return nil, errors.TraceMsg(
-				err, fmt.Sprintf("SetParameters failed for delta: %v", paramsDelta))
+			return nil, fmt.Errorf("SetParameters failed for delta: %v; %w", paramsDelta, err)
 		}
 	}
 
@@ -205,7 +213,7 @@ func StartTunnel(
 			if err != nil {
 				// This is unexpected and probably indicates something fatal has occurred.
 				// We'll interpret it as a connection error and abort.
-				err = errors.TraceMsg(err, "failed to unmarshal notice JSON")
+				err = fmt.Errorf("failed to unmarshal notice JSON: %w", err)
 				select {
 				case erroredCh <- err:
 				default:
@@ -240,7 +248,7 @@ func StartTunnel(
 
 	err = psiphon.OpenDataStore(config)
 	if err != nil {
-		return nil, errors.TraceMsg(err, "failed to open data store")
+		return nil, fmt.Errorf("failed to open data store: %w", err)
 	}
 
 	// Make sure we close the datastore in case of error
@@ -249,6 +257,7 @@ func StartTunnel(
 			tunnel.controllerWaitGroup.Wait()
 			tunnel.embeddedServerListWaitGroup.Wait()
 			psiphon.CloseDataStore()
+			started.Store(false)
 		}
 	}()
 
@@ -291,7 +300,7 @@ func StartTunnel(
 	controller, err := psiphon.NewController(config)
 	if err != nil {
 		tunnel.cancelTunnelCtx(fmt.Errorf("psiphon.NewController failed: %w", err))
-		return nil, errors.TraceMsg(err, "psiphon.NewController failed")
+		return nil, fmt.Errorf("psiphon.NewController failed: %w", err)
 	}
 
 	tunnel.controllerDial = controller.Dial
@@ -331,14 +340,15 @@ func StartTunnel(
 	case err := <-erroredCh:
 		tunnel.cancelTunnelCtx(fmt.Errorf("tunnel establishment failed: %w", err))
 		if err != ErrTimeout {
-			err = errors.TraceMsg(err, "tunnel start produced error")
+			err = fmt.Errorf("tunnel start produced error: %w", err)
 		}
 		return nil, err
 	}
 }
 
-// Stop stops/disconnects/shuts down the tunnel. It is safe to call when not connected.
-// Not safe to call concurrently with Start.
+// Stop stops/disconnects/shuts down the tunnel.
+// It is safe to call Stop multiple times.
+// It is safe to call concurrently with Dial and with itself.
 func (tunnel *PsiphonTunnel) Stop() {
 	tunnel.mu.Lock()
 	cancelTunnelCtx := tunnel.cancelTunnelCtx
@@ -354,9 +364,14 @@ func (tunnel *PsiphonTunnel) Stop() {
 	tunnel.embeddedServerListWaitGroup.Wait()
 	tunnel.controllerWaitGroup.Wait()
 	psiphon.CloseDataStore()
+
+	// Reset the started flag so that StartTunnel can be called again
+	started.Store(false)
 }
 
 // Dial connects to the specified address through the Psiphon tunnel.
+// It is safe to call Dial after the tunnel has been stopped.
+// It is safe to call Dial concurrently with Stop.
 func (tunnel *PsiphonTunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
 	// Ensure the dial is accessed in a thread-safe manner, without holding the lock
 	// while calling the dial function.

+ 58 - 0
ClientLibrary/clientlib/clientlib_test.go

@@ -252,6 +252,64 @@ func TestStartTunnel(t *testing.T) {
 	}
 }
 
+func TestMultpleStartTunnel(t *testing.T) {
+	configJSON, err := os.ReadFile("../../psiphon/controller_test.config")
+	if err != nil {
+		// What to do if config file is not present?
+		t.Skipf("error loading configuration file: %s", err)
+	}
+
+	testDataDirName, err := os.MkdirTemp("", "psiphon-clientlib-test")
+	if err != nil {
+		t.Fatalf("ioutil.TempDir failed: %v", err)
+	}
+	defer os.RemoveAll(testDataDirName)
+
+	ctx := context.Background()
+
+	tunnel1, err := StartTunnel(
+		ctx,
+		configJSON,
+		"",
+		Parameters{DataRootDirectory: &testDataDirName},
+		nil,
+		nil)
+
+	if err != nil {
+		t.Fatalf("first StartTunnel() error = %v", err)
+	}
+
+	// We have not stopped the tunnel, so a second StartTunnel() should fail
+	_, err = StartTunnel(
+		ctx,
+		configJSON,
+		"",
+		Parameters{DataRootDirectory: &testDataDirName},
+		nil,
+		nil)
+
+	if err != errMultipleStart {
+		t.Fatalf("second StartTunnel() should have failed with errMultipleStart; got %v", err)
+	}
+
+	// Stop the tunnel and try again
+	tunnel1.Stop()
+	tunnel3, err := StartTunnel(
+		ctx,
+		configJSON,
+		"",
+		Parameters{DataRootDirectory: &testDataDirName},
+		nil,
+		nil)
+
+	if err != nil {
+		t.Fatalf("third StartTunnel() error = %v", err)
+	}
+
+	// Stop the tunnel so it doesn't interfere with other tests
+	tunnel3.Stop()
+}
+
 func TestPsiphonTunnel_Dial(t *testing.T) {
 	trueVal := true
 	configJSON, err := os.ReadFile("../../psiphon/controller_test.config")