Jelajahi Sumber

Additional traffic rule checks

- Extend domain block list checking to packet tunnel and
  udpgw modes (DNS over UDP only, for now).

- Add full bogon check to packet tunnel mode.

- DNS packet validation means QUIC-OSSH-over-Psiphon will
  no longer be transparently forwarded to DNS servers.
Rod Hynes 5 tahun lalu
induk
melakukan
8b01d51ea9

+ 31 - 0
psiphon/common/net.go

@@ -29,8 +29,10 @@ import (
 	"sync/atomic"
 	"time"
 
+	"github.com/Psiphon-Labs/dns"
 	"github.com/Psiphon-Labs/goarista/monotime"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/wader/filtertransport"
 )
 
 // NetDialer mimicks the net.Dialer interface.
@@ -386,3 +388,32 @@ func (conn *ActivityMonitoredConn) IsClosed() bool {
 	}
 	return closer.IsClosed()
 }
+
+// IsBogon checks if the specified IP is a bogon (loopback, private addresses,
+// link-local addresses, etc.)
+func IsBogon(IP net.IP) bool {
+	return filtertransport.FindIPNet(
+		filtertransport.DefaultFilteredNetworks, IP)
+}
+
+// ParseDNSQuestion parses a DNS message. When the message is a query,
+// the first question, a fully-qualified domain name, is returned.
+//
+// For other valid DNS messages, "" is returned. An error is returned only
+// for invalid DNS messages.
+//
+// Limitations:
+// - Only the first Question field is extracted.
+// - ParseDNSQuestion only functions for plaintext DNS and cannot
+//   extract domains from DNS-over-TLS/HTTPS, etc.
+func ParseDNSQuestion(request []byte) (string, error) {
+	m := new(dns.Msg)
+	err := m.Unpack(request)
+	if err != nil {
+		return "", errors.Trace(err)
+	}
+	if len(m.Question) > 0 {
+		return m.Question[0].Name, nil
+	}
+	return "", nil
+}

+ 67 - 0
psiphon/common/net_test.go

@@ -26,6 +26,7 @@ import (
 	"testing/iotest"
 	"time"
 
+	"github.com/Psiphon-Labs/dns"
 	"github.com/Psiphon-Labs/goarista/monotime"
 )
 
@@ -273,3 +274,69 @@ func TestLRUConns(t *testing.T) {
 		t.Fatalf("unexpected IsClosed state")
 	}
 }
+
+func TestIsBogon(t *testing.T) {
+	if IsBogon(net.ParseIP("8.8.8.8")) {
+		t.Errorf("unexpected bogon")
+	}
+	if !IsBogon(net.ParseIP("127.0.0.1")) {
+		t.Errorf("unexpected non-bogon")
+	}
+	if !IsBogon(net.ParseIP("192.168.0.1")) {
+		t.Errorf("unexpected non-bogon")
+	}
+	if !IsBogon(net.ParseIP("::1")) {
+		t.Errorf("unexpected non-bogon")
+	}
+	if !IsBogon(net.ParseIP("fc00::")) {
+		t.Errorf("unexpected non-bogon")
+	}
+}
+
+func BenchmarkIsBogon(b *testing.B) {
+	for i := 0; i < b.N; i++ {
+		IsBogon(net.ParseIP("8.8.8.8"))
+	}
+}
+
+func makeDNSQuery(domain string) ([]byte, error) {
+	query := new(dns.Msg)
+	query.SetQuestion(domain, dns.TypeA)
+	query.RecursionDesired = true
+	msg, err := query.Pack()
+	if err != nil {
+		return nil, err
+	}
+	return msg, nil
+}
+
+func TestParseDNSQuestion(t *testing.T) {
+
+	domain := dns.Fqdn("www.example.com")
+	msg, err := makeDNSQuery(domain)
+	if err != nil {
+		t.Fatalf("makeDNSQuery failed: %s", err)
+	}
+
+	checkDomain, err := ParseDNSQuestion(msg)
+	if err != nil {
+		t.Fatalf("ParseDNSQuestion failed: %s", err)
+	}
+
+	if checkDomain != domain {
+		t.Fatalf("unexpected domain")
+	}
+}
+
+func BenchmarkParseDNSQuestion(b *testing.B) {
+
+	domain := dns.Fqdn("www.example.com")
+	msg, err := makeDNSQuery(domain)
+	if err != nil {
+		b.Fatalf("makeDNSQuery failed: %s", err)
+	}
+
+	for i := 0; i < b.N; i++ {
+		ParseDNSQuestion(msg)
+	}
+}

+ 66 - 9
psiphon/common/tun/tun.go

@@ -238,6 +238,10 @@ type ServerConfig struct {
 	// SessionIdleExpirySeconds is also, effectively, the lease
 	// time for assigned IP addresses.
 	SessionIdleExpirySeconds int
+
+	// AllowBogons disables bogon checks. This should be used only
+	// for testing.
+	AllowBogons bool
 }
 
 // Server is a packet tunnel server. A packet tunnel server
@@ -330,6 +334,10 @@ func (server *Server) Stop() {
 // and/or port.
 type AllowedPortChecker func(upstreamIPAddress net.IP, port int) bool
 
+// AllowedDomainChecker is a function which returns true when it is
+// permitted to resolve the specified domain name.
+type AllowedDomainChecker func(string) bool
+
 // FlowActivityUpdater defines an interface for receiving updates for
 // flow activity. Values passed to UpdateProgress are bytes transferred
 // and flow duration since the previous UpdateProgress.
@@ -357,9 +365,11 @@ type MetricsUpdater func(
 // transport provides the channel for relaying packets to and from
 // the client.
 //
-// checkAllowedTCPPortFunc/checkAllowedUDPPortFunc are callbacks used
-// to enforce traffic rules. For each TCP/UDP packet, the corresponding
-// function is called to check if traffic to the packet's port is
+// checkAllowedTCPPortFunc/checkAllowedUDPPortFunc/checkAllowedDomainFunc
+// are callbacks used to enforce traffic rules. For each TCP/UDP flow, the
+// corresponding AllowedPort function is called to check if traffic to the
+// packet's port is permitted. For upstream DNS query packets,
+// checkAllowedDomainFunc is called to check if domain resolution is
 // permitted. These callbacks must be efficient and safe for concurrent
 // calls.
 //
@@ -382,6 +392,7 @@ func (server *Server) ClientConnected(
 	sessionID string,
 	transport io.ReadWriteCloser,
 	checkAllowedTCPPortFunc, checkAllowedUDPPortFunc AllowedPortChecker,
+	checkAllowedDomainFunc AllowedDomainChecker,
 	flowActivityUpdaterMaker FlowActivityUpdaterMaker,
 	metricsUpdater MetricsUpdater) error {
 
@@ -437,6 +448,7 @@ func (server *Server) ClientConnected(
 		}
 
 		clientSession = &session{
+			allowBogons:              server.config.AllowBogons,
 			lastActivity:             int64(monotime.Now()),
 			sessionID:                sessionID,
 			metrics:                  new(packetMetrics),
@@ -465,6 +477,7 @@ func (server *Server) ClientConnected(
 		NewChannel(transport, MTU),
 		checkAllowedTCPPortFunc,
 		checkAllowedUDPPortFunc,
+		checkAllowedDomainFunc,
 		flowActivityUpdaterMaker,
 		metricsUpdater)
 
@@ -503,6 +516,7 @@ func (server *Server) resumeSession(
 	session *session,
 	channel *Channel,
 	checkAllowedTCPPortFunc, checkAllowedUDPPortFunc AllowedPortChecker,
+	checkAllowedDomainFunc AllowedDomainChecker,
 	flowActivityUpdaterMaker FlowActivityUpdaterMaker,
 	metricsUpdater MetricsUpdater) {
 
@@ -540,6 +554,8 @@ func (server *Server) resumeSession(
 
 	session.setCheckAllowedUDPPortFunc(&checkAllowedUDPPortFunc)
 
+	session.setCheckAllowedDomainFunc(&checkAllowedDomainFunc)
+
 	session.setFlowActivityUpdaterMaker(&flowActivityUpdaterMaker)
 
 	session.setMetricsUpdater(&metricsUpdater)
@@ -616,6 +632,8 @@ func (server *Server) interruptSession(session *session) {
 
 	session.setCheckAllowedUDPPortFunc(nil)
 
+	session.setCheckAllowedDomainFunc(nil)
+
 	session.setFlowActivityUpdaterMaker(nil)
 
 	session.setMetricsUpdater(nil)
@@ -948,7 +966,7 @@ func (server *Server) allocateIndex(newSession *session) error {
 		//   address (10.255.255.255) respectively
 		// - 1 is reserved as the server tun device address,
 		//   (10.0.0.1, and IPv6 equivalent)
-		// - 2 is reserver as the transparent DNS target
+		// - 2 is reserved as the transparent DNS target
 		//   address (10.0.0.2, and IPv6 equivalent)
 
 		if index <= 2 {
@@ -1055,10 +1073,12 @@ type session struct {
 	lastFlowReapIndex        int64
 	checkAllowedTCPPortFunc  unsafe.Pointer
 	checkAllowedUDPPortFunc  unsafe.Pointer
+	checkAllowedDomainFunc   unsafe.Pointer
 	flowActivityUpdaterMaker unsafe.Pointer
 	metricsUpdater           unsafe.Pointer
 	downstreamPackets        unsafe.Pointer
 
+	allowBogons              bool
 	metrics                  *packetMetrics
 	sessionID                string
 	index                    int32
@@ -1142,6 +1162,18 @@ func (session *session) getCheckAllowedUDPPortFunc() AllowedPortChecker {
 	return *p
 }
 
+func (session *session) setCheckAllowedDomainFunc(p *AllowedDomainChecker) {
+	atomic.StorePointer(&session.checkAllowedDomainFunc, unsafe.Pointer(p))
+}
+
+func (session *session) getCheckAllowedDomainFunc() AllowedDomainChecker {
+	p := (*AllowedDomainChecker)(atomic.LoadPointer(&session.checkAllowedDomainFunc))
+	if p == nil {
+		return nil
+	}
+	return *p
+}
+
 func (session *session) setFlowActivityUpdaterMaker(p *FlowActivityUpdaterMaker) {
 	atomic.StorePointer(&session.flowActivityUpdaterMaker, unsafe.Pointer(p))
 }
@@ -1995,9 +2027,11 @@ const (
 	packetRejectUDPPort            = 9
 	packetRejectNoOriginalAddress  = 10
 	packetRejectNoDNSResolvers     = 11
-	packetRejectNoClient           = 12
-	packetRejectReasonCount        = 13
-	packetOk                       = 13
+	packetRejectInvalidDNSMessage  = 12
+	packetRejectDisallowedDomain   = 13
+	packetRejectNoClient           = 14
+	packetRejectReasonCount        = 15
+	packetOk                       = 15
 )
 
 type packetDirection int
@@ -2034,6 +2068,10 @@ func packetRejectReasonDescription(reason packetRejectReason) string {
 		return "no_original_address"
 	case packetRejectNoDNSResolvers:
 		return "no_dns_resolvers"
+	case packetRejectInvalidDNSMessage:
+		return "invalid_dns_message"
+	case packetRejectDisallowedDomain:
+		return "disallowed_domain"
 	case packetRejectNoClient:
 		return "no_client"
 	}
@@ -2281,6 +2319,25 @@ func processPacket(
 						return false
 					}
 				}
+
+				// Limitation: checkAllowedDomainFunc is applied only to DNS queries in
+				// UDP; currently DNS-over-TCP will bypass the domain block list check.
+
+				if protocol == internetProtocolUDP {
+
+					domain, err := common.ParseDNSQuestion(applicationData)
+					if err != nil {
+						metrics.rejectedPacket(direction, packetRejectInvalidDNSMessage)
+						return false
+					}
+					if domain != "" {
+						checkAllowedDomainFunc := session.getCheckAllowedDomainFunc()
+						if !checkAllowedDomainFunc(domain) {
+							metrics.rejectedPacket(direction, packetRejectDisallowedDomain)
+							return false
+						}
+					}
+				}
 			}
 
 		} else { // packetDirectionServerDownstream
@@ -2395,9 +2452,9 @@ func processPacket(
 
 		// Enforce no localhost, multicast or broadcast packets; and
 		// no client-to-client packets.
+		if (!session.allowBogons && common.IsBogon(destinationIPAddress)) ||
 
-		if !destinationIPAddress.IsGlobalUnicast() ||
-
+			// The following are disallowed even when other bogons are allowed.
 			(direction == packetDirectionServerUpstream &&
 				((version == 4 &&
 					!destinationIPAddress.Equal(transparentDNSResolverIPv4Address) &&

+ 2 - 0
psiphon/common/tun/tun_test.go

@@ -390,12 +390,14 @@ func (server *testServer) run() {
 			sessionID := prng.HexString(SESSION_ID_LENGTH)
 
 			checkAllowedPortFunc := func(net.IP, int) bool { return true }
+			checkAllowedDomainFunc := func(string) bool { return true }
 
 			server.tunServer.ClientConnected(
 				sessionID,
 				signalConn,
 				checkAllowedPortFunc,
 				checkAllowedPortFunc,
+				checkAllowedDomainFunc,
 				server.updaterMaker,
 				server.metricsUpdater)
 

+ 6 - 0
psiphon/server/blocklist.go

@@ -135,6 +135,12 @@ func (b *Blocklist) LookupDomain(domain string) []BlocklistTag {
 		return nil
 	}
 
+	// Domains parsed out of DNS queries will be fully-qualified domain names,
+	// while list entries do not end in a dot.
+	if len(domain) > 0 && domain[len(domain)-1] == '.' {
+		domain = domain[:len(domain)-1]
+	}
+
 	tags, ok := b.data.Load().(*blocklistData).lookupDomain[domain]
 	if !ok {
 		return nil

+ 0 - 24
psiphon/server/server_test.go

@@ -2345,27 +2345,3 @@ func (v verifyTestCasesStoredLookup) checkStored(t *testing.T, errMessage string
 		t.Fatalf("%s: %+v", errMessage, v)
 	}
 }
-
-func TestIsBogon(t *testing.T) {
-	if IsBogon(net.ParseIP("8.8.8.8")) {
-		t.Errorf("unexpected bogon")
-	}
-	if !IsBogon(net.ParseIP("127.0.0.1")) {
-		t.Errorf("unexpected non-bogon")
-	}
-	if !IsBogon(net.ParseIP("192.168.0.1")) {
-		t.Errorf("unexpected non-bogon")
-	}
-	if !IsBogon(net.ParseIP("::1")) {
-		t.Errorf("unexpected non-bogon")
-	}
-	if !IsBogon(net.ParseIP("fc00::")) {
-		t.Errorf("unexpected non-bogon")
-	}
-}
-
-func BenchmarkIsBogon(b *testing.B) {
-	for i := 0; i < b.N; i++ {
-		IsBogon(net.ParseIP("8.8.8.8"))
-	}
-}

+ 1 - 0
psiphon/server/services.go

@@ -94,6 +94,7 @@ func RunServices(configJSON []byte) error {
 			EgressInterface:             config.PacketTunnelEgressInterface,
 			DownstreamPacketQueueSize:   config.PacketTunnelDownstreamPacketQueueSize,
 			SessionIdleExpirySeconds:    config.PacketTunnelSessionIdleExpirySeconds,
+			AllowBogons:                 config.AllowBogons,
 		})
 		if err != nil {
 			log.WithTraceFields(LogFields{"error": err}).Error("init packet tunnel failed")

+ 42 - 25
psiphon/server/tunnelServer.go

@@ -2084,6 +2084,11 @@ func (sshClient *sshClient) handleNewPacketTunnelChannel(
 		return sshClient.isPortForwardPermitted(portForwardTypeUDP, upstreamIPAddress, port)
 	}
 
+	checkAllowedDomainFunc := func(domain string) bool {
+		ok, _ := sshClient.isDomainPermitted(domain)
+		return ok
+	}
+
 	flowActivityUpdaterMaker := func(
 		upstreamHostname string, upstreamIPAddress net.IP) []tun.FlowActivityUpdater {
 
@@ -2112,6 +2117,7 @@ func (sshClient *sshClient) handleNewPacketTunnelChannel(
 		packetTunnelChannel,
 		checkAllowedTCPPortFunc,
 		checkAllowedUDPPortFunc,
+		checkAllowedDomainFunc,
 		flowActivityUpdaterMaker,
 		metricUpdater)
 	if err != nil {
@@ -2887,7 +2893,7 @@ func (sshClient *sshClient) isPortForwardPermitted(
 	// This check also avoids spurious disallowed traffic alerts for destinations
 	// that are impossible to reach.
 
-	if !sshClient.sshServer.support.Config.AllowBogons && IsBogon(remoteIP) {
+	if !sshClient.sshServer.support.Config.AllowBogons && common.IsBogon(remoteIP) {
 		return false
 	}
 
@@ -2960,6 +2966,34 @@ func (sshClient *sshClient) isPortForwardPermitted(
 	return false
 }
 
+// isDomainPermitted returns true when the specified domain may be resolved
+// and returns false and a reject reason otherwise.
+func (sshClient *sshClient) isDomainPermitted(domain string) (bool, string) {
+
+	// We're not doing comprehensive validation, to avoid overhead per port
+	// forward. This is a simple sanity check to ensure we don't process
+	// blantantly invalid input.
+	//
+	// TODO: validate with dns.IsDomainName?
+	if len(domain) > 255 {
+		return false, "invalid domain name"
+	}
+
+	tags := sshClient.sshServer.support.Blocklist.LookupDomain(domain)
+	if len(tags) > 0 {
+
+		sshClient.logBlocklistHits(nil, domain, tags)
+
+		if sshClient.sshServer.support.Config.BlocklistActive {
+			// Actively alert and block
+			sshClient.enqueueUnsafeTrafficAlertRequest(tags)
+			return false, "port forward not permitted"
+		}
+	}
+
+	return true, ""
+}
+
 func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
 
 	sshClient.Lock()
@@ -3245,37 +3279,20 @@ func (sshClient *sshClient) handleTCPChannel(
 	// performs no actions and next immediate step is the isPortForwardPermitted
 	// check.
 	//
-	// Limitation: at this time, only clients that send domains in hostToConnect
-	// are subject to domain blocklist checks. Both the udpgw and packet tunnel
-	// modes perform tunneled DNS and send only IPs in hostToConnect.
+	// Limitation: this case handles port forwards where the client sends the
+	// destination domain in the SSH port forward request but does not currently
+	// handle DNS-over-TCP; in the DNS-over-TCP case, a client may bypass the
+	// block list check.
 
 	if !isWebServerPortForward &&
 		net.ParseIP(hostToConnect) == nil {
 
-		// We're not doing comprehensive validation, to avoid overhead per port
-		// forward. This is a simple sanity check to ensure we don't process
-		// blantantly invalid input.
-		//
-		// TODO: validate with dns.IsDomainName?
-		if len(hostToConnect) > 255 {
+		ok, rejectMessage := sshClient.isDomainPermitted(hostToConnect)
+		if !ok {
 			// Note: not recording a port forward failure in this case
-			sshClient.rejectNewChannel(newChannel, "invalid domain name")
+			sshClient.rejectNewChannel(newChannel, rejectMessage)
 			return
 		}
-
-		tags := sshClient.sshServer.support.Blocklist.LookupDomain(hostToConnect)
-		if len(tags) > 0 {
-
-			sshClient.logBlocklistHits(nil, hostToConnect, tags)
-
-			if sshClient.sshServer.support.Config.BlocklistActive {
-				// Actively alert and block
-				// Note: not recording a port forward failure in this case
-				sshClient.enqueueUnsafeTrafficAlertRequest(tags)
-				sshClient.rejectNewChannel(newChannel, "port forward not permitted")
-				return
-			}
-		}
 	}
 
 	// Dial the remote address.

+ 23 - 2
psiphon/server/udp.go

@@ -145,9 +145,30 @@ func (mux *udpPortForwardMultiplexer) run() {
 			dialIP := net.IP(message.remoteIP)
 			dialPort := int(message.remotePort)
 
+			// Validate DNS packets and check the domain blocklist both when the client
+			// indicates DNS or when DNS is _not_ indicated and the destination port is
+			// 53.
+			if message.forwardDNS || message.remotePort == 53 {
+
+				domain, err := common.ParseDNSQuestion(message.packet)
+				if err != nil {
+					log.WithTraceFields(LogFields{"error": err}).Debug("ParseDNSQuestion failed")
+					// Drop packet
+					continue
+				}
+
+				if domain != "" {
+					ok, _ := mux.sshClient.isDomainPermitted(domain)
+					if !ok {
+						// Drop packet
+						continue
+					}
+				}
+			}
+
 			if message.forwardDNS {
-				// Transparent DNS forwarding. In this case, traffic rules
-				// checks are bypassed, since DNS is essential.
+				// Transparent DNS forwarding. In this case, isPortForwardPermitted
+				// traffic rules checks are bypassed, since DNS is essential.
 				dialIP = mux.sshClient.sshServer.support.DNSResolver.Get()
 				dialPort = DNS_RESOLVER_PORT
 

+ 0 - 10
psiphon/server/utils.go

@@ -22,11 +22,8 @@ package server
 import (
 	"fmt"
 	"io"
-	"net"
 	"strings"
 	"sync/atomic"
-
-	"github.com/wader/filtertransport"
 )
 
 // IntentionalPanicError is an error type that is used
@@ -123,10 +120,3 @@ func isExpectedTunnelIOError(err error) bool {
 	}
 	return false
 }
-
-// IsBogon checks if the specified IP is a bogon (loopback, private addresses,
-// link-local addresses, etc.)
-func IsBogon(IP net.IP) bool {
-	return filtertransport.FindIPNet(
-		filtertransport.DefaultFilteredNetworks, IP)
-}

+ 1 - 1
vendor/github.com/Psiphon-Labs/tls-tris/handshake_server.go

@@ -89,7 +89,7 @@ func (c *Conn) serverHandshake() error {
 	// changes, in the passthrough case the ownership of Conn.conn, the client
 	// TCP conn, is transferred to the passthrough relay and a closedConn is
 	// substituted for Conn.conn. This allows the remaining `tls` code paths to
-	// continue reference a net.Conn, albiet one that is closed, so Reads and
+	// continue reference a net.Conn, albeit one that is closed, so Reads and
 	// Writes will fail.
 
 	if c.config.PassthroughAddress != "" {

+ 3 - 3
vendor/vendor.json

@@ -147,10 +147,10 @@
 			"revisionTime": "2020-01-16T02:28:06Z"
 		},
 		{
-			"checksumSHA1": "tP8/SZKnStfvqhHMeB5EpgtoGSQ=",
+			"checksumSHA1": "CcrIyMnOlnDAZnt6c9KAzy/IndQ=",
 			"path": "github.com/Psiphon-Labs/tls-tris",
-			"revision": "7ff412878bba4c627909aed23258d42b1f2b14f5",
-			"revisionTime": "2020-03-26T18:33:34Z"
+			"revision": "28577248f5cb7f01bf0570a8f5bd7777359f07da",
+			"revisionTime": "2020-05-01T12:56:31Z"
 		},
 		{
 			"checksumSHA1": "30PBqj9BW03KCVqASvLg3bR+xYc=",