Просмотр исходного кода

Decouple DiscoveryValue from GeoIPData

- Simplifies code and fixes at least one
  GeoIPData struct comparison operation in
  setHandshakeState.
Rod Hynes 4 лет назад
Родитель
Сommit
339d72991c

+ 5 - 7
psiphon/common/api.go

@@ -30,13 +30,11 @@ type APIParameterValidator func(APIParameters) error
 
 // GeoIPData is type-compatible with psiphon/server.GeoIPData.
 type GeoIPData struct {
-	Country           string
-	City              string
-	ISP               string
-	ASN               string
-	ASO               string
-	HasDiscoveryValue bool
-	DiscoveryValue    int
+	Country string
+	City    string
+	ISP     string
+	ASN     string
+	ASO     string
 }
 
 // APIParameterLogFieldFormatter is a function that returns formatted

+ 30 - 3
psiphon/server/api.go

@@ -20,6 +20,8 @@
 package server
 
 import (
+	"crypto/hmac"
+	"crypto/sha256"
 	"crypto/subtle"
 	"encoding/base64"
 	"encoding/json"
@@ -301,10 +303,20 @@ func handshakeAPIRequestHandler(
 
 	pad_response, _ := getPaddingSizeRequestParam(params, "pad_response")
 
-	if !geoIPData.HasDiscoveryValue {
-		return nil, errors.TraceNew("unexpected missing discovery value")
+	// Discover new servers
+
+	host, _, err := net.SplitHostPort(clientAddr)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	clientIP := net.ParseIP(host)
+	if clientIP == nil {
+		return nil, errors.TraceNew("missing client IP")
 	}
-	encodedServerList := db.DiscoverServers(geoIPData.DiscoveryValue)
+
+	encodedServerList := db.DiscoverServers(
+		calculateDiscoveryValue(support.Config.DiscoveryValueHMACKey, clientIP))
 
 	// When the client indicates that it used an unsigned server entry for this
 	// connection, return a signed copy of the server entry for the client to
@@ -353,6 +365,21 @@ func handshakeAPIRequestHandler(
 	return responsePayload, nil
 }
 
+// calculateDiscoveryValue derives a value from the client IP address to be
+// used as input in the server discovery algorithm.
+// See https://github.com/Psiphon-Inc/psiphon-automation/tree/master/Automation/psi_ops_discovery.py
+// for full details.
+func calculateDiscoveryValue(discoveryValueHMACKey string, ipAddress net.IP) int {
+	// From: psi_ops_discovery.calculate_ip_address_strategy_value:
+	//     # Mix bits from all octets of the client IP address to determine the
+	//     # bucket. An HMAC is used to prevent pre-calculation of buckets for IPs.
+	//     return ord(hmac.new(HMAC_KEY, ip_address, hashlib.sha256).digest()[0])
+	// TODO: use 3-octet algorithm?
+	hash := hmac.New(sha256.New, []byte(discoveryValueHMACKey))
+	hash.Write([]byte(ipAddress.String()))
+	return int(hash.Sum(nil)[0])
+}
+
 // uniqueUserParams are the connected request parameters which are logged for
 // unique_user events.
 var uniqueUserParams = append(

+ 14 - 58
psiphon/server/geoip.go

@@ -20,8 +20,6 @@
 package server
 
 import (
-	"crypto/hmac"
-	"crypto/sha256"
 	"fmt"
 	"io"
 	"net"
@@ -45,17 +43,13 @@ const (
 // GeoIPData is GeoIP data for a client session. Individual client
 // IP addresses are neither logged nor explicitly referenced during a session.
 // The GeoIP country, city, and ISP corresponding to a client IP address are
-// resolved and then logged along with usage stats. The DiscoveryValue is
-// a special value derived from the client IP that's used to compartmentalize
-// discoverable servers (see calculateDiscoveryValue for details).
+// resolved and then logged along with usage stats.
 type GeoIPData struct {
-	Country           string
-	City              string
-	ISP               string
-	ASN               string
-	ASO               string
-	HasDiscoveryValue bool
-	DiscoveryValue    int
+	Country string
+	City    string
+	ISP     string
+	ASN     string
+	ASO     string
 }
 
 // NewGeoIPData returns a GeoIPData initialized with the expected
@@ -93,9 +87,8 @@ func (g GeoIPData) SetLogFieldsWithPrefix(prefix string, logFields LogFields) {
 // supports hot reloading of MaxMind data while the server is
 // running.
 type GeoIPService struct {
-	databases             []*geoIPDatabase
-	sessionCache          *cache.Cache
-	discoveryValueHMACKey string
+	databases    []*geoIPDatabase
+	sessionCache *cache.Cache
 }
 
 type geoIPDatabase struct {
@@ -107,14 +100,11 @@ type geoIPDatabase struct {
 }
 
 // NewGeoIPService initializes a new GeoIPService.
-func NewGeoIPService(
-	databaseFilenames []string,
-	discoveryValueHMACKey string) (*GeoIPService, error) {
+func NewGeoIPService(databaseFilenames []string) (*GeoIPService, error) {
 
 	geoIP := &GeoIPService{
-		databases:             make([]*geoIPDatabase, len(databaseFilenames)),
-		sessionCache:          cache.New(GEOIP_SESSION_CACHE_TTL, 1*time.Minute),
-		discoveryValueHMACKey: discoveryValueHMACKey,
+		databases:    make([]*geoIPDatabase, len(databaseFilenames)),
+		sessionCache: cache.New(GEOIP_SESSION_CACHE_TTL, 1*time.Minute),
 	}
 
 	for i, filename := range databaseFilenames {
@@ -203,21 +193,12 @@ func (geoIP *GeoIPService) Reloaders() []common.Reloader {
 }
 
 // Lookup determines a GeoIPData for a given string client IP address.
-//
-// When addDiscoveryValue is true, GeoIPData.DiscoveryValue is calculated and
-// GeoIPData.HasDiscoveryValue is true.
-func (geoIP *GeoIPService) Lookup(
-	strIP string, addDiscoveryValue bool) GeoIPData {
-
-	return geoIP.LookupIP(net.ParseIP(strIP), addDiscoveryValue)
+func (geoIP *GeoIPService) Lookup(strIP string) GeoIPData {
+	return geoIP.LookupIP(net.ParseIP(strIP))
 }
 
 // LookupIP determines a GeoIPData for a given client IP address.
-//
-// When addDiscoveryValue is true, GeoIPData.DiscoveryValue is calculated and
-// GeoIPData.HasDiscoveryValue is true.
-func (geoIP *GeoIPService) LookupIP(
-	IP net.IP, addDiscoveryValue bool) GeoIPData {
+func (geoIP *GeoIPService) LookupIP(IP net.IP) GeoIPData {
 
 	result := NewGeoIPData()
 
@@ -274,14 +255,6 @@ func (geoIP *GeoIPService) LookupIP(
 		result.ASO = geoIPFields.ASO
 	}
 
-	// Populate DiscoveryValue fields (even when there's no GeoIP database).
-
-	if addDiscoveryValue {
-		result.HasDiscoveryValue = true
-		result.DiscoveryValue = calculateDiscoveryValue(
-			geoIP.discoveryValueHMACKey, IP)
-	}
-
 	return result
 }
 
@@ -325,20 +298,3 @@ func (geoIP *GeoIPService) InSessionCache(sessionID string) bool {
 	_, found := geoIP.sessionCache.Get(sessionID)
 	return found
 }
-
-// calculateDiscoveryValue derives a value from the client IP address to be
-// used as input in the server discovery algorithm. Since we do not explicitly
-// store the client IP address, we must derive the value here and store it for
-// later use by the discovery algorithm.
-// See https://github.com/Psiphon-Inc/psiphon-automation/tree/master/Automation/psi_ops_discovery.py
-// for full details.
-func calculateDiscoveryValue(discoveryValueHMACKey string, ipAddress net.IP) int {
-	// From: psi_ops_discovery.calculate_ip_address_strategy_value:
-	//     # Mix bits from all octets of the client IP address to determine the
-	//     # bucket. An HMAC is used to prevent pre-calculation of buckets for IPs.
-	//     return ord(hmac.new(HMAC_KEY, ip_address, hashlib.sha256).digest()[0])
-	// TODO: use 3-octet algorithm?
-	hash := hmac.New(sha256.New, []byte(discoveryValueHMACKey))
-	hash.Write([]byte(ipAddress.String()))
-	return int(hash.Sum(nil)[0])
-}

+ 3 - 3
psiphon/server/meek.go

@@ -359,7 +359,7 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
 		// Endpoint mode. Currently, this means it's handled by the tactics
 		// request handler.
 
-		geoIPData := server.support.GeoIPService.Lookup(clientIP, false)
+		geoIPData := server.support.GeoIPService.Lookup(clientIP)
 		handled := server.support.TacticsServer.HandleEndPoint(
 			endPoint, common.GeoIPData(geoIPData), responseWriter, request)
 		if !handled {
@@ -621,7 +621,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 				proxyClientIP := strings.Split(value, ",")[0]
 				if net.ParseIP(proxyClientIP) != nil &&
 					server.support.GeoIPService.Lookup(
-						proxyClientIP, false).Country != GEOIP_UNKNOWN_VALUE {
+						proxyClientIP).Country != GEOIP_UNKNOWN_VALUE {
 
 					clientIP = proxyClientIP
 					break
@@ -741,7 +741,7 @@ func (server *MeekServer) rateLimit(clientIP string) bool {
 	if len(regions) > 0 || len(ISPs) > 0 || len(cities) > 0 {
 
 		// TODO: avoid redundant GeoIP lookups?
-		geoIPData := server.support.GeoIPService.Lookup(clientIP, false)
+		geoIPData := server.support.GeoIPService.Lookup(clientIP)
 
 		if len(regions) > 0 {
 			if !common.Contains(regions, geoIPData.Country) {

+ 1 - 1
psiphon/server/packetman.go

@@ -141,7 +141,7 @@ func selectPacketManipulationSpec(
 			"packet manipulation protocol port not found: %d", protocolPort)
 	}
 
-	geoIPData := support.GeoIPService.LookupIP(clientIP, false)
+	geoIPData := support.GeoIPService.LookupIP(clientIP)
 
 	specName, doReplay := support.ReplayCache.GetReplayPacketManipulation(
 		targetTunnelProtocol, geoIPData)

+ 2 - 3
psiphon/server/services.go

@@ -415,7 +415,7 @@ func logIrregularTunnel(
 	logFields["event_name"] = "irregular_tunnel"
 	logFields["listener_protocol"] = listenerTunnelProtocol
 	logFields["listener_port_number"] = listenerPort
-	support.GeoIPService.Lookup(clientIP, false).SetLogFields(logFields)
+	support.GeoIPService.Lookup(clientIP).SetLogFields(logFields)
 	logFields["tunnel_error"] = tunnelError.Error()
 	log.LogRawFieldsWithTimestamp(logFields)
 }
@@ -459,8 +459,7 @@ func NewSupportServices(config *Config) (*SupportServices, error) {
 		return nil, errors.Trace(err)
 	}
 
-	geoIPService, err := NewGeoIPService(
-		config.GeoIPDatabaseFilenames, config.DiscoveryValueHMACKey)
+	geoIPService, err := NewGeoIPService(config.GeoIPDatabaseFilenames)
 	if err != nil {
 		return nil, errors.Trace(err)
 	}

+ 3 - 3
psiphon/server/tunnelServer.go

@@ -199,7 +199,7 @@ func (server *TunnelServer) Run() error {
 			support,
 			listener,
 			tunnelProtocol,
-			func(IP string) GeoIPData { return support.GeoIPService.Lookup(IP, false) })
+			func(IP string) GeoIPData { return support.GeoIPService.Lookup(IP) })
 
 		log.WithTraceFields(
 			LogFields{
@@ -1159,7 +1159,7 @@ func (sshServer *sshServer) handleClient(
 	}
 
 	geoIPData := sshServer.support.GeoIPService.Lookup(
-		common.IPAddressFromAddr(clientAddr), true)
+		common.IPAddressFromAddr(clientAddr))
 
 	sshServer.registerAcceptedClient(tunnelProtocol, geoIPData.Country)
 	defer sshServer.unregisterAcceptedClient(tunnelProtocol, geoIPData.Country)
@@ -3799,7 +3799,7 @@ func (sshClient *sshClient) handleTCPChannel(
 
 	if doSplitTunnel {
 
-		destinationGeoIPData := sshClient.sshServer.support.GeoIPService.LookupIP(IP, false)
+		destinationGeoIPData := sshClient.sshServer.support.GeoIPService.LookupIP(IP)
 
 		if destinationGeoIPData.Country == sshClient.geoIPData.Country &&
 			sshClient.geoIPData.Country != GEOIP_UNKNOWN_VALUE {