Просмотр исходного кода

Add new port forwarding split tunnel scheme

The server now assists the client in classifying split tunnel destinations.

The new scheme improves on the previous scheme by: removing the dependency on
sometimes-unreliable TCP DNS; eliminating an extra round trip for tunneled
port forwards in split tunnel mode; eliminating the need for clients to
download and update split tunnel routing information.

The new scheme logs a split_tunnel flag in server_tunnel for clients that
enable split tunnel mode.

Destination classifications are made using the same underlying GeoIP data
source used by the previous scheme; client traffic will be tunneled or
untunneled just as it was under the previous scheme.
Rod Hynes 5 лет назад
Родитель
Сommit
237bfa7147

+ 18 - 5
psiphon/common/crypto/ssh/tcpip.go

@@ -332,12 +332,18 @@ func (l *tcpListener) Addr() net.Addr {
 	return l.laddr
 }
 
+// [Psiphon]
+// directTCPIPNoSplitTunnel is the same as "direct-tcpip", except it indicates
+// custom split tunnel behavior. It shares the same payload. We allow the
+// Client.Dial network type to optionally specify a channel type instead.
+const directTCPIPNoSplitTunnel = "direct-tcpip-no-split-tunnel@psiphon.ca"
+
 // Dial initiates a connection to the addr from the remote host.
 // The resulting connection has a zero LocalAddr() and RemoteAddr().
 func (c *Client) Dial(n, addr string) (net.Conn, error) {
 	var ch Channel
 	switch n {
-	case "tcp", "tcp4", "tcp6":
+	case "tcp", "tcp4", "tcp6", "direct-tcpip", directTCPIPNoSplitTunnel:
 		// Parse the address into host and numeric port.
 		host, portString, err := net.SplitHostPort(addr)
 		if err != nil {
@@ -347,7 +353,14 @@ func (c *Client) Dial(n, addr string) (net.Conn, error) {
 		if err != nil {
 			return nil, err
 		}
-		ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
+
+		// [Psiphon]
+		channelType := "direct-tcpip"
+		if n == directTCPIPNoSplitTunnel {
+			channelType = directTCPIPNoSplitTunnel
+		}
+
+		ch, err = c.dial(channelType, net.IPv4zero.String(), 0, host, int(port))
 		if err != nil {
 			return nil, err
 		}
@@ -393,7 +406,7 @@ func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error)
 			Port: 0,
 		}
 	}
-	ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
+	ch, err := c.dial("direct-tcpip", laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
 	if err != nil {
 		return nil, err
 	}
@@ -412,14 +425,14 @@ type channelOpenDirectMsg struct {
 	lport uint32
 }
 
-func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
+func (c *Client) dial(channelType string, laddr string, lport int, raddr string, rport int) (Channel, error) {
 	msg := channelOpenDirectMsg{
 		raddr: raddr,
 		rport: uint32(rport),
 		laddr: laddr,
 		lport: uint32(lport),
 	}
-	ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
+	ch, in, err := c.OpenChannel(channelType, Marshal(&msg))
 	if err != nil {
 		return nil, err
 	}

+ 9 - 0
psiphon/common/parameters/parameters.go

@@ -168,6 +168,8 @@ const (
 	SplitTunnelRoutesURLFormat                       = "SplitTunnelRoutesURLFormat"
 	SplitTunnelRoutesSignaturePublicKey              = "SplitTunnelRoutesSignaturePublicKey"
 	SplitTunnelDNSServer                             = "SplitTunnelDNSServer"
+	SplitTunnelClassificationTTL                     = "SplitTunnelClassificationTTL"
+	SplitTunnelClassificationMaxEntries              = "SplitTunnelClassificationMaxEntries"
 	FetchUpgradeTimeout                              = "FetchUpgradeTimeout"
 	FetchUpgradeRetryPeriod                          = "FetchUpgradeRetryPeriod"
 	FetchUpgradeStalePeriod                          = "FetchUpgradeStalePeriod"
@@ -433,11 +435,18 @@ var defaultParameters = map[string]struct {
 
 	PsiphonAPIConnectedRequestRetryPeriod: {value: 5 * time.Second, minimum: 1 * time.Millisecond},
 
+	// FetchSplitTunnelRoutesTimeout, SplitTunnelRoutesURLFormat,
+	// SplitTunnelRoutesSignaturePublicKey and SplitTunnelDNSServer are obsoleted
+	// by the server-assisted split tunnel implementation.
+	// TODO: remove once no longer required for older clients.
 	FetchSplitTunnelRoutesTimeout:       {value: 60 * time.Second, minimum: 1 * time.Second, flags: useNetworkLatencyMultiplier},
 	SplitTunnelRoutesURLFormat:          {value: ""},
 	SplitTunnelRoutesSignaturePublicKey: {value: ""},
 	SplitTunnelDNSServer:                {value: ""},
 
+	SplitTunnelClassificationTTL:        {value: 24 * time.Hour, minimum: 0 * time.Second},
+	SplitTunnelClassificationMaxEntries: {value: 65536, minimum: 0},
+
 	FetchUpgradeTimeout:                {value: 60 * time.Second, minimum: 1 * time.Second, flags: useNetworkLatencyMultiplier},
 	FetchUpgradeRetryPeriod:            {value: 30 * time.Second, minimum: 1 * time.Millisecond},
 	FetchUpgradeStalePeriod:            {value: 6 * time.Hour, minimum: 1 * time.Hour},

+ 9 - 2
psiphon/common/protocol/protocol.go

@@ -74,8 +74,15 @@ const (
 	PSIPHON_SSH_API_PROTOCOL = "ssh"
 	PSIPHON_WEB_API_PROTOCOL = "web"
 
-	PACKET_TUNNEL_CHANNEL_TYPE = "tun@psiphon.ca"
-	RANDOM_STREAM_CHANNEL_TYPE = "random@psiphon.ca"
+	PACKET_TUNNEL_CHANNEL_TYPE            = "tun@psiphon.ca"
+	RANDOM_STREAM_CHANNEL_TYPE            = "random@psiphon.ca"
+	TCP_PORT_FORWARD_NO_SPLIT_TUNNEL_TYPE = "direct-tcpip-no-split-tunnel@psiphon.ca"
+
+	// Reject reason codes are returned in SSH open channel responses.
+	//
+	// Values 0xFE000000 to 0xFFFFFFFF are reserved for "PRIVATE USE" (see
+	// https://tools.ietf.org/rfc/rfc4254.html#section-5.1).
+	CHANNEL_REJECT_REASON_SPLIT_TUNNEL = 0xFE000000
 
 	PSIPHON_API_HANDSHAKE_AUTHORIZATIONS = "authorizations"
 

+ 5 - 34
psiphon/config.go

@@ -129,6 +129,11 @@ type Config struct {
 	// in any country is selected.
 	EgressRegion string
 
+	// EnableSplitTunnel toggles split tunnel mode. When enabled, TCP port
+	// forward destinations that resolve to the same GeoIP country as the client
+	// are connected to directly, untunneled.
+	EnableSplitTunnel bool
+
 	// ListenInterface specifies which interface to listen on.  If no
 	// interface is provided then listen on 127.0.0.1. If 'any' is provided
 	// then use 0.0.0.0. If there are multiple IP addresses on an interface
@@ -349,27 +354,6 @@ type Config struct {
 	// OnlyAfterAttempts = 0.
 	ObfuscatedServerListRootURLs parameters.TransferURLs
 
-	// SplitTunnelRoutesURLFormat is a URL which specifies the location of a
-	// routes file to use for split tunnel mode. The URL must include a
-	// placeholder for the client region to be supplied. Split tunnel mode
-	// uses the routes file to classify port forward destinations as foreign
-	// or domestic and does not tunnel domestic destinations. Split tunnel
-	// mode is on when all the SplitTunnel parameters are supplied. This value
-	// is supplied by and depends on the Psiphon Network, and is typically
-	// embedded in the client binary.
-	SplitTunnelRoutesURLFormat string
-
-	// SplitTunnelRoutesSignaturePublicKey specifies a public key that's used
-	// to authenticate the split tunnel routes payload. This value is supplied
-	// by and depends on the Psiphon Network, and is typically embedded in the
-	// client binary.
-	SplitTunnelRoutesSignaturePublicKey string
-
-	// SplitTunnelDNSServer specifies a DNS server to use when resolving port
-	// forward target domain names to IP addresses for classification. The DNS
-	// server must support TCP requests.
-	SplitTunnelDNSServer string
-
 	// UpgradeDownloadURLs is list of URLs which specify locations from which
 	// to download a host client upgrade file, when one is available. The core
 	// tunnel controller provides a resumable download facility which
@@ -1036,15 +1020,6 @@ func (config *Config) Commit(migrateFromLegacyFields bool) error {
 		}
 	}
 
-	if config.SplitTunnelRoutesURLFormat != "" {
-		if config.SplitTunnelRoutesSignaturePublicKey == "" {
-			return errors.TraceNew("missing SplitTunnelRoutesSignaturePublicKey")
-		}
-		if config.SplitTunnelDNSServer == "" {
-			return errors.TraceNew("missing SplitTunnelDNSServer")
-		}
-	}
-
 	if config.UpgradeDownloadURLs != nil {
 		if config.UpgradeDownloadClientVersionHeader == "" {
 			return errors.TraceNew("missing UpgradeDownloadClientVersionHeader")
@@ -1456,10 +1431,6 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 
 	}
 
-	applyParameters[parameters.SplitTunnelRoutesURLFormat] = config.SplitTunnelRoutesURLFormat
-	applyParameters[parameters.SplitTunnelRoutesSignaturePublicKey] = config.SplitTunnelRoutesSignaturePublicKey
-	applyParameters[parameters.SplitTunnelDNSServer] = config.SplitTunnelDNSServer
-
 	if config.UpgradeDownloadURLs != nil {
 		applyParameters[parameters.UpgradeDownloadClientVersionHeader] = config.UpgradeDownloadClientVersionHeader
 		applyParameters[parameters.UpgradeDownloadURLs] = config.UpgradeDownloadURLs

+ 96 - 33
psiphon/controller.go

@@ -38,6 +38,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tun"
+	lrucache "github.com/cognusion/go-cache-lru"
 )
 
 // Controller is a tunnel lifecycle coordinator. It manages lists of servers to
@@ -70,7 +71,9 @@ type Controller struct {
 	establishedTunnelsCount                 int32
 	candidateServerEntries                  chan *candidateServerEntry
 	untunneledDialConfig                    *DialConfig
-	splitTunnelClassifier                   *SplitTunnelClassifier
+	untunneledSplitTunnelClassifications    *lrucache.Cache
+	splitTunnelClassificationTTL            time.Duration
+	splitTunnelClassificationMaxEntries     int
 	signalFetchCommonRemoteServerList       chan struct{}
 	signalFetchObfuscatedServerLists        chan struct{}
 	signalDownloadUpgrade                   chan string
@@ -106,6 +109,18 @@ func NewController(config *Config) (controller *Controller, err error) {
 		TrustedCACertificatesFilename: config.TrustedCACertificatesFilename,
 	}
 
+	// Attempt to apply any valid, local stored tactics. The pre-done context
+	// ensures no tactics request is attempted now.
+	doneContext, cancelFunc := context.WithCancel(context.Background())
+	cancelFunc()
+	GetTactics(doneContext, config)
+
+	p := config.GetParameters().Get()
+	splitTunnelClassificationTTL :=
+		p.Duration(parameters.SplitTunnelClassificationTTL)
+	splitTunnelClassificationMaxEntries :=
+		p.Int(parameters.SplitTunnelClassificationMaxEntries)
+
 	controller = &Controller{
 		config:       config,
 		runWaitGroup: new(sync.WaitGroup),
@@ -119,6 +134,11 @@ func NewController(config *Config) (controller *Controller, err error) {
 		isEstablishing:       false,
 		untunneledDialConfig: untunneledDialConfig,
 
+		untunneledSplitTunnelClassifications: lrucache.NewWithLRU(
+			splitTunnelClassificationTTL,
+			1*time.Minute,
+			splitTunnelClassificationMaxEntries),
+
 		// TODO: Add a buffer of 1 so we don't miss a signal while receiver is
 		// starting? Trade-off is potential back-to-back fetch remotes. As-is,
 		// establish will eventually signal another fetch remote.
@@ -138,8 +158,6 @@ func NewController(config *Config) (controller *Controller, err error) {
 		signalRestartEstablishing: make(chan struct{}, 1),
 	}
 
-	controller.splitTunnelClassifier = NewSplitTunnelClassifier(config, controller)
-
 	if config.PacketTunnelTunFileDescriptor > 0 {
 
 		// Run a packet tunnel client. The lifetime of the tun.Client is the
@@ -277,8 +295,6 @@ func (controller *Controller) Run(ctx context.Context) {
 
 	controller.runWaitGroup.Wait()
 
-	controller.splitTunnelClassifier.Shutdown()
-
 	NoticeInfo("exiting controller")
 
 	NoticeExiting()
@@ -945,19 +961,6 @@ loop:
 
 			if isFirstTunnel {
 
-				// The split tunnel classifier is started once the first tunnel is
-				// established. This first tunnel is passed in to be used to make
-				// the routes data request.
-				// A long-running controller may run while the host device is present
-				// in different regions. In this case, we want the split tunnel logic
-				// to switch to routes for new regions and not classify traffic based
-				// on routes installed for older regions.
-				// We assume that when regions change, the host network will also
-				// change, and so all tunnels will fail and be re-established. Under
-				// that assumption, the classifier will be re-Start()-ed here when
-				// the region has changed.
-				controller.splitTunnelClassifier.Start(connectedTunnel)
-
 				// Signal a connected request on each 1st tunnel establishment. For
 				// multi-tunnels, the session is connected as long as at least one
 				// tunnel is established.
@@ -1213,40 +1216,100 @@ func (controller *Controller) getTunnelPoolSize() int {
 // Dial selects an active tunnel and establishes a port forward
 // connection through the selected tunnel. Failure to connect is considered
 // a port forward failure, for the purpose of monitoring tunnel health.
+//
+// When split tunnel mode is enabled, the connection may be untunneled,
+// depending on GeoIP classification of the destination.
+//
+// downstreamConn is an optional parameter which specifies a connection to be
+// explicitly closed when the dialed connection is closed. For instance, this
+// is used to close downstreamConn App<->LocalProxy connections when the
+// related LocalProxy<->SshPortForward connections close.
 func (controller *Controller) Dial(
-	remoteAddr string, alwaysTunnel bool, downstreamConn net.Conn) (conn net.Conn, err error) {
+	remoteAddr string, downstreamConn net.Conn) (conn net.Conn, err error) {
 
 	tunnel := controller.getNextActiveTunnel()
 	if tunnel == nil {
 		return nil, errors.TraceNew("no active tunnels")
 	}
 
-	// Perform split tunnel classification when feature is enabled, and if the remote
-	// address is classified as untunneled, dial directly.
-	if !alwaysTunnel && controller.config.SplitTunnelDNSServer != "" {
+	// In split tunnel mode, TCP port forwards to destinations in the same
+	// country as the client are untunneled.
+	//
+	// Split tunnel is implemented with assistence from the server to classify
+	// destinations as being in the same country as the client. The server knows
+	// the client's public IP GeoIP data, and, for clients with split tunnel mode
+	// enabled, the server resolves the port forward destination address and
+	// checks the destination IP GeoIP data.
+	//
+	// When the countries match, the server "rejects" the port forward with a
+	// distinct response that indicates to the client that an untunneled port
+	// foward should be established locally.
+	//
+	// The client maintains a classification cache that allows it to make
+	// untunneled port forwards without requiring a round trip to the server.
+	// Only destinations classified as untunneled are stored in the cache: a
+	// destination classified as tunneled requires the same round trip as an
+	// unknown destination.
+	//
+	// When the countries do not match, the server establishes a port forward, as
+	// it does for all port forwards in non-split tunnel mode. There is no
+	// additional round trip for tunneled port forwards.
+
+	untunneledCache := controller.untunneledSplitTunnelClassifications
+	var splitTunnelHost string
+	cachedUntunneled := false
+
+	if controller.config.EnableSplitTunnel {
+		var err error
+		splitTunnelHost, _, err = net.SplitHostPort(remoteAddr)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		// If the destination hostname is in the untunneled split tunnel
+		// classifications cache, skip the round trip to the server and do the
+		// direct, untunneled dial immediately.
+		_, cachedUntunneled = untunneledCache.Get(splitTunnelHost)
+	}
+
+	if !cachedUntunneled {
 
-		host, _, err := net.SplitHostPort(remoteAddr)
+		tunneledConn, splitTunnel, err := tunnel.Dial(
+			remoteAddr, false, downstreamConn)
 		if err != nil {
 			return nil, errors.Trace(err)
 		}
 
-		// Note: a possible optimization, when split tunnel is active and IsUntunneled performs
-		// a DNS resolution in order to make its classification, is to reuse that IP address in
-		// the following Dials so they do not need to make their own resolutions. However, the
-		// way this is currently implemented ensures that, e.g., DNS geo load balancing occurs
-		// relative to the outbound network.
+		if !splitTunnel {
+
+			if controller.config.EnableSplitTunnel {
 
-		if controller.splitTunnelClassifier.IsUntunneled(host) {
-			return controller.DirectDial(remoteAddr)
+				// Clear any cached untunneled classification entry for this destination
+				// hostname, as the server is now classifying it as tunneled.
+				untunneledCache.Delete(splitTunnelHost)
+			}
+
+			return tunneledConn, nil
+		}
+
+		if !controller.config.EnableSplitTunnel {
+			return nil, errors.TraceNew(
+				"unexpected split tunnel classification")
 		}
+
+		// The server has indicated that the client should make a direct,
+		// untunneled dial. Cache the classification to avoid this round trip in
+		// the immediate future.
+		untunneledCache.Add(splitTunnelHost, true, 0)
 	}
 
-	tunneledConn, err := tunnel.Dial(remoteAddr, alwaysTunnel, downstreamConn)
+	NoticeUntunneled(splitTunnelHost)
+
+	untunneledConn, err := controller.DirectDial(remoteAddr)
 	if err != nil {
 		return nil, errors.Trace(err)
 	}
-
-	return tunneledConn, nil
+	return untunneledConn, nil
 }
 
 // DirectDial dials an untunneled TCP connection within the controller run context.

+ 0 - 70
psiphon/dataStore.go

@@ -41,8 +41,6 @@ var (
 	datastoreServerEntriesBucket                = []byte("serverEntries")
 	datastoreServerEntryTagsBucket              = []byte("serverEntryTags")
 	datastoreServerEntryTombstoneTagsBucket     = []byte("serverEntryTombstoneTags")
-	datastoreSplitTunnelRouteETagsBucket        = []byte("splitTunnelRouteETags")
-	datastoreSplitTunnelRouteDataBucket         = []byte("splitTunnelRouteData")
 	datastoreUrlETagsBucket                     = []byte("urlETags")
 	datastoreKeyValueBucket                     = []byte("keyValues")
 	datastoreRemoteServerListStatsBucket        = []byte("remoteServerListStats")
@@ -1390,74 +1388,6 @@ func CountServerEntries() int {
 	return count
 }
 
-// SetSplitTunnelRoutes updates the cached routes data for
-// the given region. The associated etag is also stored and
-// used to make efficient web requests for updates to the data.
-func SetSplitTunnelRoutes(region, etag string, data []byte) error {
-
-	err := datastoreUpdate(func(tx *datastoreTx) error {
-		bucket := tx.bucket(datastoreSplitTunnelRouteETagsBucket)
-		err := bucket.put([]byte(region), []byte(etag))
-		if err != nil {
-			return errors.Trace(err)
-		}
-
-		bucket = tx.bucket(datastoreSplitTunnelRouteDataBucket)
-		err = bucket.put([]byte(region), data)
-		if err != nil {
-			return errors.Trace(err)
-		}
-
-		return nil
-	})
-
-	if err != nil {
-		return errors.Trace(err)
-	}
-	return nil
-}
-
-// GetSplitTunnelRoutesETag retrieves the etag for cached routes
-// data for the specified region. If not found, it returns an empty string value.
-func GetSplitTunnelRoutesETag(region string) (string, error) {
-
-	var etag string
-
-	err := datastoreView(func(tx *datastoreTx) error {
-		bucket := tx.bucket(datastoreSplitTunnelRouteETagsBucket)
-		etag = string(bucket.get([]byte(region)))
-		return nil
-	})
-
-	if err != nil {
-		return "", errors.Trace(err)
-	}
-	return etag, nil
-}
-
-// GetSplitTunnelRoutesData retrieves the cached routes data
-// for the specified region. If not found, it returns a nil value.
-func GetSplitTunnelRoutesData(region string) ([]byte, error) {
-
-	var data []byte
-
-	err := datastoreView(func(tx *datastoreTx) error {
-		bucket := tx.bucket(datastoreSplitTunnelRouteDataBucket)
-		value := bucket.get([]byte(region))
-		if value != nil {
-			// Must make a copy as slice is only valid within transaction.
-			data = make([]byte, len(value))
-			copy(data, value)
-		}
-		return nil
-	})
-
-	if err != nil {
-		return nil, errors.Trace(err)
-	}
-	return data, nil
-}
-
 // SetUrlETag stores an ETag for the specfied URL.
 // Note: input URL is treated as a string, and is not
 // encoded or decoded or otherwise canonicalized.

+ 2 - 2
psiphon/dataStore_bolt.go

@@ -134,8 +134,6 @@ func tryDatastoreOpenDB(
 			datastoreServerEntriesBucket,
 			datastoreServerEntryTagsBucket,
 			datastoreServerEntryTombstoneTagsBucket,
-			datastoreSplitTunnelRouteETagsBucket,
-			datastoreSplitTunnelRouteDataBucket,
 			datastoreUrlETagsBucket,
 			datastoreKeyValueBucket,
 			datastoreRemoteServerListStatsBucket,
@@ -163,6 +161,8 @@ func tryDatastoreOpenDB(
 		obsoleteBuckets := [][]byte{
 			[]byte("tunnelStats"),
 			[]byte("rankedServerEntries"),
+			[]byte("splitTunnelRouteETags"),
+			[]byte("splitTunnelRouteData"),
 		}
 		for _, obsoleteBucket := range obsoleteBuckets {
 			if tx.Bucket(obsoleteBucket) != nil {

+ 3 - 3
psiphon/httpProxy.go

@@ -112,7 +112,7 @@ func NewHttpProxy(
 		// downstreamConn is not set in this case, as there is not a fixed
 		// association between a downstream client connection and a particular
 		// tunnel.
-		return tunneler.Dial(addr, false, nil)
+		return tunneler.Dial(addr, nil)
 	}
 	directDialer := func(_, addr string) (conn net.Conn, err error) {
 		return tunneler.DirectDial(addr)
@@ -253,7 +253,7 @@ func (proxy *HttpProxy) httpConnectHandler(localConn net.Conn, target string) (e
 	// Setting downstreamConn so localConn.Close() will be called when remoteConn.Close() is called.
 	// This ensures that the downstream client (e.g., web browser) doesn't keep waiting on the
 	// open connection for data which will never arrive.
-	remoteConn, err := proxy.tunneler.Dial(target, false, localConn)
+	remoteConn, err := proxy.tunneler.Dial(target, localConn)
 	if err != nil {
 		return errors.Trace(err)
 	}
@@ -398,7 +398,7 @@ func (proxy *HttpProxy) makeRewriteICYClient() (*http.Client, *rewriteICYStatus)
 
 	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
 		// See comment in NewHttpProxy regarding downstreamConn
-		return proxy.tunneler.Dial(addr, false, nil)
+		return proxy.tunneler.Dial(addr, nil)
 	}
 
 	dial := func(network, address string) (net.Conn, error) {

+ 5 - 2
psiphon/net.go

@@ -374,8 +374,11 @@ func MakeTunneledHTTPClient(
 	// Note: there is no dial context since SSH port forward dials cannot
 	// be interrupted directly. Closing the tunnel will interrupt the dials.
 
-	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
-		return tunnel.sshClient.Dial("tcp", addr)
+	tunneledDialer := func(_, addr string) (net.Conn, error) {
+		// Set alwaysTunneled to ensure the http.Client traffic is always tunneled,
+		// even when split tunnel mode is enabled.
+		conn, _, err := tunnel.Dial(addr, true, nil)
+		return conn, errors.Trace(err)
 	}
 
 	transport := &http.Transport{

+ 20 - 0
psiphon/server/api.go

@@ -214,6 +214,10 @@ func handshakeAPIRequestHandler(
 	// the client, a value of 0 will be used.
 	establishedTunnelsCount, _ := getIntStringRequestParam(params, "established_tunnels_count")
 
+	// splitTunnel indicates if the client is using split tunnel mode. When
+	// omitted by the client, the value will be false.
+	splitTunnel, _ := getBoolStringRequestParam(params, "split_tunnel")
+
 	var authorizations []string
 	if params[protocol.PSIPHON_API_HANDSHAKE_AUTHORIZATIONS] != nil {
 		authorizations, err = getStringArrayRequestParam(params, protocol.PSIPHON_API_HANDSHAKE_AUTHORIZATIONS)
@@ -241,6 +245,7 @@ func handshakeAPIRequestHandler(
 			apiParams:               copyBaseSessionAndDialParams(params),
 			expectDomainBytes:       len(httpsRequestRegexes) > 0,
 			establishedTunnelsCount: establishedTunnelsCount,
+			splitTunnel:             splitTunnel,
 		},
 		authorizations)
 	if err != nil {
@@ -807,6 +812,7 @@ var baseParams = []requestParamSpec{
 	{"client_build_rev", isHexDigits, requestParamOptional},
 	{"tunnel_whole_device", isBooleanFlag, requestParamOptional | requestParamLogFlagAsBool},
 	{"device_region", isAnyString, requestParamOptional},
+	{"split_tunnel", isBooleanFlag, requestParamOptional | requestParamLogFlagAsBool},
 }
 
 // baseSessionParams adds to baseParams the required session_id parameter. For
@@ -1199,6 +1205,20 @@ func getIntStringRequestParam(params common.APIParameters, name string) (int, er
 	return value, nil
 }
 
+func getBoolStringRequestParam(params common.APIParameters, name string) (bool, error) {
+	if params[name] == nil {
+		return false, errors.Tracef("missing param: %s", name)
+	}
+	valueStr, ok := params[name].(string)
+	if !ok {
+		return false, errors.Tracef("invalid param: %s", name)
+	}
+	if valueStr == "1" {
+		return true, nil
+	}
+	return false, nil
+}
+
 func getPaddingSizeRequestParam(params common.APIParameters, name string) (int, error) {
 	value, err := getIntStringRequestParam(params, name)
 	if err != nil {

+ 22 - 10
psiphon/server/geoip.go

@@ -201,13 +201,28 @@ func (geoIP *GeoIPService) Reloaders() []common.Reloader {
 	return reloaders
 }
 
-// Lookup determines a GeoIPData for a given client IP address.
-func (geoIP *GeoIPService) Lookup(ipAddress string) GeoIPData {
-	result := NewGeoIPData()
+// Lookup determines a GeoIPData for a given string client IP address. Lookup
+// populates the GeoIPData.DiscoveryValue field.
+func (geoIP *GeoIPService) Lookup(strIP string) GeoIPData {
+	IP := net.ParseIP(strIP)
+	if IP == nil {
+		return NewGeoIPData()
+	}
+
+	result := geoIP.LookupIP(IP)
+
+	result.DiscoveryValue = calculateDiscoveryValue(
+		geoIP.discoveryValueHMACKey, strIP)
 
-	ip := net.ParseIP(ipAddress)
+	return result
+}
+
+// LookupIP determines a GeoIPData for a given client IP address. LookupIP
+// omits the GeoIPData.DiscoveryValue field.
+func (geoIP *GeoIPService) LookupIP(IP net.IP) GeoIPData {
+	result := NewGeoIPData()
 
-	if ip == nil || len(geoIP.databases) == 0 {
+	if len(geoIP.databases) == 0 {
 		return result
 	}
 
@@ -230,7 +245,7 @@ func (geoIP *GeoIPService) Lookup(ipAddress string) GeoIPData {
 	// the separate ISP database populates ISP.
 	for _, database := range geoIP.databases {
 		database.ReloadableFile.RLock()
-		err := database.maxMindReader.Lookup(ip, &geoIPFields)
+		err := database.maxMindReader.Lookup(IP, &geoIPFields)
 		database.ReloadableFile.RUnlock()
 		if err != nil {
 			log.WithTraceFields(LogFields{"error": err}).Warning("GeoIP lookup failed")
@@ -258,9 +273,6 @@ func (geoIP *GeoIPService) Lookup(ipAddress string) GeoIPData {
 		result.ASO = geoIPFields.ASO
 	}
 
-	result.DiscoveryValue = calculateDiscoveryValue(
-		geoIP.discoveryValueHMACKey, ipAddress)
-
 	return result
 }
 
@@ -309,7 +321,7 @@ func (geoIP *GeoIPService) InSessionCache(sessionID string) bool {
 // used as input in the server discovery algorithm. Since we do not explicitly
 // store the client IP address, we must derive the value here and store it for
 // later use by the discovery algorithm.
-// See https://bitbucket.org/psiphon/psiphon-circumvention-system/src/tip/Automation/psi_ops_discovery.py
+// See https://github.com/Psiphon-Inc/psiphon-automation/tree/master/Automation/psi_ops_discovery.py
 // for full details.
 func calculateDiscoveryValue(discoveryValueHMACKey, ipAddress string) int {
 	// From: psi_ops_discovery.calculate_ip_address_strategy_value:

Разница между файлами не показана из-за своего большого размера
+ 87 - 0
psiphon/server/geoip_test.go


+ 1 - 1
psiphon/server/replay_test.go

@@ -339,7 +339,7 @@ func runServerReplayClient(
 	// Meet tunnel duration critera.
 	for i := 0; i < 20; i++ {
 		time.Sleep(10 * time.Millisecond)
-		_, _ = controller.Dial("127.0.0.1:80", true, nil)
+		_, _ = controller.Dial("127.0.0.1:80", nil)
 	}
 
 	cancelFunc()

+ 81 - 3
psiphon/server/server_test.go

@@ -134,6 +134,7 @@ func TestSSH(t *testing.T) {
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -155,6 +156,7 @@ func TestOSSH(t *testing.T) {
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -176,6 +178,7 @@ func TestFragmentedOSSH(t *testing.T) {
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -197,6 +200,7 @@ func TestUnfrontedMeek(t *testing.T) {
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -219,6 +223,7 @@ func TestUnfrontedMeekHTTPS(t *testing.T) {
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -241,6 +246,7 @@ func TestUnfrontedMeekHTTPSTLS13(t *testing.T) {
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -263,6 +269,7 @@ func TestUnfrontedMeekSessionTicket(t *testing.T) {
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -285,6 +292,7 @@ func TestUnfrontedMeekSessionTicketTLS13(t *testing.T) {
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -309,6 +317,7 @@ func TestQUICOSSH(t *testing.T) {
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -333,6 +342,7 @@ func TestMarionetteOSSH(t *testing.T) {
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -354,6 +364,7 @@ func TestWebTransportAPIRequests(t *testing.T) {
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -375,6 +386,7 @@ func TestHotReload(t *testing.T) {
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -396,6 +408,7 @@ func TestDefaultSponsorID(t *testing.T) {
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -417,6 +430,7 @@ func TestDenyTrafficRules(t *testing.T) {
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -438,6 +452,7 @@ func TestOmitAuthorization(t *testing.T) {
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -459,6 +474,7 @@ func TestNoAuthorization(t *testing.T) {
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -480,6 +496,7 @@ func TestUnusedAuthorization(t *testing.T) {
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -501,6 +518,7 @@ func TestTCPOnlySLOK(t *testing.T) {
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -522,6 +540,7 @@ func TestUDPOnlySLOK(t *testing.T) {
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -543,6 +562,7 @@ func TestLivenessTest(t *testing.T) {
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -564,6 +584,7 @@ func TestPruneServerEntries(t *testing.T) {
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
 			doBurstMonitor:       false,
+			doSplitTunnel:        false,
 		})
 }
 
@@ -585,6 +606,29 @@ func TestBurstMonitor(t *testing.T) {
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
 			doBurstMonitor:       true,
+			doSplitTunnel:        false,
+		})
+}
+
+func TestSplitTunnel(t *testing.T) {
+	runServer(t,
+		&runServerConfig{
+			tunnelProtocol:       "OSSH",
+			enableSSHAPIRequests: true,
+			doHotReload:          false,
+			doDefaultSponsorID:   false,
+			denyTrafficRules:     false,
+			requireAuthorization: true,
+			omitAuthorization:    false,
+			doTunneledWebRequest: true,
+			doTunneledNTPRequest: true,
+			forceFragmenting:     false,
+			forceLivenessTest:    false,
+			doPruneServerEntries: false,
+			doDanglingTCPConn:    true,
+			doPacketManipulation: false,
+			doBurstMonitor:       false,
+			doSplitTunnel:        true,
 		})
 }
 
@@ -605,6 +649,7 @@ type runServerConfig struct {
 	doDanglingTCPConn    bool
 	doPacketManipulation bool
 	doBurstMonitor       bool
+	doSplitTunnel        bool
 }
 
 var (
@@ -750,7 +795,15 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 
 	var serverConfig map[string]interface{}
 	json.Unmarshal(serverConfigJSON, &serverConfig)
-	serverConfig["GeoIPDatabaseFilename"] = ""
+
+	// The test GeoIP database maps all IPs to a single, non-"None" country. When
+	// split tunnel mode is enabled, this should cause port forwards to be
+	// untunneled. When split tunnel mode is not enabled, port forwards should be
+	// tunneled despite the country match.
+	geoIPDatabaseFilename := filepath.Join(testDataDirName, "geoip_database.mmbd")
+	paveGeoIPDatabaseFile(t, geoIPDatabaseFilename)
+	serverConfig["GeoIPDatabaseFilenames"] = []string{geoIPDatabaseFilename}
+
 	serverConfig["PsinetDatabaseFilename"] = psinetFilename
 	serverConfig["TrafficRulesFilename"] = trafficRulesFilename
 	serverConfig["OSLConfigFilename"] = oslConfigFilename
@@ -945,6 +998,10 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	clientConfig.EmitSLOKs = true
 	clientConfig.EmitServerAlerts = true
 
+	if runConfig.doSplitTunnel {
+		clientConfig.EnableSplitTunnel = true
+	}
+
 	if !runConfig.omitAuthorization {
 		clientConfig.Authorizations = []string{clientAuthorization}
 	}
@@ -1044,11 +1101,10 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	tunnelsEstablished := make(chan struct{}, 1)
 	homepageReceived := make(chan struct{}, 1)
 	slokSeeded := make(chan struct{}, 1)
-
 	numPruneNotices := 0
 	pruneServerEntriesNoticesEmitted := make(chan struct{}, 1)
-
 	serverAlertDisallowedNoticesEmitted := make(chan struct{}, 1)
+	untunneledPortForward := make(chan struct{}, 1)
 
 	psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
 		func(notice []byte) {
@@ -1093,6 +1149,10 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 				if reason == protocol.PSIPHON_API_ALERT_DISALLOWED_TRAFFIC {
 					sendNotificationReceived(serverAlertDisallowedNoticesEmitted)
 				}
+
+			case "Untunneled":
+				sendNotificationReceived(untunneledPortForward)
+
 			}
 		}))
 
@@ -1225,6 +1285,24 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		defer danglingConn.Close()
 	}
 
+	// Test: check for split tunnel notice
+
+	if runConfig.doSplitTunnel {
+		if !runConfig.doTunneledWebRequest || expectTrafficFailure {
+			t.Fatalf("invalid test run configuration")
+		}
+		waitOnNotification(t, untunneledPortForward, nil, "")
+	} else {
+		// There should be no "Untunneled" notice. This check assumes that any
+		// unexpected Untunneled notice will have been delivered at this point,
+		// after the SLOK notice.
+		select {
+		case <-untunneledPortForward:
+			t.Fatalf("unexpected untunnedl port forward")
+		default:
+		}
+	}
+
 	// Shutdown to ensure logs/notices are flushed
 
 	stopClient()

+ 82 - 12
psiphon/server/tunnelServer.go

@@ -1280,6 +1280,7 @@ type handshakeState struct {
 	authorizationsRevoked   bool
 	expectDomainBytes       bool
 	establishedTunnelsCount int
+	splitTunnel             bool
 }
 
 type handshakeStateInfo struct {
@@ -1926,8 +1927,15 @@ func (sshClient *sshClient) runTunnel(
 			sshClient.handleNewRandomStreamChannel(waitGroup, newChannel)
 		case protocol.PACKET_TUNNEL_CHANNEL_TYPE:
 			sshClient.handleNewPacketTunnelChannel(waitGroup, newChannel)
+		case protocol.TCP_PORT_FORWARD_NO_SPLIT_TUNNEL_TYPE:
+			// The protocol.TCP_PORT_FORWARD_NO_SPLIT_TUNNEL_TYPE is the same as
+			// "direct-tcpip", except split tunnel channel rejections are disallowed
+			// even if the client has enabled split tunnel. This channel type allows
+			// the client to ensure tunneling for certain cases while split tunnel is
+			// enabled.
+			sshClient.handleNewTCPPortForwardChannel(waitGroup, newChannel, false, newTCPPortForwards)
 		case "direct-tcpip":
-			sshClient.handleNewTCPPortForwardChannel(waitGroup, newChannel, newTCPPortForwards)
+			sshClient.handleNewTCPPortForwardChannel(waitGroup, newChannel, true, newTCPPortForwards)
 		default:
 			sshClient.rejectNewChannel(newChannel,
 				fmt.Sprintf("unknown or unsupported channel type: %s", newChannel.ChannelType()))
@@ -2008,6 +2016,7 @@ type newTCPPortForward struct {
 	enqueueTime   time.Time
 	hostToConnect string
 	portToConnect int
+	doSplitTunnel bool
 	newChannel    ssh.NewChannel
 }
 
@@ -2132,6 +2141,7 @@ func (sshClient *sshClient) handleTCPPortForwards(
 				remainingDialTimeout,
 				newPortForward.hostToConnect,
 				newPortForward.portToConnect,
+				newPortForward.doSplitTunnel,
 				newPortForward.newChannel)
 		}(remainingDialTimeout, newPortForward)
 	}
@@ -2332,7 +2342,9 @@ func (sshClient *sshClient) handleNewPacketTunnelChannel(
 }
 
 func (sshClient *sshClient) handleNewTCPPortForwardChannel(
-	waitGroup *sync.WaitGroup, newChannel ssh.NewChannel,
+	waitGroup *sync.WaitGroup,
+	newChannel ssh.NewChannel,
+	allowSplitTunnel bool,
 	newTCPPortForwards chan *newTCPPortForward) {
 
 	// udpgw client connections are dispatched immediately (clients use this for
@@ -2377,11 +2389,15 @@ func (sshClient *sshClient) handleNewTCPPortForwardChannel(
 
 		// Dispatch via TCP port forward manager. When the queue is full, the channel
 		// is immediately rejected.
+		//
+		// Split tunnel logic is enabled for this TCP port forward when the client
+		// has enabled split tunnel mode and the channel type allows it.
 
 		tcpPortForward := &newTCPPortForward{
 			enqueueTime:   time.Now(),
 			hostToConnect: directTcpipExtraData.HostToConnect,
 			portToConnect: int(directTcpipExtraData.PortToConnect),
+			doSplitTunnel: sshClient.handshakeState.splitTunnel && allowSplitTunnel,
 			newChannel:    newChannel,
 		}
 
@@ -3103,16 +3119,8 @@ func (sshClient *sshClient) isPortForwardPermitted(
 	// cases, a blocklist entry won't be dialed in any case. However, no logs
 	// will be recorded.
 
-	tags := sshClient.sshServer.support.Blocklist.LookupIP(remoteIP)
-	if len(tags) > 0 {
-
-		sshClient.logBlocklistHits(remoteIP, "", tags)
-
-		if sshClient.sshServer.support.Config.BlocklistActive {
-			// Actively alert and block
-			sshClient.enqueueUnsafeTrafficAlertRequest(tags)
-			return false
-		}
+	if !sshClient.isIPPermitted(remoteIP) {
+		return false
 	}
 
 	// Don't lock before calling logBlocklistHits.
@@ -3193,6 +3201,23 @@ func (sshClient *sshClient) isDomainPermitted(domain string) (bool, string) {
 	return true, ""
 }
 
+func (sshClient *sshClient) isIPPermitted(remoteIP net.IP) bool {
+
+	tags := sshClient.sshServer.support.Blocklist.LookupIP(remoteIP)
+	if len(tags) > 0 {
+
+		sshClient.logBlocklistHits(remoteIP, "", tags)
+
+		if sshClient.sshServer.support.Config.BlocklistActive {
+			// Actively alert and block
+			sshClient.enqueueUnsafeTrafficAlertRequest(tags)
+			return false
+		}
+	}
+
+	return true
+}
+
 func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
 
 	sshClient.Lock()
@@ -3439,6 +3464,7 @@ func (sshClient *sshClient) handleTCPChannel(
 	remainingDialTimeout time.Duration,
 	hostToConnect string,
 	portToConnect int,
+	doSplitTunnel bool,
 	newChannel ssh.NewChannel) {
 
 	// Assumptions:
@@ -3549,6 +3575,50 @@ func (sshClient *sshClient) handleTCPChannel(
 		return
 	}
 
+	// When the client has indicated split tunnel mode and when the channel is
+	// not of type protocol.TCP_PORT_FORWARD_NO_SPLIT_TUNNEL_TYPE, check if the
+	// client and the port forward destination are in the same GeoIP country. If
+	// so, reject the port forward with a distinct response code that indicates
+	// to the client that this port forward should be performed locally, direct
+	// and untunneled.
+	//
+	// Clients are expected to cache untunneled responses to avoid this round
+	// trip in the immediate future and reduce server load.
+	//
+	// When the countries differ, immediately proceed with the standard port
+	// forward. No additional round trip is required.
+	//
+	// If either GeoIP country is "None", one or both countries are unknown
+	// and there is no match.
+	//
+	// Traffic rules, such as allowed ports, are not enforced for port forward
+	// destinations classified as untunneled.
+	//
+	// Domain and IP blocklists still apply to port forward destinations
+	// classified as untunneled.
+	//
+	// The client's use of split tunnel mode is logged in server_tunnel metrics
+	// as the boolean value split_tunnel. As they may indicate some information
+	// about browsing activity, no other split tunnel metrics are logged.
+
+	if doSplitTunnel {
+
+		destinationGeoIPData := sshClient.sshServer.support.GeoIPService.LookupIP(IP)
+
+		if destinationGeoIPData.Country == sshClient.geoIPData.Country &&
+			sshClient.geoIPData.Country != GEOIP_UNKNOWN_VALUE {
+
+			// Since isPortForwardPermitted is not called in this case, explicitly call
+			// ipBlocklistCheck. The domain blocklist case is handled above.
+			if !sshClient.isIPPermitted(IP) {
+				// Note: not recording a port forward failure in this case
+				sshClient.rejectNewChannel(newChannel, "port forward not permitted")
+			}
+
+			newChannel.Reject(protocol.CHANNEL_REJECT_REASON_SPLIT_TUNNEL, "")
+		}
+	}
+
 	// Enforce traffic rules, using the resolved IP address.
 
 	if !isWebServerPortForward &&

+ 14 - 1
psiphon/serverApi.go

@@ -162,6 +162,14 @@ func (serverContext *ServerContext) doHandshakeRequest(
 		}
 	}
 
+	// When split tunnel mode is enabled, indicate this to the server. When
+	// indicated, the server will perform split tunnel classifications on TCP
+	// port forwards and reject, with a distinct response, port forwards which
+	// the client should connect to directly, untunneled.
+	if serverContext.tunnel.config.EnableSplitTunnel {
+		params["split_tunnel"] = "1"
+	}
+
 	var response []byte
 	if serverContext.psiphonHttpsClient == nil {
 
@@ -1086,7 +1094,12 @@ func makePsiphonHttpsClient(tunnel *Tunnel) (httpsClient *http.Client, err error
 		return nil, errors.Trace(err)
 	}
 
-	tunneledDialer := func(_ context.Context, _, addr string) (conn net.Conn, err error) {
+	tunneledDialer := func(_ context.Context, _, addr string) (net.Conn, error) {
+		// This case bypasses tunnel.Dial, to avoid its check that the tunnel is
+		// already active (it won't be pre-handshake). This bypass won't handle the
+		// server rejecting the port forward due to split tunnel classification, but
+		// we know that the server won't classify the web API destination as
+		// untunneled.
 		return tunnel.sshClient.Dial("tcp", addr)
 	}
 

+ 1 - 1
psiphon/socksProxy.go

@@ -91,7 +91,7 @@ func (proxy *SocksProxy) socksConnectionHandler(localConn *socks.SocksConn) (err
 	// Using downstreamConn so localConn.Close() will be called when remoteConn.Close() is called.
 	// This ensures that the downstream client (e.g., web browser) doesn't keep waiting on the
 	// open connection for data which will never arrive.
-	remoteConn, err := proxy.tunneler.Dial(localConn.Req.Target, false, localConn)
+	remoteConn, err := proxy.tunneler.Dial(localConn.Req.Target, localConn)
 
 	if err != nil {
 		reason := byte(socks.SocksRepGeneralFailure)

+ 0 - 418
psiphon/splitTunnel.go

@@ -1,418 +0,0 @@
-/*
- * Copyright (c) 2015, Psiphon Inc.
- * All rights reserved.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program.  If not, see <http://www.gnu.org/licenses/>.
- *
- */
-
-package psiphon
-
-import (
-	"bytes"
-	"compress/zlib"
-	"encoding/base64"
-	"fmt"
-	"io/ioutil"
-	"net"
-	"net/http"
-	"sync"
-	"time"
-
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
-)
-
-// SplitTunnelClassifier determines whether a network destination
-// should be accessed through a tunnel or accessed directly.
-//
-// The classifier uses tables of IP address data, routes data,
-// to determine if a given IP is to be tunneled or not. If presented
-// with a hostname, the classifier performs a tunneled (uncensored)
-// DNS request to first determine the IP address for that hostname;
-// then a classification is made based on the IP address.
-//
-// Classification results (both the hostname resolution and the
-// following IP address classification) are cached for the duration
-// of the DNS record TTL.
-//
-// Classification is by geographical region (country code). When the
-// split tunnel feature is configured to be on, and if the IP
-// address is within the user's region, it may be accessed untunneled.
-// Otherwise, the IP address must be accessed through a tunnel. The
-// user's current region is revealed to a Tunnel via the Psiphon server
-// API handshake.
-//
-// When a Tunnel has a blank region (e.g., when DisableApi is set and
-// the tunnel registers without performing a handshake) then no routes
-// data is set and all IP addresses are classified as requiring tunneling.
-//
-// Split tunnel is made on a best effort basis. After the classifier is
-// started, but before routes data is available for the given region,
-// all IP addresses will be classified as requiring tunneling.
-//
-// Routes data is fetched asynchronously after Start() is called. Routes
-// data is cached in the data store so it need not be downloaded in full
-// when fresh data is in the cache.
-type SplitTunnelClassifier struct {
-	config               *Config
-	mutex                sync.RWMutex
-	userAgent            string
-	dnsTunneler          Tunneler
-	fetchRoutesWaitGroup *sync.WaitGroup
-	isRoutesSet          bool
-	cache                map[string]*classification
-	routes               common.SubnetLookup
-}
-
-type classification struct {
-	isUntunneled bool
-	expiry       time.Time
-}
-
-func NewSplitTunnelClassifier(config *Config, tunneler Tunneler) *SplitTunnelClassifier {
-	return &SplitTunnelClassifier{
-		config:               config,
-		userAgent:            MakePsiphonUserAgent(config),
-		dnsTunneler:          tunneler,
-		fetchRoutesWaitGroup: new(sync.WaitGroup),
-		isRoutesSet:          false,
-		cache:                make(map[string]*classification),
-	}
-}
-
-// Start resets the state of the classifier. In the default state,
-// all IP addresses are classified as requiring tunneling. With
-// sufficient configuration and region info, this function starts
-// a goroutine to asynchronously fetch and install the routes data.
-func (classifier *SplitTunnelClassifier) Start(fetchRoutesTunnel *Tunnel) {
-
-	classifier.mutex.Lock()
-	defer classifier.mutex.Unlock()
-
-	classifier.isRoutesSet = false
-
-	p := classifier.config.GetParameters().Get()
-	dnsServerAddress := p.String(parameters.SplitTunnelDNSServer)
-	routesSignaturePublicKey := p.String(parameters.SplitTunnelRoutesSignaturePublicKey)
-	fetchRoutesUrlFormat := p.String(parameters.SplitTunnelRoutesURLFormat)
-
-	if dnsServerAddress == "" ||
-		routesSignaturePublicKey == "" ||
-		fetchRoutesUrlFormat == "" {
-		// Split tunnel capability is not configured
-		return
-	}
-
-	if fetchRoutesTunnel.serverContext == nil {
-		// Tunnel has no serverContext
-		return
-	}
-
-	if fetchRoutesTunnel.serverContext.clientRegion == "" {
-		// Split tunnel region is unknown
-		return
-	}
-
-	classifier.fetchRoutesWaitGroup.Add(1)
-	go classifier.setRoutes(fetchRoutesTunnel)
-}
-
-// Shutdown waits until the background setRoutes() goroutine is finished.
-// There is no explicit shutdown signal sent to setRoutes() -- instead
-// we assume that in an overall shutdown situation, the tunnel used for
-// network access in setRoutes() is closed and network events won't delay
-// the completion of the goroutine.
-func (classifier *SplitTunnelClassifier) Shutdown() {
-	classifier.mutex.Lock()
-	defer classifier.mutex.Unlock()
-
-	if classifier.fetchRoutesWaitGroup != nil {
-		classifier.fetchRoutesWaitGroup.Wait()
-		classifier.fetchRoutesWaitGroup = nil
-		classifier.isRoutesSet = false
-	}
-}
-
-// IsUntunneled takes a destination hostname or IP address and determines
-// if it should be accessed through a tunnel. When a hostname is presented, it
-// is first resolved to an IP address which can be matched against the routes data.
-// Multiple goroutines may invoke RequiresTunnel simultaneously. Multi-reader
-// locks are used in the implementation to enable concurrent access, with no locks
-// held during network access.
-func (classifier *SplitTunnelClassifier) IsUntunneled(targetAddress string) bool {
-
-	if !classifier.hasRoutes() {
-		return false
-	}
-
-	dnsServerAddress := classifier.config.GetParameters().Get().String(
-		parameters.SplitTunnelDNSServer)
-	if dnsServerAddress == "" {
-		// Split tunnel has been disabled.
-		return false
-	}
-
-	classifier.mutex.RLock()
-	cachedClassification, ok := classifier.cache[targetAddress]
-	classifier.mutex.RUnlock()
-	if ok && cachedClassification.expiry.After(time.Now()) {
-		return cachedClassification.isUntunneled
-	}
-
-	ipAddr, ttl, err := tunneledLookupIP(
-		dnsServerAddress, classifier.dnsTunneler, targetAddress)
-	if err != nil {
-		NoticeWarning("failed to resolve address for split tunnel classification: %s", err)
-		return false
-	}
-	expiry := time.Now().Add(ttl)
-
-	isUntunneled := classifier.ipAddressInRoutes(ipAddr)
-
-	// TODO: garbage collect expired items from cache?
-
-	classifier.mutex.Lock()
-	classifier.cache[targetAddress] = &classification{isUntunneled, expiry}
-	classifier.mutex.Unlock()
-
-	if isUntunneled {
-		NoticeUntunneled(targetAddress)
-	}
-
-	return isUntunneled
-}
-
-// setRoutes is a background routine that fetches routes data and installs it,
-// which sets the isRoutesSet flag, indicating that IP addresses may now be classified.
-func (classifier *SplitTunnelClassifier) setRoutes(tunnel *Tunnel) {
-	defer classifier.fetchRoutesWaitGroup.Done()
-
-	// Note: a possible optimization is to install cached routes
-	// before making the request. That would ensure some split
-	// tunneling for the duration of the request.
-
-	routesData, err := classifier.getRoutes(tunnel)
-	if err != nil {
-		NoticeWarning("failed to get split tunnel routes: %s", err)
-		return
-	}
-
-	err = classifier.installRoutes(routesData)
-	if err != nil {
-		NoticeWarning("failed to install split tunnel routes: %s", err)
-		return
-	}
-
-	NoticeSplitTunnelRegion(tunnel.serverContext.clientRegion)
-}
-
-// getRoutes makes a web request to download fresh routes data for the
-// given region, as indicated by the tunnel. It uses web caching, If-None-Match/ETag,
-// to save downloading known routes data repeatedly. If the web request
-// fails and cached routes data is present, that cached data is returned.
-func (classifier *SplitTunnelClassifier) getRoutes(tunnel *Tunnel) (routesData []byte, err error) {
-
-	p := classifier.config.GetParameters().Get()
-	routesSignaturePublicKey := p.String(parameters.SplitTunnelRoutesSignaturePublicKey)
-	fetchRoutesUrlFormat := p.String(parameters.SplitTunnelRoutesURLFormat)
-	fetchTimeout := p.Duration(parameters.FetchSplitTunnelRoutesTimeout)
-	p.Close()
-
-	url := fmt.Sprintf(fetchRoutesUrlFormat, tunnel.serverContext.clientRegion)
-	request, err := http.NewRequest("GET", url, nil)
-	if err != nil {
-		return nil, errors.Trace(err)
-	}
-
-	request.Header.Set("User-Agent", classifier.userAgent)
-
-	etag, err := GetSplitTunnelRoutesETag(tunnel.serverContext.clientRegion)
-	if err != nil {
-		return nil, errors.Trace(err)
-	}
-	if etag != "" {
-		request.Header.Add("If-None-Match", etag)
-	}
-
-	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
-		return tunnel.sshClient.Dial("tcp", addr)
-	}
-	transport := &http.Transport{
-		Dial:                  tunneledDialer,
-		ResponseHeaderTimeout: fetchTimeout,
-	}
-	httpClient := &http.Client{
-		Transport: transport,
-		Timeout:   fetchTimeout,
-	}
-
-	// At this time, the largest uncompressed routes data set is ~1MB. For now,
-	// the processing pipeline is done all in-memory.
-
-	useCachedRoutes := false
-
-	response, err := httpClient.Do(request)
-
-	if err == nil &&
-		(response.StatusCode != http.StatusOK && response.StatusCode != http.StatusNotModified) {
-		response.Body.Close()
-		err = fmt.Errorf("unexpected response status code: %d", response.StatusCode)
-	}
-	if err != nil {
-		NoticeWarning("failed to request split tunnel routes package: %s", errors.Trace(err))
-		useCachedRoutes = true
-	}
-
-	if !useCachedRoutes {
-		defer response.Body.Close()
-		if response.StatusCode == http.StatusNotModified {
-			useCachedRoutes = true
-		}
-	}
-
-	var routesDataPackage []byte
-	if !useCachedRoutes {
-		routesDataPackage, err = ioutil.ReadAll(response.Body)
-		if err != nil {
-			NoticeWarning("failed to download split tunnel routes package: %s", errors.Trace(err))
-			useCachedRoutes = true
-		}
-	}
-
-	var encodedRoutesData string
-	if !useCachedRoutes {
-		encodedRoutesData, err = common.ReadAuthenticatedDataPackage(
-			routesDataPackage, false, routesSignaturePublicKey)
-		if err != nil {
-			NoticeWarning("failed to read split tunnel routes package: %s", errors.Trace(err))
-			useCachedRoutes = true
-		}
-	}
-
-	var compressedRoutesData []byte
-	if !useCachedRoutes {
-		compressedRoutesData, err = base64.StdEncoding.DecodeString(encodedRoutesData)
-		if err != nil {
-			NoticeWarning("failed to decode split tunnel routes: %s", errors.Trace(err))
-			useCachedRoutes = true
-		}
-	}
-
-	if !useCachedRoutes {
-		zlibReader, err := zlib.NewReader(bytes.NewReader(compressedRoutesData))
-		if err == nil {
-			routesData, err = ioutil.ReadAll(zlibReader)
-			zlibReader.Close()
-		}
-		if err != nil {
-			NoticeWarning("failed to decompress split tunnel routes: %s", errors.Trace(err))
-			useCachedRoutes = true
-		}
-	}
-
-	if !useCachedRoutes {
-		etag := response.Header.Get("ETag")
-		if etag != "" {
-			err := SetSplitTunnelRoutes(tunnel.serverContext.clientRegion, etag, routesData)
-			if err != nil {
-				NoticeWarning("failed to cache split tunnel routes: %s", errors.Trace(err))
-				// Proceed with fetched data, even when we can't cache it
-			}
-		}
-	}
-
-	if useCachedRoutes {
-		routesData, err = GetSplitTunnelRoutesData(tunnel.serverContext.clientRegion)
-		if err != nil {
-			return nil, errors.Trace(err)
-		}
-		if routesData == nil {
-			return nil, errors.TraceNew("no cached routes")
-		}
-	}
-
-	return routesData, nil
-}
-
-// hasRoutes checks if the classifier has routes installed.
-func (classifier *SplitTunnelClassifier) hasRoutes() bool {
-	classifier.mutex.RLock()
-	defer classifier.mutex.RUnlock()
-
-	return classifier.isRoutesSet
-}
-
-// installRoutes parses the raw routes data and creates data structures
-// for fast in-memory classification.
-func (classifier *SplitTunnelClassifier) installRoutes(routesData []byte) (err error) {
-	classifier.mutex.Lock()
-	defer classifier.mutex.Unlock()
-
-	classifier.routes, err = common.NewSubnetLookupFromRoutes(routesData)
-	if err != nil {
-		return errors.Trace(err)
-	}
-
-	classifier.isRoutesSet = true
-
-	return nil
-}
-
-// ipAddressInRoutes searches for a split tunnel candidate IP address in the routes data.
-func (classifier *SplitTunnelClassifier) ipAddressInRoutes(ipAddr net.IP) bool {
-	classifier.mutex.RLock()
-	defer classifier.mutex.RUnlock()
-
-	return classifier.routes.ContainsIPAddress(ipAddr)
-}
-
-// tunneledLookupIP resolves a split tunnel candidate hostname with a tunneled
-// DNS request.
-func tunneledLookupIP(
-	dnsServerAddress string, dnsTunneler Tunneler, host string) (addr net.IP, ttl time.Duration, err error) {
-
-	ipAddr := net.ParseIP(host)
-	if ipAddr != nil {
-		// maxDuration from golang.org/src/time/time.go
-		return ipAddr, time.Duration(1<<63 - 1), nil
-	}
-
-	// dnsServerAddress must be an IP address
-	ipAddr = net.ParseIP(dnsServerAddress)
-	if ipAddr == nil {
-		return nil, 0, errors.TraceNew("invalid IP address")
-	}
-
-	// Dial's alwaysTunnel is set to true to ensure this connection
-	// is tunneled (also ensures this code path isn't circular).
-	// Assumes tunnel dialer conn configures timeouts and interruptibility.
-
-	conn, err := dnsTunneler.Dial(fmt.Sprintf(
-		"%s:%d", dnsServerAddress, DNS_PORT), true, nil)
-	if err != nil {
-		return nil, 0, errors.Trace(err)
-	}
-
-	ipAddrs, ttls, err := ResolveIP(host, conn)
-	if err != nil {
-		return nil, 0, errors.Trace(err)
-	}
-	if len(ipAddrs) < 1 {
-		return nil, 0, errors.TraceNew("no IP address")
-	}
-
-	return ipAddrs[0], ttls[0], nil
-}

+ 8 - 0
psiphon/tactics.go

@@ -37,6 +37,9 @@ import (
 // immediately return. If no unexpired stored tactics are found, tactics
 // requests are attempted until the input context is cancelled.
 //
+// Callers may pass in a context that is already done. In this case, stored
+// tactics, when available, are applied but no request will be attempted.
+//
 // Callers are responsible for ensuring that the input context eventually
 // cancels, and should synchronize GetTactics calls to ensure no unintended
 // concurrent fetch attempts occur.
@@ -71,6 +74,11 @@ func GetTactics(ctx context.Context, config *Config) {
 		return
 	}
 
+	// If the context is already Done, don't even start the request.
+	if ctx.Err() != nil {
+		return
+	}
+
 	if tacticsRecord == nil {
 
 		iterator, err := NewTacticsServerEntryIterator(config)

+ 1 - 1
psiphon/tactics_test.go

@@ -87,7 +87,7 @@ func TestStandAloneGetTactics(t *testing.T) {
 
 	err = FetchCommonRemoteServerList(ctx, config, 0, nil, untunneledDialConfig)
 	if err != nil {
-		t.Fatalf("error cfetching remote server list: %s", err)
+		t.Fatalf("error fetching remote server list: %s", err)
 	}
 
 	// Close the datastore to exercise the OpenDatastore/CloseDatastore

+ 50 - 18
psiphon/tunnel.go

@@ -25,6 +25,7 @@ import (
 	"crypto/rand"
 	"encoding/base64"
 	"encoding/json"
+	std_errors "errors"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -47,23 +48,19 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/transferstats"
 )
 
-// Tunneler specifies the interface required by components that use a tunnel.
-// Components which use this interface may be serviced by a single Tunnel instance,
-// or a Controller which manages a pool of tunnels, or any other object which
-// implements Tunneler.
+// Tunneler specifies the interface required by components that use tunnels.
 type Tunneler interface {
 
 	// Dial creates a tunneled connection.
 	//
-	// alwaysTunnel indicates that the connection should always be tunneled. If this
-	// is not set, the connection may be made directly, depending on split tunnel
-	// classification, when that feature is supported and active.
+	// When split tunnel mode is enabled, the connection may be untunneled,
+	// depending on GeoIP classification of the destination.
 	//
 	// downstreamConn is an optional parameter which specifies a connection to be
 	// explicitly closed when the Dialed connection is closed. For instance, this
 	// is used to close downstreamConn App<->LocalProxy connections when the related
 	// LocalProxy<->SshPortForward connections close.
-	Dial(remoteAddr string, alwaysTunnel bool, downstreamConn net.Conn) (conn net.Conn, err error)
+	Dial(remoteAddr string, downstreamConn net.Conn) (conn net.Conn, err error)
 
 	DirectDial(remoteAddr string) (conn net.Conn, err error)
 
@@ -433,19 +430,37 @@ func (tunnel *Tunnel) SendAPIRequest(
 	return responsePayload, nil
 }
 
-// Dial establishes a port forward connection through the tunnel
-// This Dial doesn't support split tunnel, so alwaysTunnel is not referenced
+// Dial establishes a port forward connection through the tunnel.
+//
+// When split tunnel mode is enabled, and unless alwaysTunneled is set, the
+// server may reject the port forward and indicate that the client is to make
+// direct, untunneled connection. In this case, the bool return value is true
+// and net.Conn and error are nil.
+//
+// downstreamConn is an optional parameter which specifies a connection to be
+// explicitly closed when the dialed connection is closed.
 func (tunnel *Tunnel) Dial(
-	remoteAddr string, alwaysTunnel bool, downstreamConn net.Conn) (net.Conn, error) {
+	remoteAddr string,
+	alwaysTunneled bool,
+	downstreamConn net.Conn) (net.Conn, bool, error) {
+
+	channelType := "direct-tcpip"
+	if alwaysTunneled && tunnel.config.EnableSplitTunnel {
+		// This channel type is only necessary in split tunnel mode.
+		channelType = protocol.TCP_PORT_FORWARD_NO_SPLIT_TUNNEL_TYPE
+	}
 
-	channel, err := tunnel.dialChannel("tcp", remoteAddr)
+	channel, err := tunnel.dialChannel(channelType, remoteAddr)
 	if err != nil {
-		return nil, errors.Trace(err)
+		if isSplitTunnelRejectReason(err) {
+			return nil, true, nil
+		}
+		return nil, false, errors.Trace(err)
 	}
 
 	netConn, ok := channel.(net.Conn)
 	if !ok {
-		return nil, errors.Tracef("unexpected channel type: %T", channel)
+		return nil, false, errors.Tracef("unexpected channel type: %T", channel)
 	}
 
 	conn := &TunneledConn{
@@ -453,7 +468,18 @@ func (tunnel *Tunnel) Dial(
 		tunnel:         tunnel,
 		downstreamConn: downstreamConn}
 
-	return tunnel.wrapWithTransferStats(conn), nil
+	return tunnel.wrapWithTransferStats(conn), false, nil
+}
+
+func isSplitTunnelRejectReason(err error) bool {
+
+	var openChannelErr *ssh.OpenChannelError
+	if std_errors.As(err, &openChannelErr) {
+		return openChannelErr.Reason ==
+			ssh.RejectionReason(protocol.CHANNEL_REJECT_REASON_SPLIT_TUNNEL)
+	}
+
+	return false
 }
 
 func (tunnel *Tunnel) DialPacketTunnelChannel() (net.Conn, error) {
@@ -515,10 +541,16 @@ func (tunnel *Tunnel) dialChannel(channelType, remoteAddr string) (interface{},
 
 	go func() {
 		result := new(channelDialResult)
-		if channelType == "tcp" {
+		switch channelType {
+
+		case "direct-tcpip", protocol.TCP_PORT_FORWARD_NO_SPLIT_TUNNEL_TYPE:
+			// The protocol.TCP_PORT_FORWARD_NO_SPLIT_TUNNEL_TYPE is the same as
+			// "direct-tcpip", except split tunnel channel rejections are disallowed
+			// even when split tunnel mode is enabled.
 			result.channel, result.err =
-				tunnel.sshClient.Dial("tcp", remoteAddr)
-		} else {
+				tunnel.sshClient.Dial(channelType, remoteAddr)
+
+		default:
 			var sshRequests <-chan *ssh.Request
 			result.channel, sshRequests, result.err =
 				tunnel.sshClient.OpenChannel(channelType, nil)

Некоторые файлы не были показаны из-за большого количества измененных файлов