Explorar o código

Add config params for custom behavior for temporary tunnels

* TargetServerEntry param specifies a specific tunnel to connect
to. When it is set, no other servers are used.

* When DisableApi is set, the Psiphon API is not used (no
handshake, no stats, etc.)

* When DisableRemoteServerListFetcher is set, the remote
server list fetch is not run.
Rod Hynes %!s(int64=11) %!d(string=hai) anos
pai
achega
2f775545c8
Modificáronse 4 ficheiros con 143 adicións e 59 borrados
  1. 3 0
      psiphon/config.go
  2. 51 36
      psiphon/controller.go
  3. 63 15
      psiphon/dataStore.go
  4. 26 8
      psiphon/tunnel.go

+ 3 - 0
psiphon/config.go

@@ -46,6 +46,9 @@ type Config struct {
 	UpstreamHttpProxyAddress           string
 	BindToDeviceProvider               DeviceBinder
 	BindToDeviceDnsServer              string
+	TargetServerEntry                  string
+	DisableApi                         bool
+	DisableRemoteServerListFetcher     bool
 }
 
 // LoadConfig parses and validates a JSON format Psiphon config JSON

+ 51 - 36
psiphon/controller.go

@@ -96,6 +96,8 @@ func NewController(config *Config) (controller *Controller, err error) {
 func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
 	Notice(NOTICE_VERSION, VERSION)
 
+	// Start components
+
 	socksProxy, err := NewSocksProxy(controller.config, controller)
 	if err != nil {
 		Notice(NOTICE_ALERT, "error initializing local SOCKS proxy: %s", err)
@@ -110,12 +112,22 @@ func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
 	}
 	defer httpProxy.Close()
 
-	/// Note: the connected reporter isn't started until a tunnel is established
+	// Note: unlike legacy Psiphon clients, this code always makes the
+	// fetch remote server list request
+
+	if !controller.config.DisableRemoteServerListFetcher {
+		controller.runWaitGroup.Add(1)
+		go controller.remoteServerListFetcher()
+	}
+
+	/// Note: the connected reporter isn't started until a tunnel is
+	// established
 
-	controller.runWaitGroup.Add(2)
-	go controller.remoteServerListFetcher()
+	controller.runWaitGroup.Add(1)
 	go controller.runTunnels()
 
+	// Wait while running
+
 	select {
 	case <-shutdownBroadcast:
 		Notice(NOTICE_INFO, "controller shutdown by request")
@@ -146,8 +158,6 @@ func (controller *Controller) SignalComponentFailure() {
 func (controller *Controller) remoteServerListFetcher() {
 	defer controller.runWaitGroup.Done()
 
-	// Note: unlike legacy Psiphon clients, this code
-	// always makes the fetch remote server list request
 loop:
 	for {
 		err := FetchRemoteServerList(
@@ -215,32 +225,42 @@ loop:
 	Notice(NOTICE_INFO, "exiting connected reporter")
 }
 
+func (controller *Controller) startConnectedReporter() {
+	if controller.config.DisableApi {
+		return
+	}
+
+	// Start the connected reporter after the first tunnel is established.
+	// Concurrency note: only the runTunnels goroutine may access startedConnectedReporter.
+	if !controller.startedConnectedReporter {
+		controller.startedConnectedReporter = true
+		controller.runWaitGroup.Add(1)
+		go controller.connectedReporter()
+	}
+}
+
 // runTunnels is the controller tunnel management main loop. It starts and stops
 // establishing tunnels based on the target tunnel pool size and the current size
 // of the pool. Tunnels are established asynchronously using worker goroutines.
+//
+// When there are no server entries for the target region/protocol, the
+// establishCandidateGenerator will yield no candidates and wait before
+// trying again. In the meantime, a remote server entry fetch may supply
+// valid candidates.
+//
 // When a tunnel is established, it's added to the active pool. The tunnel's
 // operateTunnel goroutine monitors the tunnel.
+//
 // When a tunnel fails, it's removed from the pool and the establish process is
 // restarted to fill the pool.
 func (controller *Controller) runTunnels() {
 	defer controller.runWaitGroup.Done()
 
-	// Don't start establishing until there are some server candidates. The
-	// typical case is a client with no server entries which will wait for
-	// the first successful FetchRemoteServerList to populate the data store.
-	for {
-		if HasServerEntries(
-			controller.config.EgressRegion, controller.config.TunnelProtocol) {
-			break
-		}
-		// TODO: replace polling with signal
-		timeout := time.After(5 * time.Second)
-		select {
-		case <-timeout:
-		case <-controller.shutdownBroadcast:
-			return
-		}
-	}
+	// Note: calling Count for its logging side-effect.
+	_ = CountServerEntries(controller.config.EgressRegion, controller.config.TunnelProtocol)
+
+	// Start running
+
 	controller.startEstablishing()
 loop:
 	for {
@@ -266,20 +286,15 @@ loop:
 			if controller.isFullyEstablished() {
 				controller.stopEstablishing()
 			}
-
-			// Start the connected reporter after the first tunnel is established.
-			// Concurrency note: only this goroutine may access startedConnectedReporter.
-			// isEstablishing.
-			if !controller.startedConnectedReporter {
-				controller.startedConnectedReporter = true
-				controller.runWaitGroup.Add(1)
-				go controller.connectedReporter()
-			}
+			controller.startConnectedReporter()
 
 		case <-controller.shutdownBroadcast:
 			break loop
 		}
 	}
+
+	// Stop running
+
 	controller.stopEstablishing()
 	controller.terminateAllTunnels()
 
@@ -428,10 +443,7 @@ func (controller *Controller) Dial(remoteAddr string) (conn net.Conn, err error)
 		return nil, ContextError(err)
 	}
 
-	statsConn := NewStatsConn(
-		tunneledConn, tunnel.session.StatsServerID(), tunnel.session.StatsRegexps())
-
-	return statsConn, nil
+	return tunneledConn, nil
 }
 
 // startEstablishing creates a pool of worker goroutines which will
@@ -485,8 +497,7 @@ func (controller *Controller) stopEstablishing() {
 func (controller *Controller) establishCandidateGenerator() {
 	defer controller.establishWaitGroup.Done()
 
-	iterator, err := NewServerEntryIterator(
-		controller.config.EgressRegion, controller.config.TunnelProtocol)
+	iterator, err := NewServerEntryIterator(controller.config)
 	if err != nil {
 		Notice(NOTICE_ALERT, "failed to iterate over candidates: %s", err)
 		controller.SignalComponentFailure()
@@ -495,7 +506,10 @@ func (controller *Controller) establishCandidateGenerator() {
 	defer iterator.Close()
 
 loop:
+	// Repeat until stopped
 	for {
+
+		// Yield each server entry returned by the iterator
 		for {
 			serverEntry, err := iterator.Next()
 			if err != nil {
@@ -530,6 +544,7 @@ loop:
 			break loop
 		}
 	}
+
 	close(controller.candidateServerEntries)
 	Notice(NOTICE_INFO, "stopped candidate generator")
 }

+ 63 - 15
psiphon/dataStore.go

@@ -264,19 +264,28 @@ func PromoteServerEntry(ipAddress string) error {
 // ServerEntryIterator is used to iterate over
 // stored server entries in rank order.
 type ServerEntryIterator struct {
-	region      string
-	protocol    string
-	excludeIds  []string
-	transaction *sql.Tx
-	cursor      *sql.Rows
+	region                      string
+	protocol                    string
+	transaction                 *sql.Tx
+	cursor                      *sql.Rows
+	isTargetServerEntryIterator bool
+	hasNextTargetServerEntry    bool
+	targetServerEntry           *ServerEntry
 }
 
 // NewServerEntryIterator creates a new NewServerEntryIterator
-func NewServerEntryIterator(region, protocol string) (iterator *ServerEntryIterator, err error) {
+func NewServerEntryIterator(config *Config) (iterator *ServerEntryIterator, err error) {
+
+	// When configured, this target server entry is the only candidate
+	if config.TargetServerEntry != "" {
+		return newTargetServerEntryIterator(config)
+	}
+
 	checkInitDataStore()
 	iterator = &ServerEntryIterator{
-		region:   region,
-		protocol: protocol,
+		region:                      config.EgressRegion,
+		protocol:                    config.TunnelProtocol,
+		isTargetServerEntryIterator: false,
 	}
 	err = iterator.Reset()
 	if err != nil {
@@ -285,10 +294,41 @@ func NewServerEntryIterator(region, protocol string) (iterator *ServerEntryItera
 	return iterator, nil
 }
 
+// newTargetServerEntryIterator is a helper for initializing the TargetServerEntry case
+func newTargetServerEntryIterator(config *Config) (iterator *ServerEntryIterator, err error) {
+	serverEntry, err := DecodeServerEntry(config.TargetServerEntry)
+	if err != nil {
+		return nil, err
+	}
+	if config.EgressRegion != "" && serverEntry.Region != config.EgressRegion {
+		return nil, errors.New("TargetServerEntry does not support EgressRegion")
+	}
+	if config.TunnelProtocol != "" {
+		// Note: same capability/protocol mapping as in StoreServerEntry
+		requiredCapability := strings.TrimSuffix(config.TunnelProtocol, "-OSSH")
+		if !Contains(serverEntry.Capabilities, requiredCapability) {
+			return nil, errors.New("TargetServerEntry does not support TunnelProtocol")
+		}
+	}
+	iterator = &ServerEntryIterator{
+		isTargetServerEntryIterator: true,
+		hasNextTargetServerEntry:    true,
+		targetServerEntry:           serverEntry,
+	}
+	Notice(NOTICE_INFO, "using TargetServerEntry: %s", serverEntry.IpAddress)
+	return iterator, nil
+}
+
 // Reset a NewServerEntryIterator to the start of its cycle. The next
 // call to Next will return the first server entry.
 func (iterator *ServerEntryIterator) Reset() error {
 	iterator.Close()
+
+	if iterator.isTargetServerEntryIterator {
+		iterator.hasNextTargetServerEntry = true
+		return nil
+	}
+
 	transaction, err := singleton.db.Begin()
 	if err != nil {
 		return ContextError(err)
@@ -347,6 +387,15 @@ func (iterator *ServerEntryIterator) Next() (serverEntry *ServerEntry, err error
 			iterator.Close()
 		}
 	}()
+
+	if iterator.isTargetServerEntryIterator {
+		if iterator.hasNextTargetServerEntry {
+			iterator.hasNextTargetServerEntry = false
+			return iterator.targetServerEntry, nil
+		}
+		return nil, nil
+	}
+
 	if !iterator.cursor.Next() {
 		err = iterator.cursor.Err()
 		if err != nil {
@@ -406,10 +455,9 @@ func makeServerEntryWhereClause(
 	return whereClause, whereParams
 }
 
-// HasServerEntries returns true if the data store contains at
-// least one server entry (for the specified region and/or protocol,
-// when not blank).
-func HasServerEntries(region, protocol string) bool {
+// CountServerEntries returns a count of stored servers for the
+// specified region and protocol.
+func CountServerEntries(region, protocol string) int {
 	checkInitDataStore()
 	var count int
 	whereClause, whereParams := makeServerEntryWhereClause(region, protocol, nil)
@@ -417,8 +465,8 @@ func HasServerEntries(region, protocol string) bool {
 	err := singleton.db.QueryRow(query, whereParams...).Scan(&count)
 
 	if err != nil {
-		Notice(NOTICE_ALERT, "HasServerEntries failed: %s", err)
-		return false
+		Notice(NOTICE_ALERT, "CountServerEntries failed: %s", err)
+		return 0
 	}
 
 	if region == "" {
@@ -430,7 +478,7 @@ func HasServerEntries(region, protocol string) bool {
 	Notice(NOTICE_INFO, "servers for region %s and protocol %s: %d",
 		region, protocol, count)
 
-	return count > 0
+	return count
 }
 
 // GetServerEntryIpAddresses returns an array containing

+ 26 - 8
psiphon/tunnel.go

@@ -142,10 +142,12 @@ func EstablishTunnel(
 	// TODO: as long as the servers are not enforcing that a client perform a handshake,
 	// proceed with this tunnel as long as at least one previous handhake succeeded?
 	//
-	Notice(NOTICE_INFO, "starting session for %s", tunnel.serverEntry.IpAddress)
-	tunnel.session, err = NewSession(config, tunnel, sessionId)
-	if err != nil {
-		return nil, ContextError(fmt.Errorf("error starting session for %s: %s", tunnel.serverEntry.IpAddress, err))
+	if !config.DisableApi {
+		Notice(NOTICE_INFO, "starting session for %s", tunnel.serverEntry.IpAddress)
+		tunnel.session, err = NewSession(config, tunnel, sessionId)
+		if err != nil {
+			return nil, ContextError(fmt.Errorf("error starting session for %s: %s", tunnel.serverEntry.IpAddress, err))
+		}
 	}
 
 	// Now that network operations are complete, cancel interruptibility
@@ -194,10 +196,17 @@ func (tunnel *Tunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
 		return nil, ContextError(err)
 	}
 
-	return &TunneledConn{
-			Conn:   sshPortForwardConn,
-			tunnel: tunnel},
-		nil
+	conn = &TunneledConn{
+		Conn:   sshPortForwardConn,
+		tunnel: tunnel}
+
+	// Tunnel does not have a session when DisableApi is set
+	if tunnel.session != nil {
+		conn = NewStatsConn(
+			conn, tunnel.session.StatsServerID(), tunnel.session.StatsRegexps())
+	}
+
+	return conn, nil
 }
 
 // TunneledConn implements net.Conn and wraps a port foward connection.
@@ -409,6 +418,9 @@ func dialSsh(
 // some typical error messages to consider matching against (or ignoring):
 //
 // - "ssh: rejected: administratively prohibited (open failed)"
+//   (this error message is reported in both actual and false cases: when a server
+//    is overloaded and has no free ephemeral ports; and when the user mistypes
+//    a domain in a browser address bar and name resolution fails)
 // - "ssh: rejected: connect failed (Connection timed out)"
 // - "write tcp ... broken pipe"
 // - "read tcp ... connection reset by peer"
@@ -464,6 +476,12 @@ func (tunnel *Tunnel) operateTunnel(config *Config, tunnelOwner TunnelOwner) {
 
 // sendStats is a helper for sending session stats to the server.
 func sendStats(tunnel *Tunnel) {
+
+	// Tunnel does not have a session when DisableApi is set
+	if tunnel.session == nil {
+		return
+	}
+
 	payload := GetForServer(tunnel.serverEntry.IpAddress)
 	if payload != nil {
 		err := tunnel.session.DoStatusRequest(payload)