Просмотр исходного кода

Merge pull request #734 from rod-hynes/server-dns-resolver

Add psiphond port forward DNS cache
Rod Hynes 10 месяцев назад
Родитель
Сommit
d1883d736c
3 измененных файлов с 204 добавлено и 26 удалено
  1. 5 0
      psiphon/common/parameters/parameters.go
  2. 74 8
      psiphon/server/server_test.go
  3. 125 18
      psiphon/server/tunnelServer.go

+ 5 - 0
psiphon/common/parameters/parameters.go

@@ -506,6 +506,8 @@ const (
 	InproxyProxyQualityPendingFailedMatchDeadline      = "InproxyProxyQualityPendingFailedMatchDeadline"
 	InproxyProxyQualityFailedMatchThreshold            = "InproxyProxyQualityFailedMatchThreshold"
 	NetworkIDCacheTTL                                  = "NetworkIDCacheTTL"
+	ServerDNSResolverCacheMaxSize                      = "ServerDNSResolverCacheMaxSize"
+	ServerDNSResolverCacheTTL                          = "ServerDNSResolverCacheTTL"
 
 	// Retired parameters
 
@@ -1082,6 +1084,9 @@ var defaultParameters = map[string]struct {
 	InproxyProxyQualityFailedMatchThreshold:          {value: 10, minimum: 1, flags: serverSideOnly},
 
 	NetworkIDCacheTTL: {value: 500 * time.Millisecond, minimum: time.Duration(0)},
+
+	ServerDNSResolverCacheMaxSize: {value: 32, minimum: 0, flags: serverSideOnly},
+	ServerDNSResolverCacheTTL:     {value: 10 * time.Second, minimum: time.Duration(0), flags: serverSideOnly},
 }
 
 // IsServerSideOnly indicates if the parameter specified by name is used

+ 74 - 8
psiphon/server/server_test.go

@@ -716,6 +716,16 @@ func TestLegacyAPIEncoding(t *testing.T) {
 		})
 }
 
+func TestDomainRequest(t *testing.T) {
+	runServer(t,
+		&runServerConfig{
+			tunnelProtocol:          "SSH",
+			requireAuthorization:    true,
+			doTunneledDomainRequest: true,
+			doLogHostProvider:       true,
+		})
+}
+
 type runServerConfig struct {
 	tunnelProtocol           string
 	clientTunnelProtocol     string
@@ -727,6 +737,7 @@ type runServerConfig struct {
 	requireAuthorization     bool
 	omitAuthorization        bool
 	doTunneledWebRequest     bool
+	doTunneledDomainRequest  bool
 	doTunneledNTPRequest     bool
 	applyPrefix              bool
 	forceFragmenting         bool
@@ -850,7 +861,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		runConfig.forceFragmenting ||
 		runConfig.doBurstMonitor ||
 		runConfig.doDestinationBytes ||
-		runConfig.doLegacyDestinationBytes
+		runConfig.doLegacyDestinationBytes ||
+		runConfig.doTunneledDomainRequest
 
 	// All servers require a tactics config with valid keys.
 	tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey, err :=
@@ -1094,6 +1106,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	// 2. hot reload of psinet db (runConfig.doHotReload)
 	// 3. hot reload of server tactics (runConfig.doHotReload && doServerTactics)
 	discoveryLog := make(chan map[string]interface{}, 3)
+	serverLoadLog := make(chan map[string]interface{}, 1)
 
 	inproxyProxyAnnounceLog := make(chan map[string]interface{}, 1)
 	inproxyClientOfferLog := make(chan map[string]interface{}, 1)
@@ -1135,6 +1148,11 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			case serverTunnelLog <- logFields:
 			default:
 			}
+		case "server_load":
+			select {
+			case serverLoadLog <- logFields:
+			default:
+			}
 		case "inproxy_broker":
 
 			event, ok := logFields["broker_event"].(string)
@@ -1298,6 +1316,14 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	p, _ := os.FindProcess(os.Getpid())
 	p.Signal(syscall.SIGUSR2)
 
+	timer := time.NewTimer(1 * time.Second)
+	select {
+	case <-serverLoadLog:
+	case <-timer.C:
+		t.Fatalf("missing server load log")
+	}
+	timer.Stop()
+
 	// configure client
 
 	values.SetSSHClientVersionsSpec(values.NewPickOneSpec(testSSHClientVersions))
@@ -1794,7 +1820,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		// Test: tunneled web site fetch
 
 		err = makeTunneledWebRequest(
-			t, localHTTPProxyPort, mockWebServerURL, mockWebServerExpectedResponse)
+			t, localHTTPProxyPort, mockWebServerURL, true, mockWebServerExpectedResponse)
 
 		if err == nil {
 			if expectTrafficFailure {
@@ -1807,6 +1833,28 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		}
 	}
 
+	if runConfig.doTunneledDomainRequest && !expectTrafficFailure {
+
+		// Test: tunneled web site fetch exercising the handleTCPChannel DNS
+		// resolver and cache
+
+		err = makeTunneledWebRequest(
+			t, localHTTPProxyPort, "https://psiphon.ca", false, "")
+		if err != nil {
+			t.Fatalf("tunneled web request failed: %s", err)
+		}
+
+		// Establish a second port forward to the same domain. The DNS
+		// resolution is expected to be cached. This is checked below via the
+		// dns_count reported in the server_load log.
+
+		err = makeTunneledWebRequest(
+			t, localHTTPProxyPort, "https://psiphon.ca", false, "")
+		if err != nil {
+			t.Fatalf("tunneled web request failed: %s", err)
+		}
+	}
+
 	if runConfig.doTunneledNTPRequest {
 
 		// Test: tunneled UDP packets
@@ -1913,8 +1961,9 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	expectServerBPFField := ServerBPFEnabled() && protocol.TunnelProtocolIsDirect(runConfig.tunnelProtocol) && doServerTactics
 	expectServerPacketManipulationField := runConfig.doPacketManipulation
 	expectBurstFields := runConfig.doBurstMonitor
-	expectTCPPortForwardDial := runConfig.doTunneledWebRequest
-	expectTCPDataTransfer := runConfig.doTunneledWebRequest && !expectTrafficFailure && !runConfig.doSplitTunnel
+	expectTCPPortForwardDial := (runConfig.doTunneledWebRequest || runConfig.doTunneledDomainRequest)
+	expectTCPDataTransfer := (runConfig.doTunneledWebRequest || runConfig.doTunneledDomainRequest) && !expectTrafficFailure && !runConfig.doSplitTunnel
+	expectDomainPortForward := runConfig.doTunneledDomainRequest
 	// Even with expectTrafficFailure, DNS port forwards will succeed
 	expectUDPDataTransfer := runConfig.doTunneledNTPRequest
 	expectQUICVersion := ""
@@ -1954,6 +2003,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			expectTCPPortForwardDial,
 			expectTCPDataTransfer,
 			expectUDPDataTransfer,
+			expectDomainPortForward,
 			expectQUICVersion,
 			expectDestinationBytesFields,
 			expectLegacyDestinationBytesFields,
@@ -2008,6 +2058,19 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		}
 	}
 
+	select {
+	case logFields := <-serverLoadLog:
+		if expectDomainPortForward {
+			dnsCount := int(logFields["dns_count"].(map[string]any)["ALL"].(float64))
+			if dnsCount != 1 {
+				t.Fatalf("unexpected dns_count: %d", dnsCount)
+			}
+
+		}
+	default:
+		t.Fatalf("missing server load log")
+	}
+
 	// Check logs emitted by discovery.
 
 	var expectedDiscoveryStrategy []string
@@ -2252,6 +2315,7 @@ func checkExpectedServerTunnelLogFields(
 	expectTCPPortForwardDial bool,
 	expectTCPDataTransfer bool,
 	expectUDPDataTransfer bool,
+	expectDomainPortForward bool,
 	expectQUICVersion string,
 	expectDestinationBytesFields bool,
 	expectLegacyDestinationBytesFields bool,
@@ -2413,7 +2477,7 @@ func checkExpectedServerTunnelLogFields(
 		}
 	}
 
-	if expectUDPDataTransfer {
+	if expectUDPDataTransfer || expectDomainPortForward {
 
 		if fields["peak_dns_failure_rate"] == nil {
 			return fmt.Errorf("missing expected field 'peak_dns_failure_rate'")
@@ -3051,7 +3115,9 @@ func checkExpectedDomainBytesLogFields(
 func makeTunneledWebRequest(
 	t *testing.T,
 	localHTTPProxyPort int,
-	requestURL, expectedResponseBody string) error {
+	requestURL string,
+	checkResponseBody bool,
+	expectedResponseBody string) error {
 
 	roundTripTimeout := 30 * time.Second
 
@@ -3078,7 +3144,7 @@ func makeTunneledWebRequest(
 	}
 	response.Body.Close()
 
-	if string(body) != expectedResponseBody {
+	if checkResponseBody && string(body) != expectedResponseBody {
 		return fmt.Errorf("unexpected proxied HTTP response")
 	}
 
@@ -3426,7 +3492,7 @@ func paveTrafficRulesFile(
 		t.Fatalf("unexpected intLookupThreshold")
 	}
 
-	TCPPorts := mockWebServerPort
+	TCPPorts := fmt.Sprintf("443, %s", mockWebServerPort)
 	UDPPorts := "53, 123, 10001, 10002, 10003, 10004, 10005, 10006, 10007, 10008, 10009, 10010"
 
 	allowTCPPorts := TCPPorts

+ 125 - 18
psiphon/server/tunnelServer.go

@@ -55,6 +55,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tun"
+	lrucache "github.com/cognusion/go-cache-lru"
 	"github.com/marusama/semaphore"
 	cache "github.com/patrickmn/go-cache"
 )
@@ -1930,6 +1931,8 @@ type sshClient struct {
 	peakMetrics                          peakMetrics
 	destinationBytesMetrics              map[string]*protocolDestinationBytesMetrics
 	inproxyProxyQualityTracker           *inproxyProxyQualityTracker
+	dnsResolver                          *net.Resolver
+	dnsCache                             *lrucache.Cache
 }
 
 type trafficState struct {
@@ -5168,31 +5171,56 @@ func (sshClient *sshClient) handleTCPChannel(
 
 		// Resolve the hostname
 
-		// PreferGo, equivalent to GODEBUG=netdns=go, is specified in order to
-		// avoid any cases where Go's resolver fails over to the cgo-based
-		// resolver (see https://pkg.go.dev/net#hdr-Name_Resolution). Such
-		// cases, if they resolve at all, may be expected to resolve to bogon
-		// IPs that won't be permitted; but the cgo invocation will consume
-		// an OS thread, which is a performance hit we can avoid.
-
 		if IsLogLevelDebug() {
 			log.WithTraceFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
 		}
 
-		ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
-		IPs, err := (&net.Resolver{PreferGo: true}).LookupIPAddr(ctx, hostToConnect)
-		cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
+		// See comments in getDNSResolver regarding DNS cache considerations.
+		// The cached values may be read by concurrent goroutines and must
+		// not be mutated.
+
+		dnsResolver, dnsCache := sshClient.getDNSResolver()
+
+		var IPs []net.IPAddr
+
+		if dnsCache != nil {
+			cachedIPs, ok := dnsCache.Get(hostToConnect)
+			if ok {
+				IPs = cachedIPs.([]net.IPAddr)
+			}
+		}
+
+		var err error
+		var resolveElapsedTime time.Duration
+
+		if len(IPs) == 0 {
+			ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
+			IPs, err = dnsResolver.LookupIPAddr(ctx, hostToConnect)
+			cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
+
+			resolveElapsedTime = time.Since(dialStartTime)
+
+			if err == nil && len(IPs) > 0 {
 
-		resolveElapsedTime := time.Since(dialStartTime)
+				// Add the successful DNS response to the cache. The cache
+				// won't be updated in the "no such host"/IsNotFound case,
+				// and subsequent resolves will try new requests. The "no IP
+				// address" error case in the following IP selection logic
+				// should not be reached when len(IPs) > 0.
+				if dnsCache != nil {
+					dnsCache.Add(hostToConnect, IPs, lrucache.DefaultExpiration)
+				}
+			}
 
-		// Record DNS metrics. If LookupIPAddr returns net.DNSError.IsNotFound, this
-		// is "no such host" and not a DNS failure. Limitation: the resolver IP is
-		// not known.
+			// Record DNS request metrics. If LookupIPAddr returns
+			// net.DNSError.IsNotFound, this is "no such host" and not a DNS
+			// request failure. Limitation: the DNS server IP is not known.
 
-		dnsErr, ok := err.(*net.DNSError)
-		dnsNotFound := ok && dnsErr.IsNotFound
-		dnsSuccess := err == nil || dnsNotFound
-		sshClient.updateQualityMetricsWithDNSResult(dnsSuccess, resolveElapsedTime, nil)
+			dnsErr, ok := err.(*net.DNSError)
+			dnsNotFound := ok && dnsErr.IsNotFound
+			dnsSuccess := err == nil || dnsNotFound
+			sshClient.updateQualityMetricsWithDNSResult(dnsSuccess, resolveElapsedTime, nil)
+		}
 
 		// IPv4 is preferred in case the host has limited IPv6 routing. IPv6 is
 		// selected and attempted only when there's no IPv4 option.
@@ -5427,3 +5455,82 @@ func (sshClient *sshClient) handleTCPChannel(
 				"bytesDown":  atomic.LoadInt64(&bytesDown)}).Debug("exiting")
 	}
 }
+
+func (sshClient *sshClient) getDNSResolver() (*net.Resolver, *lrucache.Cache) {
+
+	// Initialize the DNS resolver and cache used by handleTCPChannel in cases
+	// where the client sends unresolved domains through to psiphond. The
+	// resolver and cache are allocated on demand, to avoid overhead for
+	// clients that don't require this functionality.
+	//
+	// The standard library net.Resolver is used, with one instance per client
+	// to get the advantage of the "singleflight" functionality, where
+	// concurrent DNS lookups for the same domain are coalesced into a single
+	// in-flight DNS request.
+	//
+	// net.Resolver reads its configuration from /etc/resolv.conf, including a
+	// list of DNS servers, the number or retries to attempt, and whether to
+	// rotate the initial DNS server selection.
+	//
+	// In addition, a cache of successful DNS lookups is maintained to avoid
+	// rapid repeats DNS requests for the same domain. Since actual DNS
+	// response TTLs are not exposed by net.Resolver, the cache should be
+	// configured with a conservative TTL -- 10s of seconds.
+	//
+	// Each client has its own singleflight resolver and cache, which avoids
+	// leaking domain access information between clients. The cache should be
+	// configured with a modest max size appropriate for allocating one cache
+	// per client.
+	//
+	// As a potential future enhancement, consider using the custom DNS
+	// resolver, psiphon/common/resolver.Resolver, combined with the existing
+	// DNS server fetcher, SupportServices.DNSResolver. This resolver
+	// includes a cache which will respect the true TTL values in DNS
+	// responses; and randomly distributes load over the available DNS
+	// servers. Note the current limitations documented in
+	// Resolver.ResolveIP, which must be addressed.
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	if sshClient.dnsResolver != nil {
+		return sshClient.dnsResolver, sshClient.dnsCache
+	}
+
+	// PreferGo, equivalent to GODEBUG=netdns=go, is specified in order to
+	// avoid any cases where Go's resolver fails over to the cgo-based
+	// resolver (see https://pkg.go.dev/net#hdr-Name_Resolution). Such
+	// cases, if they resolve at all, may be expected to resolve to bogon
+	// IPs that won't be permitted; but the cgo invocation will consume
+	// an OS thread, which is a performance hit we can avoid.
+
+	sshClient.dnsResolver = &net.Resolver{PreferGo: true}
+
+	// Get the server DNS resolver cache parameters from tactics. In the case
+	// of an error, no tactics, or zero values no cache is initialized and
+	// getDNSResolver initializes only the resolver and returns a nil cache.
+	//
+	// Limitations:
+	// - assumes no GeoIP targeting for server DNS resolver cache parameters
+	// - an individual client's cache is not reconfigured on tactics reloads
+
+	p, err := sshClient.sshServer.support.ServerTacticsParametersCache.Get(NewGeoIPData())
+	if err != nil {
+		log.WithTraceFields(LogFields{"error": err}).Warning("get tactics failed")
+		return sshClient.dnsResolver, nil
+	}
+	if p.IsNil() {
+		return sshClient.dnsResolver, nil
+	}
+
+	maxSize := p.Int(parameters.ServerDNSResolverCacheMaxSize)
+	TTL := p.Duration(parameters.ServerDNSResolverCacheTTL)
+
+	if maxSize == 0 || TTL == 0 {
+		return sshClient.dnsResolver, nil
+	}
+
+	sshClient.dnsCache = lrucache.NewWithLRU(TTL, 1*time.Minute, maxSize)
+
+	return sshClient.dnsResolver, sshClient.dnsCache
+}