Просмотр исходного кода

Do not calculate discovery value for every GeoIP lookup

Rod Hynes 5 лет назад
Родитель
Сommit
fddbc5221d

+ 7 - 6
psiphon/common/api.go

@@ -30,12 +30,13 @@ type APIParameterValidator func(APIParameters) error
 
 
 // GeoIPData is type-compatible with psiphon/server.GeoIPData.
 // GeoIPData is type-compatible with psiphon/server.GeoIPData.
 type GeoIPData struct {
 type GeoIPData struct {
-	Country        string
-	City           string
-	ISP            string
-	ASN            string
-	ASO            string
-	DiscoveryValue int
+	Country           string
+	City              string
+	ISP               string
+	ASN               string
+	ASO               string
+	HasDiscoveryValue bool
+	DiscoveryValue    int
 }
 }
 
 
 // APIParameterLogFieldFormatter is a function that returns formatted
 // APIParameterLogFieldFormatter is a function that returns formatted

+ 3 - 0
psiphon/server/api.go

@@ -305,6 +305,9 @@ func handshakeAPIRequestHandler(
 
 
 	pad_response, _ := getPaddingSizeRequestParam(params, "pad_response")
 	pad_response, _ := getPaddingSizeRequestParam(params, "pad_response")
 
 
+	if !geoIPData.HasDiscoveryValue {
+		return nil, errors.TraceNew("unexpected missing discovery value")
+	}
 	encodedServerList := db.DiscoverServers(geoIPData.DiscoveryValue)
 	encodedServerList := db.DiscoverServers(geoIPData.DiscoveryValue)
 
 
 	// When the client indicates that it used an unsigned server entry for this
 	// When the client indicates that it used an unsigned server entry for this

+ 34 - 25
psiphon/server/geoip.go

@@ -49,12 +49,13 @@ const (
 // a special value derived from the client IP that's used to compartmentalize
 // a special value derived from the client IP that's used to compartmentalize
 // discoverable servers (see calculateDiscoveryValue for details).
 // discoverable servers (see calculateDiscoveryValue for details).
 type GeoIPData struct {
 type GeoIPData struct {
-	Country        string
-	City           string
-	ISP            string
-	ASN            string
-	ASO            string
-	DiscoveryValue int
+	Country           string
+	City              string
+	ISP               string
+	ASN               string
+	ASO               string
+	HasDiscoveryValue bool
+	DiscoveryValue    int
 }
 }
 
 
 // NewGeoIPData returns a GeoIPData initialized with the expected
 // NewGeoIPData returns a GeoIPData initialized with the expected
@@ -201,31 +202,31 @@ func (geoIP *GeoIPService) Reloaders() []common.Reloader {
 	return reloaders
 	return reloaders
 }
 }
 
 
-// Lookup determines a GeoIPData for a given string client IP address. Lookup
-// populates the GeoIPData.DiscoveryValue field.
-func (geoIP *GeoIPService) Lookup(strIP string) GeoIPData {
-	IP := net.ParseIP(strIP)
-	if IP == nil {
-		return NewGeoIPData()
-	}
+// Lookup determines a GeoIPData for a given string client IP address.
+//
+// When addDiscoveryValue is true, GeoIPData.DiscoveryValue is calculated and
+// GeoIPData.HasDiscoveryValue is true.
+func (geoIP *GeoIPService) Lookup(
+	strIP string, addDiscoveryValue bool) GeoIPData {
 
 
-	result := geoIP.LookupIP(IP)
-
-	result.DiscoveryValue = calculateDiscoveryValue(
-		geoIP.discoveryValueHMACKey, strIP)
-
-	return result
+	return geoIP.LookupIP(net.ParseIP(strIP), addDiscoveryValue)
 }
 }
 
 
-// LookupIP determines a GeoIPData for a given client IP address. LookupIP
-// omits the GeoIPData.DiscoveryValue field.
-func (geoIP *GeoIPService) LookupIP(IP net.IP) GeoIPData {
+// LookupIP determines a GeoIPData for a given client IP address.
+//
+// When addDiscoveryValue is true, GeoIPData.DiscoveryValue is calculated and
+// GeoIPData.HasDiscoveryValue is true.
+func (geoIP *GeoIPService) LookupIP(
+	IP net.IP, addDiscoveryValue bool) GeoIPData {
+
 	result := NewGeoIPData()
 	result := NewGeoIPData()
 
 
-	if len(geoIP.databases) == 0 {
+	if IP == nil {
 		return result
 		return result
 	}
 	}
 
 
+	// Populate GeoIP fields.
+
 	var geoIPFields struct {
 	var geoIPFields struct {
 		Country struct {
 		Country struct {
 			ISOCode string `maxminddb:"iso_code"`
 			ISOCode string `maxminddb:"iso_code"`
@@ -273,6 +274,14 @@ func (geoIP *GeoIPService) LookupIP(IP net.IP) GeoIPData {
 		result.ASO = geoIPFields.ASO
 		result.ASO = geoIPFields.ASO
 	}
 	}
 
 
+	// Populate DiscoveryValue fields (even when there's no GeoIP database).
+
+	if addDiscoveryValue {
+		result.HasDiscoveryValue = true
+		result.DiscoveryValue = calculateDiscoveryValue(
+			geoIP.discoveryValueHMACKey, IP)
+	}
+
 	return result
 	return result
 }
 }
 
 
@@ -323,13 +332,13 @@ func (geoIP *GeoIPService) InSessionCache(sessionID string) bool {
 // later use by the discovery algorithm.
 // later use by the discovery algorithm.
 // See https://github.com/Psiphon-Inc/psiphon-automation/tree/master/Automation/psi_ops_discovery.py
 // See https://github.com/Psiphon-Inc/psiphon-automation/tree/master/Automation/psi_ops_discovery.py
 // for full details.
 // for full details.
-func calculateDiscoveryValue(discoveryValueHMACKey, ipAddress string) int {
+func calculateDiscoveryValue(discoveryValueHMACKey string, ipAddress net.IP) int {
 	// From: psi_ops_discovery.calculate_ip_address_strategy_value:
 	// From: psi_ops_discovery.calculate_ip_address_strategy_value:
 	//     # Mix bits from all octets of the client IP address to determine the
 	//     # Mix bits from all octets of the client IP address to determine the
 	//     # bucket. An HMAC is used to prevent pre-calculation of buckets for IPs.
 	//     # bucket. An HMAC is used to prevent pre-calculation of buckets for IPs.
 	//     return ord(hmac.new(HMAC_KEY, ip_address, hashlib.sha256).digest()[0])
 	//     return ord(hmac.new(HMAC_KEY, ip_address, hashlib.sha256).digest()[0])
 	// TODO: use 3-octet algorithm?
 	// TODO: use 3-octet algorithm?
 	hash := hmac.New(sha256.New, []byte(discoveryValueHMACKey))
 	hash := hmac.New(sha256.New, []byte(discoveryValueHMACKey))
-	hash.Write([]byte(ipAddress))
+	hash.Write([]byte(ipAddress.String()))
 	return int(hash.Sum(nil)[0])
 	return int(hash.Sum(nil)[0])
 }
 }

+ 4 - 3
psiphon/server/meek.go

@@ -354,7 +354,7 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
 		// Endpoint mode. Currently, this means it's handled by the tactics
 		// Endpoint mode. Currently, this means it's handled by the tactics
 		// request handler.
 		// request handler.
 
 
-		geoIPData := server.support.GeoIPService.Lookup(clientIP)
+		geoIPData := server.support.GeoIPService.Lookup(clientIP, false)
 		handled := server.support.TacticsServer.HandleEndPoint(
 		handled := server.support.TacticsServer.HandleEndPoint(
 			endPoint, common.GeoIPData(geoIPData), responseWriter, request)
 			endPoint, common.GeoIPData(geoIPData), responseWriter, request)
 		if !handled {
 		if !handled {
@@ -601,7 +601,8 @@ func (server *MeekServer) getSessionOrEndpoint(
 				// the client IP.
 				// the client IP.
 				proxyClientIP := strings.Split(value, ",")[0]
 				proxyClientIP := strings.Split(value, ",")[0]
 				if net.ParseIP(proxyClientIP) != nil &&
 				if net.ParseIP(proxyClientIP) != nil &&
-					server.support.GeoIPService.Lookup(proxyClientIP).Country != GEOIP_UNKNOWN_VALUE {
+					server.support.GeoIPService.Lookup(
+						proxyClientIP, false).Country != GEOIP_UNKNOWN_VALUE {
 
 
 					clientIP = proxyClientIP
 					clientIP = proxyClientIP
 					break
 					break
@@ -723,7 +724,7 @@ func (server *MeekServer) rateLimit(clientIP string) bool {
 	if len(regions) > 0 || len(ISPs) > 0 || len(cities) > 0 {
 	if len(regions) > 0 || len(ISPs) > 0 || len(cities) > 0 {
 
 
 		// TODO: avoid redundant GeoIP lookups?
 		// TODO: avoid redundant GeoIP lookups?
-		geoIPData := server.support.GeoIPService.Lookup(clientIP)
+		geoIPData := server.support.GeoIPService.Lookup(clientIP, false)
 
 
 		if len(regions) > 0 {
 		if len(regions) > 0 {
 			if !common.Contains(regions, geoIPData.Country) {
 			if !common.Contains(regions, geoIPData.Country) {

+ 1 - 1
psiphon/server/packetman.go

@@ -141,7 +141,7 @@ func selectPacketManipulationSpec(
 			"packet manipulation protocol port not found: %d", protocolPort)
 			"packet manipulation protocol port not found: %d", protocolPort)
 	}
 	}
 
 
-	geoIPData := support.GeoIPService.Lookup(clientIP.String())
+	geoIPData := support.GeoIPService.LookupIP(clientIP, false)
 
 
 	specName, doReplay := support.ReplayCache.GetReplayPacketManipulation(
 	specName, doReplay := support.ReplayCache.GetReplayPacketManipulation(
 		targetTunnelProtocol, geoIPData)
 		targetTunnelProtocol, geoIPData)

+ 1 - 1
psiphon/server/services.go

@@ -415,7 +415,7 @@ func logIrregularTunnel(
 	logFields["event_name"] = "irregular_tunnel"
 	logFields["event_name"] = "irregular_tunnel"
 	logFields["listener_protocol"] = listenerTunnelProtocol
 	logFields["listener_protocol"] = listenerTunnelProtocol
 	logFields["listener_port_number"] = listenerPort
 	logFields["listener_port_number"] = listenerPort
-	support.GeoIPService.Lookup(clientIP).SetLogFields(logFields)
+	support.GeoIPService.Lookup(clientIP, false).SetLogFields(logFields)
 	logFields["tunnel_error"] = tunnelError.Error()
 	logFields["tunnel_error"] = tunnelError.Error()
 	log.LogRawFieldsWithTimestamp(logFields)
 	log.LogRawFieldsWithTimestamp(logFields)
 }
 }

+ 3 - 3
psiphon/server/tunnelServer.go

@@ -199,7 +199,7 @@ func (server *TunnelServer) Run() error {
 			support,
 			support,
 			listener,
 			listener,
 			tunnelProtocol,
 			tunnelProtocol,
-			func(IP string) GeoIPData { return support.GeoIPService.Lookup(IP) })
+			func(IP string) GeoIPData { return support.GeoIPService.Lookup(IP, false) })
 
 
 		log.WithTraceFields(
 		log.WithTraceFields(
 			LogFields{
 			LogFields{
@@ -1102,7 +1102,7 @@ func (sshServer *sshServer) handleClient(
 	}
 	}
 
 
 	geoIPData := sshServer.support.GeoIPService.Lookup(
 	geoIPData := sshServer.support.GeoIPService.Lookup(
-		common.IPAddressFromAddr(clientAddr))
+		common.IPAddressFromAddr(clientAddr), true)
 
 
 	sshServer.registerAcceptedClient(tunnelProtocol, geoIPData.Country)
 	sshServer.registerAcceptedClient(tunnelProtocol, geoIPData.Country)
 	defer sshServer.unregisterAcceptedClient(tunnelProtocol, geoIPData.Country)
 	defer sshServer.unregisterAcceptedClient(tunnelProtocol, geoIPData.Country)
@@ -3603,7 +3603,7 @@ func (sshClient *sshClient) handleTCPChannel(
 
 
 	if doSplitTunnel {
 	if doSplitTunnel {
 
 
-		destinationGeoIPData := sshClient.sshServer.support.GeoIPService.LookupIP(IP)
+		destinationGeoIPData := sshClient.sshServer.support.GeoIPService.LookupIP(IP, false)
 
 
 		if destinationGeoIPData.Country == sshClient.geoIPData.Country &&
 		if destinationGeoIPData.Country == sshClient.geoIPData.Country &&
 			sshClient.geoIPData.Country != GEOIP_UNKNOWN_VALUE {
 			sshClient.geoIPData.Country != GEOIP_UNKNOWN_VALUE {