Quellcode durchsuchen

Add domain name validation

Rod Hynes vor 6 Jahren
Ursprung
Commit
5933f43af1
2 geänderte Dateien mit 17 neuen und 1 gelöschten Zeilen
  1. 5 0
      psiphon/server/blocklist.go
  2. 12 1
      psiphon/server/tunnelServer.go

+ 5 - 0
psiphon/server/blocklist.go

@@ -27,6 +27,7 @@ import (
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
 
 
+	"github.com/Psiphon-Labs/dns"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 )
 )
@@ -203,6 +204,10 @@ func loadBlocklistFromFile(filename string) (*blocklistData, error) {
 
 
 		} else {
 		} else {
 
 
+			if _, ok := dns.IsDomainName(record[0]); !ok {
+				return nil, errors.Tracef("invalid domain name: %s", record[0])
+			}
+
 			key := record[0]
 			key := record[0]
 
 
 			tags := data.lookupDomain[key]
 			tags := data.lookupDomain[key]

+ 12 - 1
psiphon/server/tunnelServer.go

@@ -3057,7 +3057,7 @@ func (sshClient *sshClient) handleTCPChannel(
 		}
 		}
 	}
 	}
 
 
-	// Check the domain blocklist before dialing.
+	// Validate the domain name and check the domain blocklist before dialing.
 	//
 	//
 	// The IP blocklist is checked in isPortForwardPermitted, which also provides
 	// The IP blocklist is checked in isPortForwardPermitted, which also provides
 	// IP blocklist checking for the packet tunnel code path. When hostToConnect
 	// IP blocklist checking for the packet tunnel code path. When hostToConnect
@@ -3072,6 +3072,17 @@ func (sshClient *sshClient) handleTCPChannel(
 	if !isWebServerPortForward &&
 	if !isWebServerPortForward &&
 		net.ParseIP(hostToConnect) == nil {
 		net.ParseIP(hostToConnect) == nil {
 
 
+		// We're not doing comprehensive validation, to avoid overhead per port
+		// forward. This is a simple sanity check to ensure we don't process
+		// blantantly invalid input.
+		//
+		// TODO: validate with dns.IsDomainName?
+		if len(hostToConnect) > 255 {
+			// Note: not recording a port forward failure in this case
+			sshClient.rejectNewChannel(newChannel, "invalid domain name")
+			return
+		}
+
 		tags := sshClient.sshServer.support.Blocklist.LookupDomain(hostToConnect)
 		tags := sshClient.sshServer.support.Blocklist.LookupDomain(hostToConnect)
 		if len(tags) > 0 {
 		if len(tags) > 0 {
 			sshClient.logBlocklistHits(nil, hostToConnect, tags)
 			sshClient.logBlocklistHits(nil, hostToConnect, tags)