Просмотр исходного кода

Add blocklist functionality

- Make udpgw transparent DNS skip traffic
  rules (including new blocklist check),
  as packet tunnel does.

- Add large file option to hot reloader:
  stream CRC and don't load full content
  into memory.
Rod Hynes 7 лет назад
Родитель
Сommit
02fc6a246e

+ 1 - 0
psiphon/common/osl/osl.go

@@ -259,6 +259,7 @@ func NewConfig(filename string) (*Config, error) {
 
 	config.ReloadableFile = common.NewReloadableFile(
 		filename,
+		true,
 		func(fileContent []byte) error {
 			newConfig, err := LoadConfig(fileContent)
 			if err != nil {

+ 50 - 13
psiphon/common/reloader.go

@@ -21,7 +21,9 @@ package common
 
 import (
 	"hash/crc64"
+	"io"
 	"io/ioutil"
+	"os"
 	"sync"
 )
 
@@ -60,26 +62,34 @@ type Reloader interface {
 //
 type ReloadableFile struct {
 	sync.RWMutex
-	fileName     string
-	checksum     uint64
-	reloadAction func([]byte) error
+	filename        string
+	loadFileContent bool
+	checksum        uint64
+	reloadAction    func([]byte) error
 }
 
-// NewReloadableFile initializes a new ReloadableFile
+// NewReloadableFile initializes a new ReloadableFile.
+//
+// When loadFileContent is true, the file content is loaded and passed to
+// reloadAction; otherwise, reloadAction receives a nil argument and is
+// responsible for loading the file. The latter option allows for cases where
+// the file contents must be streamed, memory mapped, etc.
 func NewReloadableFile(
-	fileName string,
+	filename string,
+	loadFileContent bool,
 	reloadAction func([]byte) error) ReloadableFile {
 
 	return ReloadableFile{
-		fileName:     fileName,
-		reloadAction: reloadAction,
+		filename:        filename,
+		loadFileContent: loadFileContent,
+		reloadAction:    reloadAction,
 	}
 }
 
 // WillReload indicates whether the ReloadableFile is capable
 // of reloading.
 func (reloadable *ReloadableFile) WillReload() bool {
-	return reloadable.fileName != ""
+	return reloadable.filename != ""
 }
 
 var crc64table = crc64.MakeTable(crc64.ISO)
@@ -108,22 +118,49 @@ func (reloadable *ReloadableFile) Reload() (bool, error) {
 	// Check whether the file has changed _before_ blocking readers
 
 	reloadable.RLock()
-	fileName := reloadable.fileName
+	filename := reloadable.filename
 	previousChecksum := reloadable.checksum
 	reloadable.RUnlock()
 
-	content, err := ioutil.ReadFile(fileName)
+	file, err := os.Open(filename)
 	if err != nil {
 		return false, ContextError(err)
 	}
+	defer file.Close()
+
+	hash := crc64.New(crc64table)
 
-	checksum := crc64.Checksum(content, crc64table)
+	_, err = io.Copy(hash, file)
+	if err != nil {
+		return false, ContextError(err)
+	}
+
+	checksum := hash.Sum64()
 
 	if checksum == previousChecksum {
 		return false, nil
 	}
 
-	// ...now block readers
+	// It's possible for the file content to revert to its previous value
+	// between the checksum operation and subsequent content load. We accept
+	// the false positive in this unlikely case.
+
+	var content []byte
+	if reloadable.loadFileContent {
+		_, err = file.Seek(0, 0)
+		if err != nil {
+			return false, ContextError(err)
+		}
+		content, err = ioutil.ReadAll(file)
+		if err != nil {
+			return false, ContextError(err)
+		}
+	}
+
+	// Don't keep file open during reloadAction call.
+	file.Close()
+
+	// ...now block readers and reload
 
 	reloadable.Lock()
 	defer reloadable.Unlock()
@@ -139,5 +176,5 @@ func (reloadable *ReloadableFile) Reload() (bool, error) {
 }
 
 func (reloadable *ReloadableFile) LogDescription() string {
-	return reloadable.fileName
+	return reloadable.filename
 }

+ 5 - 4
psiphon/common/reloader_test.go

@@ -35,7 +35,7 @@ func TestReloader(t *testing.T) {
 	}
 	defer os.RemoveAll(dirname)
 
-	fileName := filepath.Join(dirname, "reloader_test.dat")
+	filename := filepath.Join(dirname, "reloader_test.dat")
 
 	initialContents := []byte("contents1\n")
 	modifiedContents := []byte("contents2\n")
@@ -46,7 +46,8 @@ func TestReloader(t *testing.T) {
 	}
 
 	file.ReloadableFile = NewReloadableFile(
-		fileName,
+		filename,
+		true,
 		func(fileContent []byte) error {
 			file.contents = fileContent
 			return nil
@@ -54,7 +55,7 @@ func TestReloader(t *testing.T) {
 
 	// Test: initial load
 
-	err = ioutil.WriteFile(fileName, initialContents, 0600)
+	err = ioutil.WriteFile(filename, initialContents, 0600)
 	if err != nil {
 		t.Fatalf("WriteFile failed: %s", err)
 	}
@@ -89,7 +90,7 @@ func TestReloader(t *testing.T) {
 
 	// Test: reload changed file
 
-	err = ioutil.WriteFile(fileName, modifiedContents, 0600)
+	err = ioutil.WriteFile(filename, modifiedContents, 0600)
 	if err != nil {
 		t.Fatalf("WriteFile failed: %s", err)
 	}

+ 1 - 0
psiphon/common/tactics/tactics.go

@@ -436,6 +436,7 @@ func NewServer(
 
 	server.ReloadableFile = common.NewReloadableFile(
 		configFilename,
+		true,
 		func(fileContent []byte) error {
 
 			var newServer Server

+ 203 - 0
psiphon/server/blocklist.go

@@ -0,0 +1,203 @@
+/*
+ * Copyright (c) 2019, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package server
+
+import (
+	"encoding/csv"
+	"fmt"
+	"io"
+	"net"
+	"os"
+	"sync/atomic"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+)
+
+// Blocklist provides a fast lookup of IP addresses that are candidates for
+// egress blocking. This is intended to be used to block malware and other
+// malicious traffic.
+//
+// The Reload function supports hot reloading of rules data while the server
+// is running.
+//
+// Limitations: currently supports only IPv4 addresses, and is implemented
+// with an in-memory Go map, which limits the practical size of the blocklist.
+type Blocklist struct {
+	common.ReloadableFile
+	loaded int32
+	data   atomic.Value
+}
+
+// BlocklistTag indicates the source containing an IP address and the subject,
+// or name of the suspected malicious traffic.
+type BlocklistTag struct {
+	Source  string
+	Subject string
+}
+
+type blocklistData struct {
+	lookup          map[[net.IPv4len]byte][]BlocklistTag
+	internedStrings map[string]string
+}
+
+// NewBlocklist creates a new block list.
+//
+// The input file must be a 3 field comma-delimited and optional quote-escaped
+// CSV. Fields: <IPv4 address>,<source>,<subject>.
+//
+// IP addresses may appear multiple times in the input file; each distinct
+// source/subject is associated with the IP address and returned in the Lookup
+// tag list.
+func NewBlocklist(filename string) (*Blocklist, error) {
+
+	blocklist := &Blocklist{}
+
+	blocklist.ReloadableFile = common.NewReloadableFile(
+		filename,
+		false,
+		func(_ []byte) error {
+
+			newData, err := loadBlocklistFromFile(filename)
+			if err != nil {
+				return common.ContextError(err)
+			}
+
+			blocklist.data.Store(newData)
+			atomic.StoreInt32(&blocklist.loaded, 1)
+
+			return nil
+		})
+
+	_, err := blocklist.Reload()
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+
+	return blocklist, nil
+}
+
+// Lookup returns the blocklist tags for any IP address that is on the
+// blocklist, or returns nil for any IP address not on the blocklist. Lookup
+// may be called oncurrently. The caller must not modify the return value.
+func (b *Blocklist) Lookup(IPAddress net.IP) []BlocklistTag {
+
+	// When not configured, no blocklist is loaded/initialized.
+	if atomic.LoadInt32(&b.loaded) != 1 {
+		return nil
+	}
+
+	var key [net.IPv4len]byte
+	IPv4Address := IPAddress.To4()
+	if IPv4Address == nil {
+		return nil
+	}
+	copy(key[:], IPv4Address)
+
+	// As data is an atomic.Value, it's not necessary to call
+	// ReloadableFile.RLock/ReloadableFile.RUnlock in this case.
+
+	tags, ok := b.data.Load().(*blocklistData).lookup[key]
+	if !ok {
+		return nil
+	}
+	return tags
+}
+
+func loadBlocklistFromFile(filename string) (*blocklistData, error) {
+
+	data := newBlocklistData()
+
+	file, err := os.Open(filename)
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+	defer file.Close()
+
+	reader := csv.NewReader(file)
+
+	reader.FieldsPerRecord = 3
+	reader.Comment = '#'
+	reader.ReuseRecord = true
+
+	for {
+		record, err := reader.Read()
+
+		if err == io.EOF {
+			break
+		} else if err != nil {
+			return nil, common.ContextError(err)
+		}
+
+		IPAddress := net.ParseIP(record[0])
+		if IPAddress == nil {
+			return nil, common.ContextError(
+				fmt.Errorf("invalid IP address: %s", record[0]))
+		}
+		IPv4Address := IPAddress.To4()
+		if IPAddress == nil {
+			return nil, common.ContextError(
+				fmt.Errorf("invalid IPv4 address: %s", record[0]))
+		}
+
+		var key [net.IPv4len]byte
+		copy(key[:], IPv4Address)
+
+		// Intern the source and subject strings so we only store one copy of
+		// each in memory. These values are expected to repeat often.
+		source := data.internString(record[1])
+		subject := data.internString(record[2])
+
+		tag := BlocklistTag{
+			Source:  source,
+			Subject: subject,
+		}
+
+		tags := data.lookup[key]
+
+		found := false
+		for _, existingTag := range tags {
+			if tag == existingTag {
+				found = true
+				break
+			}
+		}
+
+		if !found {
+			data.lookup[key] = append(tags, tag)
+		}
+	}
+
+	return data, nil
+}
+
+func newBlocklistData() *blocklistData {
+	return &blocklistData{
+		lookup:          make(map[[net.IPv4len]byte][]BlocklistTag),
+		internedStrings: make(map[string]string),
+	}
+}
+
+func (data *blocklistData) internString(str string) string {
+	if internedStr, ok := data.internedStrings[str]; ok {
+		return internedStr
+	}
+	data.internedStrings[str] = str
+	return str
+}

+ 141 - 0
psiphon/server/blocklist_test.go

@@ -0,0 +1,141 @@
+/*
+ * Copyright (c) 2019, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package server
+
+import (
+	"fmt"
+	"io/ioutil"
+	"net"
+	"os"
+	"path/filepath"
+	"testing"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+)
+
+func TestBlocklist(t *testing.T) {
+
+	testDataDirName, err := ioutil.TempDir("", "psiphon-blocklist-test")
+	if err != nil {
+		t.Fatalf("TempDir failed: %s", err)
+	}
+	defer os.RemoveAll(testDataDirName)
+
+	filename := filepath.Join(testDataDirName, "blocklist")
+
+	hit := net.ParseIP("0.0.0.0")
+	miss := net.ParseIP("255.255.255.255")
+	sources := []string{"source1", "source2", "source3", "source4", "source4"}
+	subjects := []string{"subject1", "subject2", "subject3", "subject4", "subject4"}
+	hitPresent := []int{0, 1}
+	entriesPerSource := 100000
+
+	file, err := os.Create(filename)
+	if err != nil {
+		t.Fatalf("Open failed: %s", err)
+	}
+	defer file.Close()
+
+	for i := 0; i < len(sources); i++ {
+		_, err := fmt.Fprintf(file, "# comment\n# comment\n# comment\n")
+		if err != nil {
+			t.Fatalf("Fprintf failed: %s", err)
+		}
+		for j := 0; j < entriesPerSource; j++ {
+			var IPAddress string
+			if j == entriesPerSource/2 && common.ContainsInt(hitPresent, i) {
+				IPAddress = hit.String()
+			} else {
+				IPAddress = fmt.Sprintf(
+					"%d.%d.%d.%d",
+					prng.Range(1, 254), prng.Range(1, 254),
+					prng.Range(1, 254), prng.Range(1, 254))
+			}
+			_, err := fmt.Fprintf(file, "%s,%s,%s\n",
+				IPAddress, sources[i], subjects[i])
+			if err != nil {
+				t.Fatalf("Fprintf failed: %s", err)
+			}
+		}
+	}
+
+	file.Close()
+
+	b, err := NewBlocklist(filename)
+	if err != nil {
+		t.Fatalf("NewBlocklist failed: %s", err)
+	}
+
+	tags := b.Lookup(hit)
+
+	if tags == nil {
+		t.Fatalf("unexpected miss")
+	}
+
+	if len(tags) != len(hitPresent) {
+		t.Fatalf("unexpected hit tag count")
+	}
+
+	for _, tag := range tags {
+		sourceFound := false
+		subjectFound := false
+		for _, i := range hitPresent {
+			if tag.Source == sources[i] {
+				sourceFound = true
+			}
+			if tag.Subject == subjects[i] {
+				subjectFound = true
+			}
+		}
+		if !sourceFound || !subjectFound {
+			t.Fatalf("unexpected hit tag")
+		}
+	}
+
+	if b.Lookup(miss) != nil {
+		t.Fatalf("unexpected hit")
+	}
+
+	numLookups := 10
+	numIterations := 1000000
+
+	lookups := make([]net.IP, numLookups)
+
+	for i := 0; i < numLookups; i++ {
+		lookups[i] = net.ParseIP(
+			fmt.Sprintf(
+				"%d.%d.%d.%d",
+				prng.Range(1, 254), prng.Range(1, 254),
+				prng.Range(1, 254), prng.Range(1, 254)))
+	}
+
+	start := time.Now()
+
+	for i := 0; i < numIterations; i++ {
+		_ = b.Lookup(lookups[i%numLookups])
+	}
+
+	t.Logf(
+		"average time per lookup in %d entries: %s",
+		len(sources)*entriesPerSource,
+		time.Since(start)/time.Duration(numIterations))
+}

+ 9 - 0
psiphon/server/config.go

@@ -316,6 +316,15 @@ type Config struct {
 	// MARIONETTE-OSSH tunnel protocol. The format specifies the network
 	// protocol port to listen on.
 	MarionetteFormat string
+
+	// BlocklistFilename is the path of a file containing a CSV-encoded
+	// blocklist configuration. See NewBlocklist for more file format
+	// documentation.
+	BlocklistFilename string
+
+	// BlocklistActive indicates whether to actively prevent blocklist hits in
+	// addition to logging events.
+	BlocklistActive bool
 }
 
 // RunWebServer indicates whether to run a web server component.

+ 1 - 0
psiphon/server/dns.go

@@ -78,6 +78,7 @@ func NewDNSResolver(defaultResolver string) (*DNSResolver, error) {
 
 	dns.ReloadableFile = common.NewReloadableFile(
 		DNS_SYSTEM_CONFIG_FILENAME,
+		true,
 		func(fileContent []byte) error {
 
 			resolvers, err := parseResolveConf(fileContent)

+ 3 - 2
psiphon/server/geoip.go

@@ -89,8 +89,9 @@ func NewGeoIPService(
 		database := &geoIPDatabase{}
 		database.ReloadableFile = common.NewReloadableFile(
 			filename,
-			func(fileContent []byte) error {
-				maxMindReader, err := maxminddb.FromBytes(fileContent)
+			false,
+			func(_ []byte) error {
+				maxMindReader, err := maxminddb.Open(filename)
 				if err != nil {
 					// On error, database state remains the same
 					return common.ContextError(err)

+ 1 - 0
psiphon/server/psinet/psinet.go

@@ -132,6 +132,7 @@ func NewDatabase(filename string) (*Database, error) {
 
 	database.ReloadableFile = common.NewReloadableFile(
 		filename,
+		true,
 		func(fileContent []byte) error {
 			var newDatabase Database
 			err := json.Unmarshal(fileContent, &newDatabase)

+ 16 - 1
psiphon/server/server_test.go

@@ -329,7 +329,7 @@ func TestHotReload(t *testing.T) {
 		})
 }
 
-func TestDefaultSessionID(t *testing.T) {
+func TestDefaultSponsorID(t *testing.T) {
 	runServer(t,
 		&runServerConfig{
 			tunnelProtocol:       "OSSH",
@@ -590,6 +590,9 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			livenessTestSize)
 	}
 
+	blocklistFilename := filepath.Join(testDataDirName, "blocklist.csv")
+	paveBlocklistFile(t, blocklistFilename)
+
 	var serverConfig map[string]interface{}
 	json.Unmarshal(serverConfigJSON, &serverConfig)
 	serverConfig["GeoIPDatabaseFilename"] = ""
@@ -599,6 +602,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	if doServerTactics {
 		serverConfig["TacticsConfigFilename"] = tacticsConfigFilename
 	}
+	serverConfig["BlocklistFilename"] = blocklistFilename
+
 	serverConfig["LogFilename"] = filepath.Join(testDataDirName, "psiphond.log")
 	serverConfig["LogLevel"] = "debug"
 
@@ -1442,6 +1447,16 @@ func paveTacticsConfigFile(
 	}
 }
 
+func paveBlocklistFile(t *testing.T, blocklistFilename string) {
+
+	blocklistContent := "255.255.255.255,test-source,test-subject\n"
+
+	err := ioutil.WriteFile(blocklistFilename, []byte(blocklistContent), 0600)
+	if err != nil {
+		t.Fatalf("error paving blocklist file: %s", err)
+	}
+}
+
 func sendNotificationReceived(c chan<- struct{}) {
 	select {
 	case c <- *new(struct{}):

+ 4 - 0
psiphon/server/services.go

@@ -335,6 +335,7 @@ type SupportServices struct {
 	TunnelServer       *TunnelServer
 	PacketTunnelServer *tun.Server
 	TacticsServer      *tactics.Server
+	Blocklist          *Blocklist
 }
 
 // NewSupportServices initializes a new SupportServices.
@@ -366,6 +367,8 @@ func NewSupportServices(config *Config) (*SupportServices, error) {
 		return nil, common.ContextError(err)
 	}
 
+	blocklist, err := NewBlocklist(config.BlocklistFilename)
+
 	tacticsServer, err := tactics.NewServer(
 		CommonLogger(log),
 		getTacticsAPIParameterLogFieldFormatter(),
@@ -383,6 +386,7 @@ func NewSupportServices(config *Config) (*SupportServices, error) {
 		GeoIPService:    geoIPService,
 		DNSResolver:     dnsResolver,
 		TacticsServer:   tacticsServer,
+		Blocklist:       blocklist,
 	}, nil
 }
 

+ 1 - 0
psiphon/server/trafficRules.go

@@ -257,6 +257,7 @@ func NewTrafficRulesSet(filename string) (*TrafficRulesSet, error) {
 
 	set.ReloadableFile = common.NewReloadableFile(
 		filename,
+		true,
 		func(fileContent []byte) error {
 			var newSet TrafficRulesSet
 			err := json.Unmarshal(fileContent, &newSet)

+ 44 - 7
psiphon/server/tunnelServer.go

@@ -1811,11 +1811,11 @@ func (sshClient *sshClient) handleNewPacketTunnelChannel(
 	// will stop packet tunnel workers for any previous packet tunnel channel.
 
 	checkAllowedTCPPortFunc := func(upstreamIPAddress net.IP, port int) bool {
-		return sshClient.isPortForwardPermitted(portForwardTypeTCP, false, upstreamIPAddress, port)
+		return sshClient.isPortForwardPermitted(portForwardTypeTCP, upstreamIPAddress, port)
 	}
 
 	checkAllowedUDPPortFunc := func(upstreamIPAddress net.IP, port int) bool {
-		return sshClient.isPortForwardPermitted(portForwardTypeUDP, false, upstreamIPAddress, port)
+		return sshClient.isPortForwardPermitted(portForwardTypeUDP, upstreamIPAddress, port)
 	}
 
 	flowActivityUpdaterMaker := func(
@@ -2039,6 +2039,30 @@ func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) {
 	log.LogRawFieldsWithTimestamp(logFields)
 }
 
+func (sshClient *sshClient) logBlocklistHits(remoteIP net.IP, tags []BlocklistTag) {
+
+	sshClient.Lock()
+
+	logFields := getRequestLogFields(
+		"blocklist_hit",
+		sshClient.geoIPData,
+		sshClient.handshakeState.authorizedAccessTypes,
+		sshClient.handshakeState.apiParams,
+		baseRequestParams)
+
+	// Note: see comment in logTunnel regarding unlock and concurrent access.
+
+	sshClient.Unlock()
+
+	for _, tag := range tags {
+		logFields["ip_address"] = remoteIP.String()
+		logFields["blocklist_source"] = tag.Source
+		logFields["blocklist_subject"] = tag.Subject
+
+		log.LogRawFieldsWithTimestamp(logFields)
+	}
+}
+
 func (sshClient *sshClient) runOSLSender() {
 
 	for {
@@ -2478,7 +2502,6 @@ const (
 
 func (sshClient *sshClient) isPortForwardPermitted(
 	portForwardType int,
-	isTransparentDNSForwarding bool,
 	remoteIP net.IP,
 	port int) bool {
 
@@ -2491,12 +2514,27 @@ func (sshClient *sshClient) isPortForwardPermitted(
 
 	// Disallow connection to loopback. This is a failsafe. The server
 	// should be run on a host with correctly configured firewall rules.
-	// An exception is made in the case of transparent DNS forwarding,
-	// where the remoteIP has been rewritten.
-	if !isTransparentDNSForwarding && remoteIP.IsLoopback() {
+	if remoteIP.IsLoopback() {
 		return false
 	}
 
+	// Blocklist check.
+	//
+	// Limitation: isPortForwardPermitted is not called in transparent DNS
+	// forwarding cases. As the destination IP address is rewritten in these
+	// cases, a blocklist entry won't be dialed in any case. However, no logs
+	// will be recorded.
+
+	tags := sshClient.sshServer.support.Blocklist.Lookup(remoteIP)
+	if len(tags) > 0 {
+		sshClient.logBlocklistHits(remoteIP, tags)
+		if sshClient.sshServer.support.Config.BlocklistActive {
+			return false
+		}
+	}
+
+	// Traffic rules checks.
+
 	var allowPorts []int
 	if portForwardType == portForwardTypeTCP {
 		allowPorts = sshClient.trafficRules.AllowTCPPorts
@@ -2834,7 +2872,6 @@ func (sshClient *sshClient) handleTCPChannel(
 	if !isWebServerPortForward &&
 		!sshClient.isPortForwardPermitted(
 			portForwardTypeTCP,
-			false,
 			IP,
 			portToConnect) {
 

+ 4 - 4
psiphon/server/udp.go

@@ -146,14 +146,14 @@ func (mux *udpPortForwardMultiplexer) run() {
 			dialIP := net.IP(message.remoteIP)
 			dialPort := int(message.remotePort)
 
-			// Transparent DNS forwarding
 			if message.forwardDNS {
+				// Transparent DNS forwarding. In this case, traffic rules
+				// checks are bypassed, since DNS is essential.
 				dialIP = mux.sshClient.sshServer.support.DNSResolver.Get()
 				dialPort = DNS_RESOLVER_PORT
-			}
 
-			if !mux.sshClient.isPortForwardPermitted(
-				portForwardTypeUDP, message.forwardDNS, dialIP, int(message.remotePort)) {
+			} else if !mux.sshClient.isPortForwardPermitted(
+				portForwardTypeUDP, dialIP, int(message.remotePort)) {
 				// The udpgw protocol has no error response, so
 				// we just discard the message and read another.
 				continue