Просмотр исходного кода

Split tunnel code complete

* Added routes processing and IP address lookup code (submitted by @efryntov)
* Fixed SQL syntax error
* Fixed missing DNS port in tunneled DNS dial
* Fixed missing fetchRoutesWaitGroup.Add
* Fixed uninitialized buffer passed into zlib decompress
* Added NoticeUntunneledClassification to enable alerting users to which
  destinations are excluded from tunnels
Rod Hynes 11 лет назад
Родитель
Сommit
d543d4d910
5 измененных файлов с 134 добавлено и 9 удалено
  1. 0 2
      psiphon/LookupIP.go
  2. 5 2
      psiphon/dataStore.go
  3. 2 0
      psiphon/net.go
  4. 3 0
      psiphon/serverApi.go
  5. 124 5
      psiphon/splitTunnel.go

+ 0 - 2
psiphon/LookupIP.go

@@ -30,8 +30,6 @@ import (
 	"time"
 )
 
-const DNS_PORT = 53
-
 // LookupIP resolves a hostname. When BindToDevice is not required, it
 // simply uses net.LookupIP.
 // When BindToDevice is required, LookupIP explicitly creates a UDP

+ 5 - 2
psiphon/dataStore.go

@@ -531,7 +531,7 @@ func SetSplitTunnelRoutes(region, etag string, data []byte) error {
 	return transactionWithRetry(func(transaction *sql.Tx) error {
 		_, err := transaction.Exec(`
             insert or replace into splitTunnelRoutes (region, etag, data)
-            values (?, ?. ?);
+            values (?, ?, ?);
             `, region, etag, data)
 		if err != nil {
 			// Note: ContextError() would break canRetry()
@@ -557,11 +557,14 @@ func GetSplitTunnelRoutesETag(region string) (etag string, err error) {
 }
 
 // GetSplitTunnelRoutesData retrieves the cached routes data
-// for the specified region. It returns an error if not found.
+// for the specified region. If not found, it returns a nil value.
 func GetSplitTunnelRoutesData(region string) (data []byte, err error) {
 	checkInitDataStore()
 	rows := singleton.db.QueryRow("select data from splitTunnelRoutes where region = ?;", region)
 	err = rows.Scan(&data)
+	if err == sql.ErrNoRows {
+		return nil, nil
+	}
 	if err != nil {
 		return nil, ContextError(err)
 	}

+ 2 - 0
psiphon/net.go

@@ -33,6 +33,8 @@ import (
 	"github.com/Psiphon-Inc/dns"
 )
 
+const DNS_PORT = 53
+
 // DialConfig contains parameters to determine the behavior
 // of a Psiphon dialer (TCPDial, MeekDial, etc.)
 type DialConfig struct {

+ 3 - 0
psiphon/serverApi.go

@@ -214,6 +214,9 @@ func (session *Session) doHandshakeRequest() error {
 
 	session.clientRegion = handshakeConfig.ClientRegion
 
+	// ***TEMP***
+	session.clientRegion = "CA"
+
 	// Store discovered server entries
 	for _, encodedServerEntry := range handshakeConfig.EncodedServerList {
 		serverEntry, err := DecodeServerEntry(encodedServerEntry)

+ 124 - 5
psiphon/splitTunnel.go

@@ -20,14 +20,18 @@
 package psiphon
 
 import (
+	"bufio"
 	"bytes"
 	"compress/zlib"
 	"encoding/base64"
+	"encoding/binary"
 	"errors"
 	"fmt"
 	"io/ioutil"
 	"net"
 	"net/http"
+	"sort"
+	"strings"
 	"sync"
 	"time"
 )
@@ -72,6 +76,7 @@ type SplitTunnelClassifier struct {
 	fetchRoutesWaitGroup     *sync.WaitGroup
 	isRoutesSet              bool
 	cache                    map[string]*classification
+	routes                   networkList
 }
 
 type classification struct {
@@ -114,6 +119,7 @@ func (classifier *SplitTunnelClassifier) Start(fetchRoutesTunnel *Tunnel) {
 		return
 	}
 
+	classifier.fetchRoutesWaitGroup.Add(1)
 	go classifier.setRoutes(fetchRoutesTunnel)
 }
 
@@ -168,6 +174,10 @@ func (classifier *SplitTunnelClassifier) IsUntunneled(targetAddress string) bool
 	classifier.cache[targetAddress] = &classification{isUntunneled, expiry}
 	classifier.mutex.Unlock()
 
+	if isUntunneled {
+		NoticeUntunneledClassification(targetAddress)
+	}
+
 	return isUntunneled
 }
 
@@ -264,7 +274,7 @@ func (classifier *SplitTunnelClassifier) getRoutes(tunnel *Tunnel) (routesData [
 
 	var compressedRoutesData []byte
 	if !useCachedRoutes {
-		routesData, err = base64.StdEncoding.DecodeString(encodedRoutesData)
+		compressedRoutesData, err = base64.StdEncoding.DecodeString(encodedRoutesData)
 		if err != nil {
 			NoticeAlert("failed to decode split tunnel routes: %s", ContextError(err))
 			useCachedRoutes = true
@@ -300,6 +310,9 @@ func (classifier *SplitTunnelClassifier) getRoutes(tunnel *Tunnel) (routesData [
 		if err != nil {
 			return nil, ContextError(err)
 		}
+		if routesData == nil {
+			return nil, ContextError(errors.New("no cached routes"))
+		}
 	}
 
 	return routesData, nil
@@ -319,7 +332,10 @@ func (classifier *SplitTunnelClassifier) installRoutes(routesData []byte) (err e
 	classifier.mutex.Lock()
 	defer classifier.mutex.Unlock()
 
-	// ***TODO***: implementation
+	classifier.routes, err = NewNetworkList(routesData)
+	if err != nil {
+		return ContextError(err)
+	}
 
 	classifier.isRoutesSet = true
 
@@ -331,9 +347,111 @@ func (classifier *SplitTunnelClassifier) ipAddressInRoutes(ipAddr net.IP) bool {
 	classifier.mutex.RLock()
 	defer classifier.mutex.RUnlock()
 
-	// ***TODO***: implementation
+	return classifier.routes.ContainsIpAddress(ipAddr)
+}
+
+// networkList is a sorted list of network ranges. It's used to
+// lookup candidate IP addresses for split tunnel classification.
+// networkList implements Sort.Interface.
+type networkList []net.IPNet
+
+// NewNetworkList parses text routes data and produces a networkList
+// for fast ContainsIpAddress lookup.
+// The input format is expected to be text lines where each line
+// is, e.g., "1.2.3.0\t255.255.255.0\n"
+func NewNetworkList(routesData []byte) (networkList, error) {
+
+	// Parse text routes data
+	var list networkList
+	scanner := bufio.NewScanner(bytes.NewReader(routesData))
+	scanner.Split(bufio.ScanLines)
+	for scanner.Scan() {
+		s := strings.Split(scanner.Text(), "\t")
+		if len(s) != 2 {
+			continue
+		}
+
+		ip := parseIPv4(s[0])
+		mask := parseIPv4Mask(s[1])
+		if ip == nil || mask == nil {
+			continue
+		}
+
+		list = append(list, net.IPNet{IP: ip.Mask(mask), Mask: mask})
+	}
+	if len(list) == 0 {
+		return nil, ContextError(errors.New("Routes data contains no networks"))
+	}
+
+	// Sort data for fast lookup
+	sort.Sort(list)
+
+	return list, nil
+}
+
+func parseIPv4(s string) net.IP {
+	ip := net.ParseIP(s)
+	if ip == nil {
+		return nil
+	}
+	return ip.To4()
+}
+
+func parseIPv4Mask(s string) net.IPMask {
+	ip := parseIPv4(s)
+	if ip == nil {
+		return nil
+	}
+	mask := net.IPMask(ip)
+	if bits, size := mask.Size(); bits == 0 || size == 0 {
+		return nil
+	}
+	return mask
+}
+
+// Len implementes Sort.Interface
+func (list networkList) Len() int {
+	return len(list)
+}
+
+// Swap implementes Sort.Interface
+func (list networkList) Swap(i, j int) {
+	list[i], list[j] = list[j], list[i]
+}
+
+// Less implementes Sort.Interface
+func (list networkList) Less(i, j int) bool {
+	return binary.BigEndian.Uint32(list[i].IP) < binary.BigEndian.Uint32(list[j].IP)
+}
 
-	return false
+// ContainsIpAddress performs a binary search on the networkList to
+// find a network containing the candidate IP address.
+func (list networkList) ContainsIpAddress(addr net.IP) bool {
+
+	// Search criteria
+	//
+	// The following conditions are satisfied when address_IP is in the network:
+	// 1. address_IP ^ network_mask == network_IP ^ network_mask
+	// 2. address_IP >= network_IP.
+	// We are also assuming that network ranges do not overlap.
+	//
+	// For an ascending array of networks, the sort.Search returns the smallest
+	// index idx for which condition network_IP > address_IP is satisfied, so we
+	// are checking whether or not adrress_IP belongs to the network[idx-1].
+
+	// Edge conditions check
+	//
+	// idx == 0 means that address_IP is  lesser than the first (smallest) network_IP
+	// thus never satisfies search condition 2.
+	// idx == array_length means that address_IP is larger than the last (largest)
+	// network_IP so we need to check the last element for condition 1.
+
+	addrValue := binary.BigEndian.Uint32(addr.To4())
+	index := sort.Search(len(list), func(i int) bool {
+		networkValue := binary.BigEndian.Uint32(list[i].IP)
+		return networkValue > addrValue
+	})
+	return index > 0 && list[index-1].IP.Equal(addr.Mask(list[index-1].Mask))
 }
 
 // tunneledLookupIP resolves a split tunnel candidate hostname with a tunneled
@@ -357,7 +475,8 @@ func tunneledLookupIP(
 	// is tunneled (also ensures this code path isn't circular).
 	// Assumes tunnel dialer conn configures timeouts and interruptibility.
 
-	conn, err := dnsTunneler.Dial(dnsServerAddress, true, nil)
+	conn, err := dnsTunneler.Dial(fmt.Sprintf(
+		"%s:%d", dnsServerAddress, DNS_PORT), true, nil)
 	if err != nil {
 		return nil, 0, ContextError(err)
 	}