mirokuratczyk 1 год назад
Родитель
Сommit
b6105581bb
4 измененных файлов с 51 добавлено и 51 удалено
  1. 26 28
      psiphon/common/osl/osl.go
  2. 15 15
      psiphon/common/osl/osl_test.go
  3. 2 1
      psiphon/server/geoip.go
  4. 8 7
      psiphon/server/tunnelServer.go

+ 26 - 28
psiphon/common/osl/osl.go

@@ -103,7 +103,7 @@ type Scheme struct {
 	// SeedSpecs is the set of different client network activity patterns
 	// SeedSpecs is the set of different client network activity patterns
 	// that will result in issuing SLOKs. For a given time period, a distinct
 	// that will result in issuing SLOKs. For a given time period, a distinct
 	// SLOK is issued for each SeedSpec.
 	// SLOK is issued for each SeedSpec.
-	// Duplicate subnets and asns may appear in multiple SeedSpecs.
+	// Duplicate subnets and ASNs may appear in multiple SeedSpecs.
 	SeedSpecs []*SeedSpec
 	SeedSpecs []*SeedSpec
 
 
 	// SeedSpecThreshold is the threshold scheme for combining SLOKs to
 	// SeedSpecThreshold is the threshold scheme for combining SLOKs to
@@ -149,7 +149,6 @@ type Scheme struct {
 
 
 	epoch                 time.Time
 	epoch                 time.Time
 	subnetLookups         []common.SubnetLookup
 	subnetLookups         []common.SubnetLookup
-	asnLookups            [][]string
 	derivedSLOKCacheMutex sync.RWMutex
 	derivedSLOKCacheMutex sync.RWMutex
 	derivedSLOKCache      map[slokReference]*SLOK
 	derivedSLOKCache      map[slokReference]*SLOK
 }
 }
@@ -200,7 +199,6 @@ type ClientSeedState struct {
 	signalIssueSLOKs     chan struct{}
 	signalIssueSLOKs     chan struct{}
 	issuedSLOKs          map[string]*SLOK
 	issuedSLOKs          map[string]*SLOK
 	payloadSLOKs         []*SLOK
 	payloadSLOKs         []*SLOK
-	lookupASN            func(net.IP) string
 }
 }
 
 
 // ClientSeedProgress tracks client progress towards seeding SLOKs for
 // ClientSeedProgress tracks client progress towards seeding SLOKs for
@@ -216,7 +214,7 @@ type ClientSeedProgress struct {
 
 
 // ClientSeedPortForward map a client port forward, which is relaying
 // ClientSeedPortForward map a client port forward, which is relaying
 // traffic to a specific upstream address, to all seed state progress
 // traffic to a specific upstream address, to all seed state progress
-// counters for SeedSpecs with subnets and asns containing the upstream address.
+// counters for SeedSpecs with subnets and ASNs containing the upstream address.
 // As traffic is relayed through the port forwards, the bytes transferred
 // As traffic is relayed through the port forwards, the bytes transferred
 // and duration count towards the progress of these SeedSpecs and
 // and duration count towards the progress of these SeedSpecs and
 // associated SLOKs.
 // associated SLOKs.
@@ -322,7 +320,6 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 
 
 		scheme.epoch = epoch
 		scheme.epoch = epoch
 		scheme.subnetLookups = make([]common.SubnetLookup, len(scheme.SeedSpecs))
 		scheme.subnetLookups = make([]common.SubnetLookup, len(scheme.SeedSpecs))
-		scheme.asnLookups = make([][]string, len(scheme.SeedSpecs))
 		scheme.derivedSLOKCache = make(map[slokReference]*SLOK)
 		scheme.derivedSLOKCache = make(map[slokReference]*SLOK)
 
 
 		if len(scheme.MasterKey) != KEY_LENGTH_BYTES {
 		if len(scheme.MasterKey) != KEY_LENGTH_BYTES {
@@ -348,16 +345,14 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 			scheme.subnetLookups[index] = subnetLookup
 			scheme.subnetLookups[index] = subnetLookup
 
 
 			// Ensure there are no duplicates.
 			// Ensure there are no duplicates.
-			asns := make(map[string]struct{}, len(seedSpec.UpstreamASNs))
-			for _, asn := range seedSpec.UpstreamASNs {
-				if _, ok := asns[asn]; ok {
-					return nil, errors.Tracef("invalid upstream asns, duplicate asn: %s", asn)
+			ASNs := make(map[string]struct{}, len(seedSpec.UpstreamASNs))
+			for _, ASN := range seedSpec.UpstreamASNs {
+				if _, ok := ASNs[ASN]; ok {
+					return nil, errors.Tracef("invalid upstream ASNs, duplicate ASN: %s", ASN)
 				} else {
 				} else {
-					asns[asn] = struct{}{}
+					ASNs[ASN] = struct{}{}
 				}
 				}
 			}
 			}
-
-			scheme.asnLookups[index] = seedSpec.UpstreamASNs
 		}
 		}
 
 
 		if !isValidShamirSplit(len(scheme.SeedSpecs), scheme.SeedSpecThreshold) {
 		if !isValidShamirSplit(len(scheme.SeedSpecs), scheme.SeedSpecThreshold) {
@@ -390,8 +385,7 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 // should be appropriately buffered.
 // should be appropriately buffered.
 func (config *Config) NewClientSeedState(
 func (config *Config) NewClientSeedState(
 	clientRegion, propagationChannelID string,
 	clientRegion, propagationChannelID string,
-	signalIssueSLOKs chan struct{},
-	lookupASN func(net.IP) string) *ClientSeedState {
+	signalIssueSLOKs chan struct{}) *ClientSeedState {
 
 
 	config.ReloadableFile.RLock()
 	config.ReloadableFile.RLock()
 	defer config.ReloadableFile.RUnlock()
 	defer config.ReloadableFile.RUnlock()
@@ -401,7 +395,6 @@ func (config *Config) NewClientSeedState(
 		signalIssueSLOKs:     signalIssueSLOKs,
 		signalIssueSLOKs:     signalIssueSLOKs,
 		issuedSLOKs:          make(map[string]*SLOK),
 		issuedSLOKs:          make(map[string]*SLOK),
 		payloadSLOKs:         nil,
 		payloadSLOKs:         nil,
-		lookupASN:            lookupASN,
 	}
 	}
 
 
 	for _, scheme := range config.Schemes {
 	for _, scheme := range config.Schemes {
@@ -468,13 +461,14 @@ func (state *ClientSeedState) Resume(
 // NewClientSeedPortForward creates a new client port forward
 // NewClientSeedPortForward creates a new client port forward
 // traffic progress tracker. Port forward progress reported to the
 // traffic progress tracker. Port forward progress reported to the
 // ClientSeedPortForward is added to seed state progress for all
 // ClientSeedPortForward is added to seed state progress for all
-// seed specs containing upstreamIPAddress in their subnets or asns.
+// seed specs containing upstreamIPAddress in their subnets or ASNs.
 // The return value will be nil when activity for upstreamIPAddress
 // The return value will be nil when activity for upstreamIPAddress
 // does not count towards any progress.
 // does not count towards any progress.
 // NewClientSeedPortForward may be invoked concurrently by many
 // NewClientSeedPortForward may be invoked concurrently by many
 // psiphond port forward establishment goroutines.
 // psiphond port forward establishment goroutines.
 func (state *ClientSeedState) NewClientSeedPortForward(
 func (state *ClientSeedState) NewClientSeedPortForward(
-	upstreamIPAddress net.IP) *ClientSeedPortForward {
+	upstreamIPAddress net.IP,
+	lookupASN func(net.IP) string) *ClientSeedPortForward {
 
 
 	// Concurrency: access to ClientSeedState is unsynchronized
 	// Concurrency: access to ClientSeedState is unsynchronized
 	// but references only read-only fields.
 	// but references only read-only fields.
@@ -485,20 +479,21 @@ func (state *ClientSeedState) NewClientSeedPortForward(
 
 
 	var progressReferences []progressReference
 	var progressReferences []progressReference
 
 
-	// Determine which seed spec subnets and asns contain upstreamIPAddress
+	// Determine which seed spec subnets and ASNs contain upstreamIPAddress
 	// and point to the progress for each. When progress is reported,
 	// and point to the progress for each. When progress is reported,
 	// it is added directly to all of these TrafficValues instances.
 	// it is added directly to all of these TrafficValues instances.
 	// Assumes state.seedProgress entries correspond 1-to-1 with
 	// Assumes state.seedProgress entries correspond 1-to-1 with
-	// state.scheme.subnetLookups and state.scheme.asnLookups.
+	// state.scheme.subnetLookups.
 	// Note: this implementation assumes a small number of schemes and
 	// Note: this implementation assumes a small number of schemes and
 	// seed specs. For larger numbers, instead of N SubnetLookups, create
 	// seed specs. For larger numbers, instead of N SubnetLookups, create
 	// a single SubnetLookup which returns, for a given IP address, all
 	// a single SubnetLookup which returns, for a given IP address, all
 	// matching subnets and associated seed specs.
 	// matching subnets and associated seed specs.
 	for seedProgressIndex, seedProgress := range state.seedProgress {
 	for seedProgressIndex, seedProgress := range state.seedProgress {
 
 
-		var upstreamASN *string
+		var upstreamASN string
+		var upstreamASNSet bool
 
 
-		for trafficProgressIndex := range seedProgress.scheme.SeedSpecs {
+		for trafficProgressIndex, seedSpec := range seedProgress.scheme.SeedSpecs {
 
 
 			matchesSeedSpec := false
 			matchesSeedSpec := false
 
 
@@ -507,16 +502,19 @@ func (state *ClientSeedState) NewClientSeedPortForward(
 			subnetLookup := seedProgress.scheme.subnetLookups[trafficProgressIndex]
 			subnetLookup := seedProgress.scheme.subnetLookups[trafficProgressIndex]
 			matchesSeedSpec = subnetLookup.ContainsIPAddress(upstreamIPAddress)
 			matchesSeedSpec = subnetLookup.ContainsIPAddress(upstreamIPAddress)
 
 
-			if !matchesSeedSpec && state.lookupASN != nil {
+			if !matchesSeedSpec && lookupASN != nil {
 				// No subnet match. Check for ASN match.
 				// No subnet match. Check for ASN match.
-				asnLookup := seedProgress.scheme.asnLookups[trafficProgressIndex]
-				if len(asnLookup) > 0 {
+				if len(seedSpec.UpstreamASNs) > 0 {
 					// Lookup ASN on demand and only once.
 					// Lookup ASN on demand and only once.
-					if upstreamASN == nil {
-						upstreamASN = new(string)
-						*upstreamASN = state.lookupASN(upstreamIPAddress)
+					if !upstreamASNSet {
+						upstreamASN = lookupASN(upstreamIPAddress)
+						upstreamASNSet = true
 					}
 					}
-					matchesSeedSpec = common.Contains(asnLookup, *upstreamASN)
+					// TODO: use a map for faster lookups when the number of
+					// string values to compare against exceeds a threshold
+					// where benchmarks show maps are faster than looping
+					// through a string slice.
+					matchesSeedSpec = common.Contains(seedSpec.UpstreamASNs, upstreamASN)
 				}
 				}
 			}
 			}
 
 

+ 15 - 15
psiphon/common/osl/osl_test.go

@@ -179,9 +179,9 @@ func TestOSL(t *testing.T) {
 
 
 	t.Run("ineligible client, sufficient transfer", func(t *testing.T) {
 	t.Run("ineligible client, sufficient transfer", func(t *testing.T) {
 
 
-		clientSeedState := config.NewClientSeedState("US", "C5E8D2EDFD093B50D8D65CF59D0263CA", nil, lookupASN)
+		clientSeedState := config.NewClientSeedState("US", "C5E8D2EDFD093B50D8D65CF59D0263CA", nil)
 
 
-		seedPortForward := clientSeedState.NewClientSeedPortForward(net.ParseIP("192.168.0.1"))
+		seedPortForward := clientSeedState.NewClientSeedPortForward(net.ParseIP("192.168.0.1"), lookupASN)
 
 
 		if seedPortForward != nil {
 		if seedPortForward != nil {
 			t.Fatalf("expected nil client seed port forward")
 			t.Fatalf("expected nil client seed port forward")
@@ -190,7 +190,7 @@ func TestOSL(t *testing.T) {
 
 
 	// This clientSeedState is used across multiple tests.
 	// This clientSeedState is used across multiple tests.
 	signalIssueSLOKs := make(chan struct{}, 1)
 	signalIssueSLOKs := make(chan struct{}, 1)
-	clientSeedState := config.NewClientSeedState("US", "2995DB0C968C59C4F23E87988D9C0D41", signalIssueSLOKs, lookupASN)
+	clientSeedState := config.NewClientSeedState("US", "2995DB0C968C59C4F23E87988D9C0D41", signalIssueSLOKs)
 
 
 	t.Run("eligible client, no transfer", func(t *testing.T) {
 	t.Run("eligible client, no transfer", func(t *testing.T) {
 
 
@@ -201,7 +201,7 @@ func TestOSL(t *testing.T) {
 
 
 	t.Run("eligible client, insufficient transfer", func(t *testing.T) {
 	t.Run("eligible client, insufficient transfer", func(t *testing.T) {
 
 
-		clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1")).UpdateProgress(5, 5, 5)
+		clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1"), lookupASN).UpdateProgress(5, 5, 5)
 
 
 		if len(clientSeedState.GetSeedPayload().SLOKs) != 0 {
 		if len(clientSeedState.GetSeedPayload().SLOKs) != 0 {
 			t.Fatalf("expected 0 SLOKs, got %d", len(clientSeedState.GetSeedPayload().SLOKs))
 			t.Fatalf("expected 0 SLOKs, got %d", len(clientSeedState.GetSeedPayload().SLOKs))
@@ -218,7 +218,7 @@ func TestOSL(t *testing.T) {
 
 
 		rolloverToNextSLOKTime()
 		rolloverToNextSLOKTime()
 
 
-		clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1")).UpdateProgress(5, 5, 5)
+		clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1"), lookupASN).UpdateProgress(5, 5, 5)
 
 
 		if len(clientSeedState.GetSeedPayload().SLOKs) != 0 {
 		if len(clientSeedState.GetSeedPayload().SLOKs) != 0 {
 			t.Fatalf("expected 0 SLOKs, got %d", len(clientSeedState.GetSeedPayload().SLOKs))
 			t.Fatalf("expected 0 SLOKs, got %d", len(clientSeedState.GetSeedPayload().SLOKs))
@@ -229,7 +229,7 @@ func TestOSL(t *testing.T) {
 
 
 		rolloverToNextSLOKTime()
 		rolloverToNextSLOKTime()
 
 
-		clientSeedPortForward := clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1"))
+		clientSeedPortForward := clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1"), lookupASN)
 
 
 		clientSeedPortForward.UpdateProgress(5, 5, 5)
 		clientSeedPortForward.UpdateProgress(5, 5, 5)
 
 
@@ -252,7 +252,7 @@ func TestOSL(t *testing.T) {
 
 
 		*portForwardASN = "0000"
 		*portForwardASN = "0000"
 
 
-		clientSeedPortForward := clientSeedState.NewClientSeedPortForward(net.ParseIP("11.0.0.1"))
+		clientSeedPortForward := clientSeedState.NewClientSeedPortForward(net.ParseIP("11.0.0.1"), lookupASN)
 
 
 		clientSeedPortForward.UpdateProgress(5, 5, 5)
 		clientSeedPortForward.UpdateProgress(5, 5, 5)
 
 
@@ -278,7 +278,7 @@ func TestOSL(t *testing.T) {
 
 
 		*portForwardASN = "0000"
 		*portForwardASN = "0000"
 
 
-		clientSeedPortForward := clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1"))
+		clientSeedPortForward := clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1"), lookupASN)
 
 
 		clientSeedPortForward.UpdateProgress(5, 5, 5)
 		clientSeedPortForward.UpdateProgress(5, 5, 5)
 
 
@@ -307,9 +307,9 @@ func TestOSL(t *testing.T) {
 
 
 		rolloverToNextSLOKTime()
 		rolloverToNextSLOKTime()
 
 
-		clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1")).UpdateProgress(5, 5, 5)
+		clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1"), lookupASN).UpdateProgress(5, 5, 5)
 
 
-		clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1")).UpdateProgress(5, 5, 5)
+		clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1"), lookupASN).UpdateProgress(5, 5, 5)
 
 
 		select {
 		select {
 		case <-signalIssueSLOKs:
 		case <-signalIssueSLOKs:
@@ -327,9 +327,9 @@ func TestOSL(t *testing.T) {
 
 
 		rolloverToNextSLOKTime()
 		rolloverToNextSLOKTime()
 
 
-		clientSeedState.NewClientSeedPortForward(net.ParseIP("192.168.0.1")).UpdateProgress(5, 5, 5)
+		clientSeedState.NewClientSeedPortForward(net.ParseIP("192.168.0.1"), lookupASN).UpdateProgress(5, 5, 5)
 
 
-		clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1")).UpdateProgress(5, 5, 5)
+		clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1"), lookupASN).UpdateProgress(5, 5, 5)
 
 
 		select {
 		select {
 		case <-signalIssueSLOKs:
 		case <-signalIssueSLOKs:
@@ -355,7 +355,7 @@ func TestOSL(t *testing.T) {
 
 
 		rolloverToNextSLOKTime()
 		rolloverToNextSLOKTime()
 
 
-		clientSeedState := config.NewClientSeedState("US", "36F1CF2DF1250BF0C7BA0629CE3DC657", nil, lookupASN)
+		clientSeedState := config.NewClientSeedState("US", "36F1CF2DF1250BF0C7BA0629CE3DC657", nil)
 
 
 		if len(clientSeedState.GetSeedPayload().SLOKs) != 1 {
 		if len(clientSeedState.GetSeedPayload().SLOKs) != 1 {
 			t.Fatalf("expected 1 SLOKs, got %d", len(clientSeedState.GetSeedPayload().SLOKs))
 			t.Fatalf("expected 1 SLOKs, got %d", len(clientSeedState.GetSeedPayload().SLOKs))
@@ -366,9 +366,9 @@ func TestOSL(t *testing.T) {
 
 
 		rolloverToNextSLOKTime()
 		rolloverToNextSLOKTime()
 
 
-		clientSeedState := config.NewClientSeedState("US", "B4A780E67695595FA486E9B900EA7335", nil, lookupASN)
+		clientSeedState := config.NewClientSeedState("US", "B4A780E67695595FA486E9B900EA7335", nil)
 
 
-		clientSeedPortForward := clientSeedState.NewClientSeedPortForward(net.ParseIP("192.168.0.1"))
+		clientSeedPortForward := clientSeedState.NewClientSeedPortForward(net.ParseIP("192.168.0.1"), lookupASN)
 
 
 		clientSeedPortForward.UpdateProgress(10, 10, 10)
 		clientSeedPortForward.UpdateProgress(10, 10, 10)
 
 

+ 2 - 1
psiphon/server/geoip.go

@@ -209,7 +209,8 @@ func (geoIP *GeoIPService) LookupIP(IP net.IP) GeoIPData {
 
 
 // LookupISPForIP determines a GeoIPData for a given client IP address. Only
 // LookupISPForIP determines a GeoIPData for a given client IP address. Only
 // ISP, ASN, and ASO fields will be populated. This lookup is faster than a
 // ISP, ASN, and ASO fields will be populated. This lookup is faster than a
-// full lookup.
+// full lookup. Benchmarks show this lookup is <= ~1 microsecond against the
+// production geo IP database.
 func (geoIP *GeoIPService) LookupISPForIP(IP net.IP) GeoIPData {
 func (geoIP *GeoIPService) LookupISPForIP(IP net.IP) GeoIPData {
 	return geoIP.lookupIP(IP, true)
 	return geoIP.lookupIP(IP, true)
 }
 }

+ 8 - 7
psiphon/server/tunnelServer.go

@@ -3753,15 +3753,10 @@ func (sshClient *sshClient) setOSLConfig() {
 	//    port forwards will not send progress to the new client
 	//    port forwards will not send progress to the new client
 	//    seed state.
 	//    seed state.
 
 
-	lookupASN := func(IPAddress net.IP) string {
-		return sshClient.sshServer.support.GeoIPService.LookupISPForIP(IPAddress).ASN
-	}
-
 	sshClient.oslClientSeedState = sshClient.sshServer.support.OSLConfig.NewClientSeedState(
 	sshClient.oslClientSeedState = sshClient.sshServer.support.OSLConfig.NewClientSeedState(
 		sshClient.geoIPData.Country,
 		sshClient.geoIPData.Country,
 		propagationChannelID,
 		propagationChannelID,
-		sshClient.signalIssueSLOKs,
-		lookupASN)
+		sshClient.signalIssueSLOKs)
 }
 }
 
 
 // newClientSeedPortForward will return nil when no seeding is
 // newClientSeedPortForward will return nil when no seeding is
@@ -3775,7 +3770,13 @@ func (sshClient *sshClient) newClientSeedPortForward(IPAddress net.IP) *osl.Clie
 		return nil
 		return nil
 	}
 	}
 
 
-	return sshClient.oslClientSeedState.NewClientSeedPortForward(IPAddress)
+	lookupASN := func(IP net.IP) string {
+		// TODO: there are potentially multiple identical geo IP lookups per new
+		// port forward and flow, cache and use result of first lookup.
+		return sshClient.sshServer.support.GeoIPService.LookupISPForIP(IP).ASN
+	}
+
+	return sshClient.oslClientSeedState.NewClientSeedPortForward(IPAddress, lookupASN)
 }
 }
 
 
 // getOSLSeedPayload returns a payload containing all seeded SLOKs for
 // getOSLSeedPayload returns a payload containing all seeded SLOKs for