Przeglądaj źródła

Switch from special cases to general HandshakeParameters

Rod Hynes 9 lat temu
rodzic
commit
dfd9a39d55
2 zmienionych plików z 17 dodań i 55 usunięć
  1. 5 4
      psiphon/server/config.go
  2. 12 51
      psiphon/server/trafficRules.go

+ 5 - 4
psiphon/server/config.go

@@ -386,7 +386,8 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 
 	// Web server config
 
-	var webServerSecret, webServerCertificate, webServerPrivateKey string
+	var webServerSecret, webServerCertificate,
+		webServerPrivateKey, webServerPortForwardAddress string
 
 	if params.WebServerPort != 0 {
 		var err error
@@ -399,10 +400,10 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 		if err != nil {
 			return nil, nil, nil, common.ContextError(err)
 		}
-	}
 
-	webServerPortForwardAddress := net.JoinHostPort(
-		params.ServerIPAddress, strconv.Itoa(params.WebServerPort))
+		webServerPortForwardAddress = net.JoinHostPort(
+			params.ServerIPAddress, strconv.Itoa(params.WebServerPort))
+	}
 
 	// SSH config
 

+ 12 - 51
psiphon/server/trafficRules.go

@@ -22,7 +22,6 @@ package server
 import (
 	"encoding/json"
 	"io/ioutil"
-	"strconv"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
@@ -46,7 +45,7 @@ type TrafficRulesSet struct {
 	// For each client, the first matching Filter in FilteredTrafficRules
 	// determines the additional Rules that are selected and applied
 	// on top of DefaultRules.
-	FilteredTrafficRules []struct {
+	FilteredRules []struct {
 		Filter TrafficRulesFilter
 		Rules  TrafficRules
 	}
@@ -69,25 +68,10 @@ type TrafficRulesFilter struct {
 	// When omitted or blank, any API protocol matches.
 	APIProtocol string
 
-	// SponsorIDs is a list of client handshake sponsor IDs that must be
-	// specified to match this filter. When omitted or empty, any client
-	// sponsor ID matches.
-	SponsorIDs []string
-
-	// PropagationChannelIDs is a list of client handshake propagation
-	// channel IDs that must be specified to match this filter. When
-	// omitted or empty, any propagation channel ID matches.
-	PropagationChannelIDs []string
-
-	// MinClientVersion is a minimum client handshake version number that
-	// must be specified to match this filter. When omitted or empty, any
-	// client version matches.
-	MinClientVersion *int
-
-	// MaxClientVersion is a maximum client handshake version number that
-	// must be specified to match this filter. When omitted or empty, any
-	// client version matches.
-	MaxClientVersion *int
+	// HandshakeParameters specifies handshake API parameter names and
+	// a list of values, one of which must be specified to match this
+	// filter. Only scalar string API parameters may be filtered.
+	HandshakeParameters map[string][]string
 }
 
 // TrafficRules specify the limits placed on client traffic.
@@ -275,7 +259,7 @@ func (set *TrafficRulesSet) GetTrafficRules(
 	}
 
 	// TODO: faster lookup?
-	for _, filteredRules := range set.FilteredTrafficRules {
+	for _, filteredRules := range set.FilteredRules {
 
 		if len(filteredRules.Filter.Protocols) > 0 {
 			if !common.Contains(filteredRules.Filter.Protocols, tunnelProtocol) {
@@ -298,39 +282,16 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			}
 		}
 
-		// Note: ignoring param format errors as params have been validated
-
-		if len(filteredRules.Filter.SponsorIDs) > 0 {
+		if filteredRules.Filter.HandshakeParameters != nil {
 			if !state.completed {
 				continue
 			}
-			sponsorID, _ := getStringRequestParam(state.apiParams, "sponsor_id")
-			if !common.Contains(filteredRules.Filter.SponsorIDs, sponsorID) {
-				continue
-			}
-		}
 
-		if len(filteredRules.Filter.PropagationChannelIDs) > 0 {
-			if !state.completed {
-				continue
-			}
-			propagationChannelID, _ := getStringRequestParam(state.apiParams, "propagation_channel_id")
-			if !common.Contains(filteredRules.Filter.PropagationChannelIDs, propagationChannelID) {
-				continue
-			}
-		}
-
-		if filteredRules.Filter.MinClientVersion != nil || filteredRules.Filter.MaxClientVersion != nil {
-			if !state.completed {
-				continue
-			}
-			clientVersionStr, _ := getStringRequestParam(state.apiParams, "client_version")
-			clientVersion, _ := strconv.Atoi(clientVersionStr)
-			if filteredRules.Filter.MinClientVersion != nil && clientVersion < *filteredRules.Filter.MinClientVersion {
-				continue
-			}
-			if filteredRules.Filter.MaxClientVersion != nil && clientVersion > *filteredRules.Filter.MaxClientVersion {
-				continue
+			for name, values := range filteredRules.Filter.HandshakeParameters {
+				clientValue, err := getStringRequestParam(state.apiParams, name)
+				if err != nil || !common.Contains(values, clientValue) {
+					continue
+				}
 			}
 		}