|
|
@@ -20,14 +20,18 @@
|
|
|
package psiphon
|
|
|
|
|
|
import (
|
|
|
+ "bufio"
|
|
|
"bytes"
|
|
|
"compress/zlib"
|
|
|
"encoding/base64"
|
|
|
+ "encoding/binary"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io/ioutil"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
+ "sort"
|
|
|
+ "strings"
|
|
|
"sync"
|
|
|
"time"
|
|
|
)
|
|
|
@@ -72,6 +76,7 @@ type SplitTunnelClassifier struct {
|
|
|
fetchRoutesWaitGroup *sync.WaitGroup
|
|
|
isRoutesSet bool
|
|
|
cache map[string]*classification
|
|
|
+ routes networkList
|
|
|
}
|
|
|
|
|
|
type classification struct {
|
|
|
@@ -114,6 +119,7 @@ func (classifier *SplitTunnelClassifier) Start(fetchRoutesTunnel *Tunnel) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+ classifier.fetchRoutesWaitGroup.Add(1)
|
|
|
go classifier.setRoutes(fetchRoutesTunnel)
|
|
|
}
|
|
|
|
|
|
@@ -168,6 +174,10 @@ func (classifier *SplitTunnelClassifier) IsUntunneled(targetAddress string) bool
|
|
|
classifier.cache[targetAddress] = &classification{isUntunneled, expiry}
|
|
|
classifier.mutex.Unlock()
|
|
|
|
|
|
+ if isUntunneled {
|
|
|
+ NoticeUntunneledClassification(targetAddress)
|
|
|
+ }
|
|
|
+
|
|
|
return isUntunneled
|
|
|
}
|
|
|
|
|
|
@@ -264,7 +274,7 @@ func (classifier *SplitTunnelClassifier) getRoutes(tunnel *Tunnel) (routesData [
|
|
|
|
|
|
var compressedRoutesData []byte
|
|
|
if !useCachedRoutes {
|
|
|
- routesData, err = base64.StdEncoding.DecodeString(encodedRoutesData)
|
|
|
+ compressedRoutesData, err = base64.StdEncoding.DecodeString(encodedRoutesData)
|
|
|
if err != nil {
|
|
|
NoticeAlert("failed to decode split tunnel routes: %s", ContextError(err))
|
|
|
useCachedRoutes = true
|
|
|
@@ -300,6 +310,9 @@ func (classifier *SplitTunnelClassifier) getRoutes(tunnel *Tunnel) (routesData [
|
|
|
if err != nil {
|
|
|
return nil, ContextError(err)
|
|
|
}
|
|
|
+ if routesData == nil {
|
|
|
+ return nil, ContextError(errors.New("no cached routes"))
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
return routesData, nil
|
|
|
@@ -319,7 +332,10 @@ func (classifier *SplitTunnelClassifier) installRoutes(routesData []byte) (err e
|
|
|
classifier.mutex.Lock()
|
|
|
defer classifier.mutex.Unlock()
|
|
|
|
|
|
- // ***TODO***: implementation
|
|
|
+ classifier.routes, err = NewNetworkList(routesData)
|
|
|
+ if err != nil {
|
|
|
+ return ContextError(err)
|
|
|
+ }
|
|
|
|
|
|
classifier.isRoutesSet = true
|
|
|
|
|
|
@@ -331,9 +347,111 @@ func (classifier *SplitTunnelClassifier) ipAddressInRoutes(ipAddr net.IP) bool {
|
|
|
classifier.mutex.RLock()
|
|
|
defer classifier.mutex.RUnlock()
|
|
|
|
|
|
- // ***TODO***: implementation
|
|
|
+ return classifier.routes.ContainsIpAddress(ipAddr)
|
|
|
+}
|
|
|
+
|
|
|
+// networkList is a sorted list of network ranges. It's used to
|
|
|
+// lookup candidate IP addresses for split tunnel classification.
|
|
|
+// networkList implements Sort.Interface.
|
|
|
+type networkList []net.IPNet
|
|
|
+
|
|
|
+// NewNetworkList parses text routes data and produces a networkList
|
|
|
+// for fast ContainsIpAddress lookup.
|
|
|
+// The input format is expected to be text lines where each line
|
|
|
+// is, e.g., "1.2.3.0\t255.255.255.0\n"
|
|
|
+func NewNetworkList(routesData []byte) (networkList, error) {
|
|
|
+
|
|
|
+ // Parse text routes data
|
|
|
+ var list networkList
|
|
|
+ scanner := bufio.NewScanner(bytes.NewReader(routesData))
|
|
|
+ scanner.Split(bufio.ScanLines)
|
|
|
+ for scanner.Scan() {
|
|
|
+ s := strings.Split(scanner.Text(), "\t")
|
|
|
+ if len(s) != 2 {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ ip := parseIPv4(s[0])
|
|
|
+ mask := parseIPv4Mask(s[1])
|
|
|
+ if ip == nil || mask == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ list = append(list, net.IPNet{IP: ip.Mask(mask), Mask: mask})
|
|
|
+ }
|
|
|
+ if len(list) == 0 {
|
|
|
+ return nil, ContextError(errors.New("Routes data contains no networks"))
|
|
|
+ }
|
|
|
+
|
|
|
+ // Sort data for fast lookup
|
|
|
+ sort.Sort(list)
|
|
|
+
|
|
|
+ return list, nil
|
|
|
+}
|
|
|
+
|
|
|
+func parseIPv4(s string) net.IP {
|
|
|
+ ip := net.ParseIP(s)
|
|
|
+ if ip == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ return ip.To4()
|
|
|
+}
|
|
|
+
|
|
|
+func parseIPv4Mask(s string) net.IPMask {
|
|
|
+ ip := parseIPv4(s)
|
|
|
+ if ip == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ mask := net.IPMask(ip)
|
|
|
+ if bits, size := mask.Size(); bits == 0 || size == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ return mask
|
|
|
+}
|
|
|
+
|
|
|
+// Len implementes Sort.Interface
|
|
|
+func (list networkList) Len() int {
|
|
|
+ return len(list)
|
|
|
+}
|
|
|
+
|
|
|
+// Swap implementes Sort.Interface
|
|
|
+func (list networkList) Swap(i, j int) {
|
|
|
+ list[i], list[j] = list[j], list[i]
|
|
|
+}
|
|
|
+
|
|
|
+// Less implementes Sort.Interface
|
|
|
+func (list networkList) Less(i, j int) bool {
|
|
|
+ return binary.BigEndian.Uint32(list[i].IP) < binary.BigEndian.Uint32(list[j].IP)
|
|
|
+}
|
|
|
|
|
|
- return false
|
|
|
+// ContainsIpAddress performs a binary search on the networkList to
|
|
|
+// find a network containing the candidate IP address.
|
|
|
+func (list networkList) ContainsIpAddress(addr net.IP) bool {
|
|
|
+
|
|
|
+ // Search criteria
|
|
|
+ //
|
|
|
+ // The following conditions are satisfied when address_IP is in the network:
|
|
|
+ // 1. address_IP ^ network_mask == network_IP ^ network_mask
|
|
|
+ // 2. address_IP >= network_IP.
|
|
|
+ // We are also assuming that network ranges do not overlap.
|
|
|
+ //
|
|
|
+ // For an ascending array of networks, the sort.Search returns the smallest
|
|
|
+ // index idx for which condition network_IP > address_IP is satisfied, so we
|
|
|
+ // are checking whether or not adrress_IP belongs to the network[idx-1].
|
|
|
+
|
|
|
+ // Edge conditions check
|
|
|
+ //
|
|
|
+ // idx == 0 means that address_IP is lesser than the first (smallest) network_IP
|
|
|
+ // thus never satisfies search condition 2.
|
|
|
+ // idx == array_length means that address_IP is larger than the last (largest)
|
|
|
+ // network_IP so we need to check the last element for condition 1.
|
|
|
+
|
|
|
+ addrValue := binary.BigEndian.Uint32(addr.To4())
|
|
|
+ index := sort.Search(len(list), func(i int) bool {
|
|
|
+ networkValue := binary.BigEndian.Uint32(list[i].IP)
|
|
|
+ return networkValue > addrValue
|
|
|
+ })
|
|
|
+ return index > 0 && list[index-1].IP.Equal(addr.Mask(list[index-1].Mask))
|
|
|
}
|
|
|
|
|
|
// tunneledLookupIP resolves a split tunnel candidate hostname with a tunneled
|
|
|
@@ -357,7 +475,8 @@ func tunneledLookupIP(
|
|
|
// is tunneled (also ensures this code path isn't circular).
|
|
|
// Assumes tunnel dialer conn configures timeouts and interruptibility.
|
|
|
|
|
|
- conn, err := dnsTunneler.Dial(dnsServerAddress, true, nil)
|
|
|
+ conn, err := dnsTunneler.Dial(fmt.Sprintf(
|
|
|
+ "%s:%d", dnsServerAddress, DNS_PORT), true, nil)
|
|
|
if err != nil {
|
|
|
return nil, 0, ContextError(err)
|
|
|
}
|