Просмотр исходного кода

Add option to disable discovery

Rod Hynes 4 лет назад
Родитель
Сommit
9efdd08348
3 измененных файлов с 58 добавлено и 7 удалено
  1. 17 7
      psiphon/server/api.go
  2. 12 0
      psiphon/server/trafficRules.go
  3. 29 0
      psiphon/server/tunnelServer.go

+ 17 - 7
psiphon/server/api.go

@@ -323,18 +323,28 @@ func handshakeAPIRequestHandler(
 
 	// Discover new servers
 
-	host, _, err := net.SplitHostPort(clientAddr)
+	disableDiscovery, err := support.TunnelServer.GetClientDisableDiscovery(sessionID)
 	if err != nil {
 		return nil, errors.Trace(err)
 	}
 
-	clientIP := net.ParseIP(host)
-	if clientIP == nil {
-		return nil, errors.TraceNew("missing client IP")
-	}
+	var encodedServerList []string
+
+	if !disableDiscovery {
 
-	encodedServerList := db.DiscoverServers(
-		calculateDiscoveryValue(support.Config.DiscoveryValueHMACKey, clientIP))
+		host, _, err := net.SplitHostPort(clientAddr)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		clientIP := net.ParseIP(host)
+		if clientIP == nil {
+			return nil, errors.TraceNew("missing client IP")
+		}
+
+		encodedServerList = db.DiscoverServers(
+			calculateDiscoveryValue(support.Config.DiscoveryValueHMACKey, clientIP))
+	}
 
 	// When the client indicates that it used an unsigned server entry for this
 	// connection, return a signed copy of the server entry for the client to

+ 12 - 0
psiphon/server/trafficRules.go

@@ -273,6 +273,10 @@ type TrafficRules struct {
 	// client sends an IP address. Domain names are not resolved before checking
 	// AllowSubnets.
 	AllowSubnets []string
+
+	// DisableDiscovery specifies whether to disable server entry discovery,
+	// to manage load on discovery servers.
+	DisableDiscovery *bool
 }
 
 // RateLimits is a clone of common.RateLimits with pointers
@@ -589,6 +593,10 @@ func (set *TrafficRulesSet) GetTrafficRules(
 		trafficRules.AllowSubnets = make([]string, 0)
 	}
 
+	if trafficRules.DisableDiscovery == nil {
+		trafficRules.DisableDiscovery = new(bool)
+	}
+
 	// TODO: faster lookup?
 	for _, filteredRules := range set.FilteredRules {
 
@@ -795,6 +803,10 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			trafficRules.AllowSubnets = filteredRules.Rules.AllowSubnets
 		}
 
+		if filteredRules.Rules.DisableDiscovery != nil {
+			trafficRules.DisableDiscovery = filteredRules.Rules.DisableDiscovery
+		}
+
 		break
 	}
 

+ 29 - 0
psiphon/server/tunnelServer.go

@@ -319,6 +319,14 @@ func (server *TunnelServer) GetClientHandshaked(
 	return server.sshServer.getClientHandshaked(sessionID)
 }
 
+// GetClientDisableDiscovery indicates whether discovery is disabled for the
+// client corresponding to sessionID.
+func (server *TunnelServer) GetClientDisableDiscovery(
+	sessionID string) (bool, error) {
+
+	return server.sshServer.getClientDisableDiscovery(sessionID)
+}
+
 // UpdateClientAPIParameters updates the recorded handshake API parameters for
 // the client corresponding to sessionID.
 func (server *TunnelServer) UpdateClientAPIParameters(
@@ -1130,6 +1138,20 @@ func (sshServer *sshServer) getClientHandshaked(
 	return completed, exhausted, nil
 }
 
+func (sshServer *sshServer) getClientDisableDiscovery(
+	sessionID string) (bool, error) {
+
+	sshServer.clientsMutex.Lock()
+	client := sshServer.clients[sessionID]
+	sshServer.clientsMutex.Unlock()
+
+	if client == nil {
+		return false, errors.TraceNew("unknown session ID")
+	}
+
+	return client.getDisableDiscovery(), nil
+}
+
 func (sshServer *sshServer) updateClientAPIParameters(
 	sessionID string,
 	apiParams common.APIParameters) error {
@@ -3320,6 +3342,13 @@ func (sshClient *sshClient) getHandshaked() (bool, bool) {
 	return completed, exhausted
 }
 
+func (sshClient *sshClient) getDisableDiscovery() bool {
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	return *sshClient.trafficRules.DisableDiscovery
+}
+
 func (sshClient *sshClient) updateAPIParameters(
 	apiParams common.APIParameters) {