Переглянути джерело

Merge pull request #686 from adam-p/clientlib-dial

Add Dial method
Rod Hynes 1 рік тому
батько
коміт
1150f1258e

+ 105 - 41
ClientLibrary/clientlib/clientlib.go

@@ -24,8 +24,11 @@ import (
 	"encoding/json"
 	std_errors "errors"
 	"fmt"
+	"io"
+	"net"
 	"path/filepath"
 	"sync"
+	"sync/atomic"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
@@ -65,14 +68,22 @@ type Parameters struct {
 	// notices to noticeReceiver. Has no effect unless the tunnel
 	// config.EmitDiagnosticNotices flag is set.
 	EmitDiagnosticNoticesToFiles bool
+
+	// DisableLocalSocksProxy disables running the local SOCKS proxy.
+	DisableLocalSocksProxy *bool
+
+	// DisableLocalHTTPProxy disables running the local HTTP proxy.
+	DisableLocalHTTPProxy *bool
 }
 
 // PsiphonTunnel is the tunnel object. It can be used for stopping the tunnel and
 // retrieving proxy ports.
 type PsiphonTunnel struct {
+	mu                          sync.Mutex
+	stop                        func()
 	embeddedServerListWaitGroup sync.WaitGroup
 	controllerWaitGroup         sync.WaitGroup
-	stopController              context.CancelFunc
+	controllerDial              func(string, net.Conn) (net.Conn, error)
 
 	// The port on which the HTTP proxy is running
 	HTTPProxyPort int
@@ -95,6 +106,10 @@ type NoticeEvent struct {
 
 // ErrTimeout is returned when the tunnel establishment attempt fails due to timeout
 var ErrTimeout = std_errors.New("clientlib: tunnel establishment timeout")
+var errMultipleStart = std_errors.New("clientlib: StartTunnel called multiple times")
+
+// started is used to ensure that only one tunnel is started at a time
+var started atomic.Bool
 
 // StartTunnel establishes a Psiphon tunnel. It returns an error if the establishment
 // was not successful. If the returned error is nil, the returned tunnel can be used
@@ -122,6 +137,10 @@ func StartTunnel(
 	paramsDelta ParametersDelta,
 	noticeReceiver func(NoticeEvent)) (retTunnel *PsiphonTunnel, retErr error) {
 
+	if !started.CompareAndSwap(false, true) {
+		return nil, errMultipleStart
+	}
+
 	config, err := psiphon.LoadConfig(configJSON)
 	if err != nil {
 		return nil, errors.TraceMsg(err, "failed to load config file")
@@ -156,6 +175,14 @@ func StartTunnel(
 		}
 	} // else use the value in the config
 
+	if params.DisableLocalSocksProxy != nil {
+		config.DisableLocalSocksProxy = *params.DisableLocalSocksProxy
+	} // else use the value in the config
+
+	if params.DisableLocalHTTPProxy != nil {
+		config.DisableLocalHTTPProxy = *params.DisableLocalHTTPProxy
+	} // else use the value in the config
+
 	// config.Commit must be called before calling config.SetParameters
 	// or attempting to connect.
 	err = config.Commit(true)
@@ -167,15 +194,14 @@ func StartTunnel(
 	if len(paramsDelta) > 0 {
 		err = config.SetParameters("", false, paramsDelta)
 		if err != nil {
-			return nil, errors.TraceMsg(
-				err, fmt.Sprintf("SetParameters failed for delta: %v", paramsDelta))
+			return nil, errors.TraceMsg(err, fmt.Sprintf("SetParameters failed for delta: %v", paramsDelta))
 		}
 	}
 
-	// Will receive a value when the tunnel has successfully connected.
-	connected := make(chan struct{}, 1)
-	// Will receive a value if an error occurs during the connection sequence.
-	errored := make(chan error, 1)
+	// Will be closed when the tunnel has successfully connected
+	connectedSignal := make(chan struct{})
+	// Will receive a value if an error occurs during the connection sequence
+	erroredCh := make(chan error, 1)
 
 	// Create the tunnel object
 	tunnel := new(PsiphonTunnel)
@@ -190,7 +216,7 @@ func StartTunnel(
 				// We'll interpret it as a connection error and abort.
 				err = errors.TraceMsg(err, "failed to unmarshal notice JSON")
 				select {
-				case errored <- err:
+				case erroredCh <- err:
 				default:
 				}
 				return
@@ -204,16 +230,13 @@ func StartTunnel(
 				tunnel.SOCKSProxyPort = int(port)
 			} else if event.Type == "EstablishTunnelTimeout" {
 				select {
-				case errored <- ErrTimeout:
+				case erroredCh <- ErrTimeout:
 				default:
 				}
 			} else if event.Type == "Tunnels" {
 				count := event.Data["count"].(float64)
 				if count > 0 {
-					select {
-					case connected <- struct{}{}:
-					default:
-					}
+					close(connectedSignal)
 				}
 			}
 
@@ -228,19 +251,30 @@ func StartTunnel(
 	if err != nil {
 		return nil, errors.TraceMsg(err, "failed to open data store")
 	}
-	// Make sure we close the datastore in case of error
+
+	// Create a cancelable context that will be used for stopping the tunnel
+	tunnelCtx, cancelTunnelCtx := context.WithCancel(ctx)
+
+	// Because the tunnel object is only returned on success, there are at least two
+	// problems that we don't need to worry about:
+	// 1. This stop function is called both by the error-defer here and by a call to the
+	//    tunnel's Stop method.
+	// 2. This stop function is called via the tunnel's Stop method before the WaitGroups
+	//    are incremented (causing a race condition).
+	tunnel.stop = func() {
+		cancelTunnelCtx()
+		tunnel.embeddedServerListWaitGroup.Wait()
+		tunnel.controllerWaitGroup.Wait()
+		psiphon.CloseDataStore()
+		started.Store(false)
+	}
+
 	defer func() {
 		if retErr != nil {
-			tunnel.controllerWaitGroup.Wait()
-			tunnel.embeddedServerListWaitGroup.Wait()
-			psiphon.CloseDataStore()
+			tunnel.stop()
 		}
 	}()
 
-	// Create a cancelable context that will be used for stopping the tunnel
-	var controllerCtx context.Context
-	controllerCtx, tunnel.stopController = context.WithCancel(ctx)
-
 	// If specified, the embedded server list is loaded and stored. When there
 	// are no server candidates at all, we wait for this import to complete
 	// before starting the Psiphon controller. Otherwise, we import while
@@ -258,7 +292,7 @@ func StartTunnel(
 		defer tunnel.embeddedServerListWaitGroup.Done()
 
 		err := psiphon.ImportEmbeddedServerEntries(
-			controllerCtx,
+			tunnelCtx,
 			config,
 			"",
 			embeddedServerEntryList)
@@ -275,45 +309,44 @@ func StartTunnel(
 	// Create the Psiphon controller
 	controller, err := psiphon.NewController(config)
 	if err != nil {
-		tunnel.stopController()
-		tunnel.embeddedServerListWaitGroup.Wait()
 		return nil, errors.TraceMsg(err, "psiphon.NewController failed")
 	}
 
+	tunnel.controllerDial = controller.Dial
+
 	// Begin tunnel connection
 	tunnel.controllerWaitGroup.Add(1)
 	go func() {
 		defer tunnel.controllerWaitGroup.Done()
 
 		// Start the tunnel. Only returns on error (or internal timeout).
-		controller.Run(controllerCtx)
+		controller.Run(tunnelCtx)
 
 		// controller.Run does not exit until the goroutine that posts
 		// EstablishTunnelTimeout has terminated; so, if there was a
 		// EstablishTunnelTimeout event, ErrTimeout is guaranteed to be sent to
-		// errord before this next error and will be the StartTunnel return value.
+		// errored before this next error and will be the StartTunnel return value.
 
-		var err error
-		switch ctx.Err() {
+		err := ctx.Err()
+		switch err {
 		case context.DeadlineExceeded:
 			err = ErrTimeout
 		case context.Canceled:
-			err = errors.TraceNew("StartTunnel canceled")
+			err = errors.TraceMsg(err, "StartTunnel canceled")
 		default:
-			err = errors.TraceNew("controller.Run exited unexpectedly")
+			err = errors.TraceMsg(err, "controller.Run exited unexpectedly")
 		}
 		select {
-		case errored <- err:
+		case erroredCh <- err:
 		default:
 		}
 	}()
 
 	// Wait for an active tunnel or error
 	select {
-	case <-connected:
+	case <-connectedSignal:
 		return tunnel, nil
-	case err := <-errored:
-		tunnel.Stop()
+	case err := <-erroredCh:
 		if err != ErrTimeout {
 			err = errors.TraceMsg(err, "tunnel start produced error")
 		}
@@ -321,14 +354,45 @@ func StartTunnel(
 	}
 }
 
-// Stop stops/disconnects/shuts down the tunnel. It is safe to call when not connected.
-// Not safe to call concurrently with Start.
+// Stop stops/disconnects/shuts down the tunnel.
+// It is safe to call Stop multiple times.
+// It is safe to call concurrently with Dial and with itself.
 func (tunnel *PsiphonTunnel) Stop() {
-	if tunnel.stopController == nil {
+	// Holding a lock while calling the stop function ensures that any concurrent call
+	// to Stop will wait for the first call to finish before returning, rather than
+	// returning immediately (because tunnel.stop is nil) and thereby indicating
+	// (erroneously) that the tunnel has been stopped.
+	// Stopping a tunnel happens quickly enough that this processing block shouldn't be
+	// a problem.
+	tunnel.mu.Lock()
+	defer tunnel.mu.Unlock()
+
+	if tunnel.stop == nil {
 		return
 	}
-	tunnel.stopController()
-	tunnel.controllerWaitGroup.Wait()
-	tunnel.embeddedServerListWaitGroup.Wait()
-	psiphon.CloseDataStore()
+
+	tunnel.stop()
+	tunnel.stop = nil
+	tunnel.controllerDial = nil
+
+	// Clear our notice receiver, as it is no longer needed and we should let it be
+	// garbage-collected.
+	psiphon.SetNoticeWriter(io.Discard)
+}
+
+// Dial connects to the specified address through the Psiphon tunnel.
+// It is safe to call Dial after the tunnel has been stopped.
+// It is safe to call Dial concurrently with Stop.
+func (tunnel *PsiphonTunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
+	// Ensure the dial is accessed in a thread-safe manner, without holding the lock
+	// while calling the dial function.
+	// Note that it is safe for controller.Dial to be called even after or during a tunnel
+	// shutdown (i.e., if the context has been canceled).
+	tunnel.mu.Lock()
+	dial := tunnel.controllerDial
+	tunnel.mu.Unlock()
+	if dial == nil {
+		return nil, errors.TraceNew("tunnel not started")
+	}
+	return dial(remoteAddr, nil)
 }

+ 208 - 3
ClientLibrary/clientlib/clientlib_test.go

@@ -22,7 +22,6 @@ package clientlib
 import (
 	"context"
 	"encoding/json"
-	"io/ioutil"
 	"os"
 	"testing"
 	"time"
@@ -37,8 +36,9 @@ func TestStartTunnel(t *testing.T) {
 	networkID := "UNKNOWN"
 	timeout := 60
 	quickTimeout := 1
+	trueVal := true
 
-	configJSON, err := ioutil.ReadFile("../../psiphon/controller_test.config")
+	configJSON, err := os.ReadFile("../../psiphon/controller_test.config")
 	if err != nil {
 		// Skip, don't fail, if config file is not present
 		t.Skipf("error loading configuration file: %s", err)
@@ -47,7 +47,7 @@ func TestStartTunnel(t *testing.T) {
 	// Initialize a fresh datastore and create a modified config which cannot
 	// connect without known servers, to be used in timeout cases.
 
-	testDataDirName, err := ioutil.TempDir("", "psiphon-clientlib-test")
+	testDataDirName, err := os.MkdirTemp("", "psiphon-clientlib-test")
 	if err != nil {
 		t.Fatalf("ioutil.TempDir failed: %v", err)
 	}
@@ -143,6 +143,64 @@ func TestStartTunnel(t *testing.T) {
 			wantTunnel:  true,
 			expectedErr: nil,
 		},
+		{
+			name: "Success: disable SOCKS proxy",
+			args: args{
+				ctxTimeout:              0,
+				configJSON:              configJSON,
+				embeddedServerEntryList: "",
+				params: Parameters{
+					DataRootDirectory:             &testDataDirName,
+					ClientPlatform:                &clientPlatform,
+					NetworkID:                     &networkID,
+					EstablishTunnelTimeoutSeconds: &timeout,
+					DisableLocalSocksProxy:        &trueVal,
+				},
+				paramsDelta:    nil,
+				noticeReceiver: nil,
+			},
+			wantTunnel:  true,
+			expectedErr: nil,
+		},
+		{
+			name: "Success: disable HTTP proxy",
+			args: args{
+				ctxTimeout:              0,
+				configJSON:              configJSON,
+				embeddedServerEntryList: "",
+				params: Parameters{
+					DataRootDirectory:             &testDataDirName,
+					ClientPlatform:                &clientPlatform,
+					NetworkID:                     &networkID,
+					EstablishTunnelTimeoutSeconds: &timeout,
+					DisableLocalHTTPProxy:         &trueVal,
+				},
+				paramsDelta:    nil,
+				noticeReceiver: nil,
+			},
+			wantTunnel:  true,
+			expectedErr: nil,
+		},
+		{
+			name: "Success: disable SOCKS and HTTP proxies",
+			args: args{
+				ctxTimeout:              0,
+				configJSON:              configJSON,
+				embeddedServerEntryList: "",
+				params: Parameters{
+					DataRootDirectory:             &testDataDirName,
+					ClientPlatform:                &clientPlatform,
+					NetworkID:                     &networkID,
+					EstablishTunnelTimeoutSeconds: &timeout,
+					DisableLocalSocksProxy:        &trueVal,
+					DisableLocalHTTPProxy:         &trueVal,
+				},
+				paramsDelta:    nil,
+				noticeReceiver: nil,
+			},
+			wantTunnel:  true,
+			expectedErr: nil,
+		},
 	}
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
@@ -180,6 +238,153 @@ func TestStartTunnel(t *testing.T) {
 				return
 			}
 
+			if tunnel == nil {
+				return
+			}
+
+			if tt.args.params.DisableLocalSocksProxy != nil && *tt.args.params.DisableLocalSocksProxy {
+				if tunnel.SOCKSProxyPort != 0 {
+					t.Fatalf("should not have started SOCKS proxy")
+				}
+			} else {
+				if tunnel.SOCKSProxyPort == 0 {
+					t.Fatalf("failed to start SOCKS proxy")
+				}
+			}
+
+			if tt.args.params.DisableLocalHTTPProxy != nil && *tt.args.params.DisableLocalHTTPProxy {
+				if tunnel.HTTPProxyPort != 0 {
+					t.Fatalf("should not have started HTTP proxy")
+				}
+			} else {
+				if tunnel.HTTPProxyPort == 0 {
+					t.Fatalf("failed to start HTTP proxy")
+				}
+			}
+		})
+	}
+}
+
+func TestMultipleStartTunnel(t *testing.T) {
+	configJSON, err := os.ReadFile("../../psiphon/controller_test.config")
+	if err != nil {
+		// What to do if config file is not present?
+		t.Skipf("error loading configuration file: %s", err)
+	}
+
+	testDataDirName, err := os.MkdirTemp("", "psiphon-clientlib-test")
+	if err != nil {
+		t.Fatalf("ioutil.TempDir failed: %v", err)
+	}
+	defer os.RemoveAll(testDataDirName)
+
+	ctx := context.Background()
+
+	tunnel1, err := StartTunnel(
+		ctx,
+		configJSON,
+		"",
+		Parameters{DataRootDirectory: &testDataDirName},
+		nil,
+		nil)
+
+	if err != nil {
+		t.Fatalf("first StartTunnel() error = %v", err)
+	}
+
+	// We have not stopped the tunnel, so a second StartTunnel() should fail
+	_, err = StartTunnel(
+		ctx,
+		configJSON,
+		"",
+		Parameters{DataRootDirectory: &testDataDirName},
+		nil,
+		nil)
+
+	if err != errMultipleStart {
+		t.Fatalf("second StartTunnel() should have failed with errMultipleStart; got %v", err)
+	}
+
+	// Stop the tunnel and try again
+	tunnel1.Stop()
+	tunnel3, err := StartTunnel(
+		ctx,
+		configJSON,
+		"",
+		Parameters{DataRootDirectory: &testDataDirName},
+		nil,
+		nil)
+
+	if err != nil {
+		t.Fatalf("third StartTunnel() error = %v", err)
+	}
+
+	// Stop the tunnel so it doesn't interfere with other tests
+	tunnel3.Stop()
+}
+
+func TestPsiphonTunnel_Dial(t *testing.T) {
+	trueVal := true
+	configJSON, err := os.ReadFile("../../psiphon/controller_test.config")
+	if err != nil {
+		// Skip, don't fail, if config file is not present
+		t.Skipf("error loading configuration file: %s", err)
+	}
+
+	testDataDirName, err := os.MkdirTemp("", "psiphon-clientlib-test")
+	if err != nil {
+		t.Fatalf("ioutil.TempDir failed: %v", err)
+	}
+	defer os.RemoveAll(testDataDirName)
+
+	type args struct {
+		remoteAddr string
+	}
+	tests := []struct {
+		name    string
+		args    args
+		wantErr bool
+	}{
+		{
+			name:    "Success: example.com",
+			args:    args{remoteAddr: "example.com:443"},
+			wantErr: false,
+		},
+		{
+			name:    "Failure: invalid address",
+			args:    args{remoteAddr: "example.com:99999"},
+			wantErr: true,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			tunnel, err := StartTunnel(
+				context.Background(),
+				configJSON,
+				"",
+				Parameters{
+					DataRootDirectory: &testDataDirName,
+					// Don't need local proxies for dial tests
+					// (and this is likely the configuration that will be used by consumers of the library who utilitize Dial).
+					DisableLocalSocksProxy: &trueVal,
+					DisableLocalHTTPProxy:  &trueVal,
+				},
+				nil,
+				nil)
+			if err != nil {
+				t.Fatalf("StartTunnel() error = %v", err)
+			}
+			defer tunnel.Stop()
+
+			conn, err := tunnel.Dial(tt.args.remoteAddr)
+			if (err != nil) != tt.wantErr {
+				t.Fatalf("PsiphonTunnel.Dial() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+
+			if tt.wantErr != (conn == nil) {
+				t.Fatalf("PsiphonTunnel.Dial() conn = %v, wantConn %v", conn, !tt.wantErr)
+			}
 		})
 	}
 }