瀏覽代碼

Add domain and IPv6 support to blocklist

Rod Hynes 6 年之前
父節點
當前提交
d822d40b15
共有 4 個文件被更改,包括 171 次插入73 次删除
  1. 75 36
      psiphon/server/blocklist.go
  2. 54 26
      psiphon/server/blocklist_test.go
  3. 2 1
      psiphon/server/server_test.go
  4. 40 10
      psiphon/server/tunnelServer.go

+ 75 - 36
psiphon/server/blocklist.go

@@ -31,15 +31,15 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 )
 
-// Blocklist provides a fast lookup of IP addresses that are candidates for
-// egress blocking. This is intended to be used to block malware and other
-// malicious traffic.
+// Blocklist provides a fast lookup of IP addresses and domains that are
+// candidates for egress blocking. This is intended to be used to block
+// malware and other malicious traffic.
 //
 // The Reload function supports hot reloading of rules data while the server
 // is running.
 //
-// Limitations: currently supports only IPv4 addresses, and is implemented
-// with an in-memory Go map, which limits the practical size of the blocklist.
+// Limitations: the blocklist is implemented with in-memory Go maps, which
+// limits the practical size of the blocklist.
 type Blocklist struct {
 	common.ReloadableFile
 	loaded int32
@@ -54,7 +54,8 @@ type BlocklistTag struct {
 }
 
 type blocklistData struct {
-	lookup          map[[net.IPv4len]byte][]BlocklistTag
+	lookupIP        map[[net.IPv6len]byte][]BlocklistTag
+	lookupDomain    map[string][]BlocklistTag
 	internedStrings map[string]string
 }
 
@@ -94,27 +95,46 @@ func NewBlocklist(filename string) (*Blocklist, error) {
 	return blocklist, nil
 }
 
-// Lookup returns the blocklist tags for any IP address that is on the
+// LookupIP returns the blocklist tags for any IP address that is on the
 // blocklist, or returns nil for any IP address not on the blocklist. Lookup
-// may be called oncurrently. The caller must not modify the return value.
-func (b *Blocklist) Lookup(IPAddress net.IP) []BlocklistTag {
+// may be called concurrently. The caller must not modify the return value.
+func (b *Blocklist) LookupIP(IPAddress net.IP) []BlocklistTag {
 
 	// When not configured, no blocklist is loaded/initialized.
 	if atomic.LoadInt32(&b.loaded) != 1 {
 		return nil
 	}
 
-	var key [net.IPv4len]byte
-	IPv4Address := IPAddress.To4()
-	if IPv4Address == nil {
+	// IPAddress may be an IPv4 or IPv6 address. To16 will return the 16-byte
+	// representation of an IPv4 address, with the net.v4InV6Prefix prefix.
+
+	var key [net.IPv6len]byte
+	IPAddress16 := IPAddress.To16()
+	if IPAddress16 == nil {
 		return nil
 	}
-	copy(key[:], IPv4Address)
+	copy(key[:], IPAddress16)
 
 	// As data is an atomic.Value, it's not necessary to call
 	// ReloadableFile.RLock/ReloadableFile.RUnlock in this case.
 
-	tags, ok := b.data.Load().(*blocklistData).lookup[key]
+	tags, ok := b.data.Load().(*blocklistData).lookupIP[key]
+	if !ok {
+		return nil
+	}
+	return tags
+}
+
+// LookupDomain returns the blocklist tags for any domain that is on the
+// blocklist, or returns nil for any domain not on the blocklist. Lookup may
+// be called concurrently. The caller must not modify the return value.
+func (b *Blocklist) LookupDomain(domain string) []BlocklistTag {
+
+	if atomic.LoadInt32(&b.loaded) != 1 {
+		return nil
+	}
+
+	tags, ok := b.data.Load().(*blocklistData).lookupDomain[domain]
 	if !ok {
 		return nil
 	}
@@ -146,18 +166,6 @@ func loadBlocklistFromFile(filename string) (*blocklistData, error) {
 			return nil, errors.Trace(err)
 		}
 
-		IPAddress := net.ParseIP(record[0])
-		if IPAddress == nil {
-			return nil, errors.Tracef("invalid IP address: %s", record[0])
-		}
-		IPv4Address := IPAddress.To4()
-		if IPAddress == nil {
-			return nil, errors.Tracef("invalid IPv4 address: %s", record[0])
-		}
-
-		var key [net.IPv4len]byte
-		copy(key[:], IPv4Address)
-
 		// Intern the source and subject strings so we only store one copy of
 		// each in memory. These values are expected to repeat often.
 		source := data.internString(record[1])
@@ -168,18 +176,48 @@ func loadBlocklistFromFile(filename string) (*blocklistData, error) {
 			Subject: subject,
 		}
 
-		tags := data.lookup[key]
+		IPAddress := net.ParseIP(record[0])
+		if IPAddress != nil {
 
-		found := false
-		for _, existingTag := range tags {
-			if tag == existingTag {
-				found = true
-				break
+			IPAddress16 := IPAddress.To16()
+			if IPAddress16 == nil {
+				return nil, errors.Tracef("invalid IP address: %s", record[0])
 			}
-		}
 
-		if !found {
-			data.lookup[key] = append(tags, tag)
+			var key [net.IPv6len]byte
+			copy(key[:], IPAddress16)
+
+			tags := data.lookupIP[key]
+
+			found := false
+			for _, existingTag := range tags {
+				if tag == existingTag {
+					found = true
+					break
+				}
+			}
+
+			if !found {
+				data.lookupIP[key] = append(tags, tag)
+			}
+
+		} else {
+
+			key := record[0]
+
+			tags := data.lookupDomain[key]
+
+			found := false
+			for _, existingTag := range tags {
+				if tag == existingTag {
+					found = true
+					break
+				}
+			}
+
+			if !found {
+				data.lookupDomain[key] = append(tags, tag)
+			}
 		}
 	}
 
@@ -188,7 +226,8 @@ func loadBlocklistFromFile(filename string) (*blocklistData, error) {
 
 func newBlocklistData() *blocklistData {
 	return &blocklistData{
-		lookup:          make(map[[net.IPv4len]byte][]BlocklistTag),
+		lookupIP:        make(map[[net.IPv6len]byte][]BlocklistTag),
+		lookupDomain:    make(map[string][]BlocklistTag),
 		internedStrings: make(map[string]string),
 	}
 }

+ 54 - 26
psiphon/server/blocklist_test.go

@@ -42,8 +42,10 @@ func TestBlocklist(t *testing.T) {
 
 	filename := filepath.Join(testDataDirName, "blocklist")
 
-	hit := net.ParseIP("0.0.0.0")
-	miss := net.ParseIP("255.255.255.255")
+	hitIPv4 := net.ParseIP("0.0.0.0")
+	hitIPv6 := net.ParseIP("2001:db8:f75c::0951:58bc:ef22")
+	hitDomain := "example.org"
+	missIPv4 := net.ParseIP("255.255.255.255")
 	sources := []string{"source1", "source2", "source3", "source4", "source4"}
 	subjects := []string{"subject1", "subject2", "subject3", "subject4", "subject4"}
 	hitPresent := []int{0, 1}
@@ -60,18 +62,31 @@ func TestBlocklist(t *testing.T) {
 		if err != nil {
 			t.Fatalf("Fprintf failed: %s", err)
 		}
+		hitIPv4Index := -1
+		hitIPv6Index := -1
+		hitDomainIndex := -1
+		if common.ContainsInt(hitPresent, i) {
+			indices := prng.Perm(entriesPerSource)
+			hitIPv4Index = indices[0] - 1
+			hitIPv6Index = indices[1] - 1
+			hitDomainIndex = indices[2] - 1
+		}
 		for j := 0; j < entriesPerSource; j++ {
-			var IPAddress string
-			if j == entriesPerSource/2 && common.ContainsInt(hitPresent, i) {
-				IPAddress = hit.String()
+			var address string
+			if j == hitIPv4Index {
+				address = hitIPv4.String()
+			} else if j == hitIPv6Index {
+				address = hitIPv6.String()
+			} else if j == hitDomainIndex {
+				address = hitDomain
 			} else {
-				IPAddress = fmt.Sprintf(
+				address = fmt.Sprintf(
 					"%d.%d.%d.%d",
 					prng.Range(1, 254), prng.Range(1, 254),
 					prng.Range(1, 254), prng.Range(1, 254))
 			}
 			_, err := fmt.Fprintf(file, "%s,%s,%s\n",
-				IPAddress, sources[i], subjects[i])
+				address, sources[i], subjects[i])
 			if err != nil {
 				t.Fatalf("Fprintf failed: %s", err)
 			}
@@ -85,7 +100,36 @@ func TestBlocklist(t *testing.T) {
 		t.Fatalf("NewBlocklist failed: %s", err)
 	}
 
-	tags := b.Lookup(hit)
+	for _, hitIP := range []net.IP{hitIPv4, hitIPv6} {
+
+		tags := b.LookupIP(hitIP)
+
+		if tags == nil {
+			t.Fatalf("unexpected miss")
+		}
+
+		if len(tags) != len(hitPresent) {
+			t.Fatalf("unexpected hit tag count")
+		}
+
+		for _, tag := range tags {
+			sourceFound := false
+			subjectFound := false
+			for _, i := range hitPresent {
+				if tag.Source == sources[i] {
+					sourceFound = true
+				}
+				if tag.Subject == subjects[i] {
+					subjectFound = true
+				}
+			}
+			if !sourceFound || !subjectFound {
+				t.Fatalf("unexpected hit tag")
+			}
+		}
+	}
+
+	tags := b.LookupDomain(hitDomain)
 
 	if tags == nil {
 		t.Fatalf("unexpected miss")
@@ -95,23 +139,7 @@ func TestBlocklist(t *testing.T) {
 		t.Fatalf("unexpected hit tag count")
 	}
 
-	for _, tag := range tags {
-		sourceFound := false
-		subjectFound := false
-		for _, i := range hitPresent {
-			if tag.Source == sources[i] {
-				sourceFound = true
-			}
-			if tag.Subject == subjects[i] {
-				subjectFound = true
-			}
-		}
-		if !sourceFound || !subjectFound {
-			t.Fatalf("unexpected hit tag")
-		}
-	}
-
-	if b.Lookup(miss) != nil {
+	if b.LookupIP(missIPv4) != nil {
 		t.Fatalf("unexpected hit")
 	}
 
@@ -131,7 +159,7 @@ func TestBlocklist(t *testing.T) {
 	start := time.Now()
 
 	for i := 0; i < numIterations; i++ {
-		_ = b.Lookup(lookups[i%numLookups])
+		_ = b.LookupIP(lookups[i%numLookups])
 	}
 
 	t.Logf(

+ 2 - 1
psiphon/server/server_test.go

@@ -1874,7 +1874,8 @@ func paveTacticsConfigFile(
 
 func paveBlocklistFile(t *testing.T, blocklistFilename string) {
 
-	blocklistContent := "255.255.255.255,test-source,test-subject\n"
+	blocklistContent :=
+		"255.255.255.255,test-source,test-subject\n2001:db8:f75c::0951:58bc:ef22,test-source,test-subject\nexample.org,test-source,test-subject\n"
 
 	err := ioutil.WriteFile(blocklistFilename, []byte(blocklistContent), 0600)
 	if err != nil {

+ 40 - 10
psiphon/server/tunnelServer.go

@@ -2271,7 +2271,7 @@ var blocklistHitsStatParams = []requestParamSpec{
 	{"last_connected", isLastConnected, requestParamOptional},
 }
 
-func (sshClient *sshClient) logBlocklistHits(remoteIP net.IP, tags []BlocklistTag) {
+func (sshClient *sshClient) logBlocklistHits(IP net.IP, domain string, tags []BlocklistTag) {
 
 	sshClient.Lock()
 
@@ -2289,7 +2289,12 @@ func (sshClient *sshClient) logBlocklistHits(remoteIP net.IP, tags []BlocklistTa
 	sshClient.Unlock()
 
 	for _, tag := range tags {
-		logFields["blocklist_ip_address"] = remoteIP.String()
+		if IP != nil {
+			logFields["blocklist_ip_address"] = IP.String()
+		}
+		if domain != "" {
+			logFields["blocklist_domain"] = domain
+		}
 		logFields["blocklist_source"] = tag.Source
 		logFields["blocklist_subject"] = tag.Subject
 
@@ -2752,9 +2757,9 @@ func (sshClient *sshClient) isPortForwardPermitted(
 	// cases, a blocklist entry won't be dialed in any case. However, no logs
 	// will be recorded.
 
-	tags := sshClient.sshServer.support.Blocklist.Lookup(remoteIP)
+	tags := sshClient.sshServer.support.Blocklist.LookupIP(remoteIP)
 	if len(tags) > 0 {
-		sshClient.logBlocklistHits(remoteIP, tags)
+		sshClient.logBlocklistHits(remoteIP, "", tags)
 		if sshClient.sshServer.support.Config.BlocklistActive {
 			return false
 		}
@@ -3052,13 +3057,40 @@ func (sshClient *sshClient) handleTCPChannel(
 		}
 	}
 
+	// Check the domain blocklist before dialing.
+	//
+	// The IP blocklist is checked in isPortForwardPermitted, which also provides
+	// IP blocklist checking for the packet tunnel code path. When hostToConnect
+	// is an IP address, the following hostname resolution step effectively
+	// performs no actions and next immediate step is the isPortForwardPermitted
+	// check.
+	//
+	// Limitation: at this time, only clients that send domains in hostToConnect
+	// are subject to domain blocklist checks. Both the udpgw and packet tunnel
+	// modes perform tunneled DNS and send only IPs in hostToConnect.
+
+	if !isWebServerPortForward &&
+		net.ParseIP(hostToConnect) == nil {
+
+		tags := sshClient.sshServer.support.Blocklist.LookupDomain(hostToConnect)
+		if len(tags) > 0 {
+			sshClient.logBlocklistHits(nil, hostToConnect, tags)
+			if sshClient.sshServer.support.Config.BlocklistActive {
+				// Note: not recording a port forward failure in this case
+				sshClient.rejectNewChannel(newChannel, "port forward not permitted")
+				return
+			}
+		}
+	}
+
 	// Dial the remote address.
 	//
-	// Hostname resolution is performed explicitly, as a separate step, as the target IP
-	// address is used for traffic rules (AllowSubnets) and OSL seed progress.
+	// Hostname resolution is performed explicitly, as a separate step, as the
+	// target IP address is used for traffic rules (AllowSubnets), OSL seed
+	// progress, and IP address blocklists.
 	//
-	// Contexts are used for cancellation (via sshClient.runCtx, which is cancelled
-	// when the client is stopping) and timeouts.
+	// Contexts are used for cancellation (via sshClient.runCtx, which is
+	// cancelled when the client is stopping) and timeouts.
 
 	dialStartTime := time.Now()
 
@@ -3113,9 +3145,7 @@ func (sshClient *sshClient) handleTCPChannel(
 			portForwardTypeTCP,
 			IP,
 			portToConnect) {
-
 		// Note: not recording a port forward failure in this case
-
 		sshClient.rejectNewChannel(newChannel, "port forward not permitted")
 		return
 	}