Procházet zdrojové kódy

New traffic rules logic
- Determine client rules by starting with default rules and applying
a filter to select first matching specific case rules.
- Filter may use handshake API params in addition to client geoIP
and tunnel protocol attributes.

Rod Hynes před 9 roky
rodič
revize
1134323109

+ 2 - 1
psiphon/server/api.go

@@ -133,7 +133,8 @@ func handshakeAPIRequestHandler(
 			params,
 			baseRequestParams))
 
-	// Note: ignoring errors as params are validated
+	// Note: ignoring param format errors as params have been validated
+
 	sessionID, _ := getStringRequestParam(params, "client_session_id")
 	sponsorID, _ := getStringRequestParam(params, "sponsor_id")
 	clientVersion, _ := getStringRequestParam(params, "client_version")

+ 13 - 9
psiphon/server/config.go

@@ -514,18 +514,22 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 		return nil, nil, nil, common.ContextError(err)
 	}
 
+	intPtr := func(i int) *int {
+		return &i
+	}
+
 	trafficRulesSet := &TrafficRulesSet{
 		DefaultRules: TrafficRules{
-			DefaultLimits: common.RateLimits{
-				ReadUnthrottledBytes:  0,
-				ReadBytesPerSecond:    0,
-				WriteUnthrottledBytes: 0,
-				WriteBytesPerSecond:   0,
+			RateLimits: RateLimits{
+				ReadUnthrottledBytes:  new(int64),
+				ReadBytesPerSecond:    new(int64),
+				WriteUnthrottledBytes: new(int64),
+				WriteBytesPerSecond:   new(int64),
 			},
-			IdleTCPPortForwardTimeoutMilliseconds: 30000,
-			IdleUDPPortForwardTimeoutMilliseconds: 30000,
-			MaxTCPPortForwardCount:                1024,
-			MaxUDPPortForwardCount:                32,
+			IdleTCPPortForwardTimeoutMilliseconds: intPtr(30000),
+			IdleUDPPortForwardTimeoutMilliseconds: intPtr(30000),
+			MaxTCPPortForwardCount:                intPtr(1024),
+			MaxUDPPortForwardCount:                intPtr(32),
 			AllowTCPPorts:                         nil,
 			AllowUDPPorts:                         nil,
 			DenyTCPPorts:                          nil,

+ 2 - 2
psiphon/server/services.go

@@ -132,9 +132,9 @@ loop:
 		select {
 		case <-reloadSupportServicesSignal:
 			supportServices.Reload()
-			// Reselect traffic rules for established clients to reflect reloaded config
+			// Reset traffic rules for established clients to reflect reloaded config
 			// TODO: only update when traffic rules config has changed
-			tunnelServer.SelectAllClientTrafficRules()
+			tunnelServer.ResetAllClientTrafficRules()
 		case <-logServerLoadSignal:
 			logServerLoad(tunnelServer)
 		case <-systemStopSignal:

+ 253 - 42
psiphon/server/trafficRules.go

@@ -22,7 +22,7 @@ package server
 import (
 	"encoding/json"
 	"io/ioutil"
-	"strings"
+	"strconv"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
@@ -30,55 +30,89 @@ import (
 // TrafficRulesSet represents the various traffic rules to
 // apply to Psiphon client tunnels. The Reload function supports
 // hot reloading of rules data while the server is running.
+//
+// For a given client, the traffic rules are determined by starting
+// with DefaultRules, then finding the first (if any)
+// FilteredTrafficRules match and overriding the defaults with fields
+// set in the selected FilteredTrafficRules.
 type TrafficRulesSet struct {
 	common.ReloadableFile
 
-	// DefaultRules specifies the traffic rules to be used when no
-	// regional-specific rules are set or apply to a particular
-	// client.
+	// DefaultRules are the base values to use as defaults for all
+	// clients.
 	DefaultRules TrafficRules
 
-	// RegionalRules specifies the traffic rules for particular client
-	// regions (countries) as determined by GeoIP lookup of the client
-	// IP address. The key for each regional traffic rule entry is one
-	// or more space delimited ISO 3166-1 alpha-2 country codes.
-	RegionalRules map[string]TrafficRules
+	// FilteredTrafficRules is an ordered list of filter/rules pairs.
+	// For each client, the first matching Filter in FilteredTrafficRules
+	// determines the additional Rules that are selected and applied
+	// on top of DefaultRules.
+	FilteredTrafficRules []struct {
+		Filter TrafficRulesFilter
+		Rules  TrafficRules
+	}
+}
+
+// TrafficRulesFilter defines a filter to match against client attributes.
+type TrafficRulesFilter struct {
+
+	// Protocols is a list of client tunnel protocols that must be in use
+	// to match this filter. When omitted or empty, any protocol matches.
+	Protocols []string
+
+	// Regions is a list of client GeoIP countries that the client must
+	// reolve to to match this filter. When omitted or empty, any client
+	// region matches.
+	Regions []string
+
+	// SponsorIDs is a list of client handshake sponsor IDs that must be
+	// specified to match this filter. When omitted or empty, any client
+	// sponsor ID matches.
+	SponsorIDs []string
+
+	// PropagationChannelIDs is a list of client handshake propagation
+	// channel IDs that must be specified to match this filter. When
+	// omitted or empty, any propagation channel ID matches.
+	PropagationChannelIDs []string
+
+	// MinClientVersion is a minimum client handshake version number that
+	// must be specified to match this filter. When omitted or empty, any
+	// client version matches.
+	MinClientVersion *int
+
+	// MaxClientVersion is a maximum client handshake version number that
+	// must be specified to match this filter. When omitted or empty, any
+	// client version matches.
+	MaxClientVersion *int
 }
 
 // TrafficRules specify the limits placed on client traffic.
 type TrafficRules struct {
-	// DefaultLimits are the rate limits to be applied when
-	// no protocol-specific rates are set.
-	DefaultLimits common.RateLimits
 
-	// ProtocolLimits specifies the rate limits for particular
-	// tunnel protocols. The key for each rate limit entry is one
-	// or more space delimited Psiphon tunnel protocol names. Valid
-	// tunnel protocols includes the same list as for
-	// TunnelProtocolPorts.
-	ProtocolLimits map[string]common.RateLimits
+	// RateLimits specifies data transfer rate limits for the
+	// client traffic.
+	RateLimits RateLimits
 
 	// IdleTCPPortForwardTimeoutMilliseconds is the timeout period
 	// after which idle (no bytes flowing in either direction)
 	// client TCP port forwards are preemptively closed.
 	// The default, 0, is no idle timeout.
-	IdleTCPPortForwardTimeoutMilliseconds int
+	IdleTCPPortForwardTimeoutMilliseconds *int
 
 	// IdleUDPPortForwardTimeoutMilliseconds is the timeout period
 	// after which idle (no bytes flowing in either direction)
 	// client UDP port forwards are preemptively closed.
 	// The default, 0, is no idle timeout.
-	IdleUDPPortForwardTimeoutMilliseconds int
+	IdleUDPPortForwardTimeoutMilliseconds *int
 
 	// MaxTCPPortForwardCount is the maximum number of TCP port
 	// forwards each client may have open concurrently.
 	// The default, 0, is no maximum.
-	MaxTCPPortForwardCount int
+	MaxTCPPortForwardCount *int
 
 	// MaxUDPPortForwardCount is the maximum number of UDP port
 	// forwards each client may have open concurrently.
 	// The default, 0, is no maximum.
-	MaxUDPPortForwardCount int
+	MaxUDPPortForwardCount *int
 
 	// AllowTCPPorts specifies a whitelist of TCP ports that
 	// are permitted for port forwarding. When set, only ports
@@ -101,6 +135,29 @@ type TrafficRules struct {
 	DenyUDPPorts []int
 }
 
+// RateLimits is a clone of common.RateLimits with pointers
+// to fields to enable distinguishing between zero values and
+// omitted values in JSON serialized traffic rules.
+// See common.RateLimits for field descriptions.
+type RateLimits struct {
+	ReadUnthrottledBytes  *int64
+	ReadBytesPerSecond    *int64
+	WriteUnthrottledBytes *int64
+	WriteBytesPerSecond   *int64
+	CloseAfterExhausted   *bool
+}
+
+// CommonRateLimits converts a RateLimits to a common.RateLimits.
+func (rateLimits *RateLimits) CommonRateLimits() common.RateLimits {
+	return common.RateLimits{
+		ReadUnthrottledBytes:  *rateLimits.ReadUnthrottledBytes,
+		ReadBytesPerSecond:    *rateLimits.ReadBytesPerSecond,
+		WriteUnthrottledBytes: *rateLimits.WriteUnthrottledBytes,
+		WriteBytesPerSecond:   *rateLimits.WriteBytesPerSecond,
+		CloseAfterExhausted:   *rateLimits.CloseAfterExhausted,
+	}
+}
+
 // NewTrafficRulesSet initializes a TrafficRulesSet with
 // the rules data in the specified config file.
 func NewTrafficRulesSet(filename string) (*TrafficRulesSet, error) {
@@ -133,35 +190,189 @@ func NewTrafficRulesSet(filename string) (*TrafficRulesSet, error) {
 	return set, nil
 }
 
-// GetTrafficRules looks up the traffic rules for the specified country. If there
-// are no regional TrafficRules for the country, default TrafficRules are returned.
-func (set *TrafficRulesSet) GetTrafficRules(clientCountryCode string) TrafficRules {
+// GetTrafficRules determines the traffic rules for a client based on its attributes.
+// For the return value TrafficRules, all pointer and slice fields are initialized,
+// so nil checks are not required. The caller must not modify the returned TrafficRules.
+func (set *TrafficRulesSet) GetTrafficRules(
+	tunnelProtocol string, geoIPData GeoIPData, state handshakeState) TrafficRules {
+
 	set.ReloadableFile.RLock()
 	defer set.ReloadableFile.RUnlock()
 
+	// Start with a copy of the DefaultRules, and then select the first
+	// matches Rules from FilteredTrafficRules, taking only the explicitly
+	// specified fields from that Rules.
+	//
+	// Notes:
+	// - Scalar pointers are used in TrafficRules and RateLimits to distinguish between
+	//   omitted fields (in serialized JSON) and default values. For example, if a filtered
+	//   Rules specifies a field value of 0, this will override the default; but if the
+	//   serialized filtered rule omits the field, the default is to be retained.
+	// - We use shallow copies and slices and scalar pointers are shared between the
+	//   return value TrafficRules, so callers must treat the return value as immutable.
+	//   This also means that these slices and pointers can remain referenced in memory even
+	//   after a hot reload.
+
+	trafficRules := set.DefaultRules
+
+	// Populate defaults for omitted DefaultRules fields
+
+	if trafficRules.RateLimits.ReadUnthrottledBytes == nil {
+		trafficRules.RateLimits.ReadUnthrottledBytes = new(int64)
+	}
+
+	if trafficRules.RateLimits.ReadBytesPerSecond == nil {
+		trafficRules.RateLimits.ReadBytesPerSecond = new(int64)
+	}
+
+	if trafficRules.RateLimits.WriteUnthrottledBytes == nil {
+		trafficRules.RateLimits.WriteUnthrottledBytes = new(int64)
+	}
+
+	if trafficRules.RateLimits.WriteBytesPerSecond == nil {
+		trafficRules.RateLimits.WriteBytesPerSecond = new(int64)
+	}
+
+	if trafficRules.RateLimits.CloseAfterExhausted == nil {
+		trafficRules.RateLimits.CloseAfterExhausted = new(bool)
+	}
+
+	if trafficRules.IdleTCPPortForwardTimeoutMilliseconds == nil {
+		trafficRules.IdleTCPPortForwardTimeoutMilliseconds = new(int)
+	}
+
+	if trafficRules.IdleUDPPortForwardTimeoutMilliseconds == nil {
+		trafficRules.IdleUDPPortForwardTimeoutMilliseconds = new(int)
+	}
+
+	if trafficRules.MaxTCPPortForwardCount == nil {
+		trafficRules.MaxTCPPortForwardCount = new(int)
+	}
+
+	if trafficRules.MaxUDPPortForwardCount == nil {
+		trafficRules.MaxUDPPortForwardCount = new(int)
+	}
+
+	if trafficRules.AllowTCPPorts == nil {
+		trafficRules.AllowTCPPorts = make([]int, 0)
+	}
+
+	if trafficRules.AllowUDPPorts == nil {
+		trafficRules.AllowUDPPorts = make([]int, 0)
+	}
+
+	if trafficRules.DenyTCPPorts == nil {
+		trafficRules.DenyTCPPorts = make([]int, 0)
+	}
+
+	if trafficRules.DenyUDPPorts == nil {
+		trafficRules.DenyUDPPorts = make([]int, 0)
+	}
+
 	// TODO: faster lookup?
-	for countryCodes, trafficRules := range set.RegionalRules {
-		for _, countryCode := range strings.Split(countryCodes, " ") {
-			if countryCode == clientCountryCode {
-				return trafficRules
+	for _, filteredRules := range set.FilteredTrafficRules {
+
+		if len(filteredRules.Filter.Protocols) > 0 {
+			if !common.Contains(filteredRules.Filter.Protocols, tunnelProtocol) {
+				continue
+			}
+		}
+
+		if len(filteredRules.Filter.Regions) > 0 {
+			if !common.Contains(filteredRules.Filter.Regions, geoIPData.Country) {
+				continue
 			}
 		}
-	}
-	return set.DefaultRules
-}
 
-// GetRateLimits looks up the rate limits for the specified tunnel protocol.
-// If there are no specific RateLimits for the protocol, default RateLimits are
-// returned.
-func (rules *TrafficRules) GetRateLimits(clientTunnelProtocol string) common.RateLimits {
+		// Note: ignoring param format errors as params have been validated
 
-	// TODO: faster lookup?
-	for tunnelProtocols, rateLimits := range rules.ProtocolLimits {
-		for _, tunnelProtocol := range strings.Split(tunnelProtocols, " ") {
-			if tunnelProtocol == clientTunnelProtocol {
-				return rateLimits
+		if len(filteredRules.Filter.SponsorIDs) > 0 {
+			if !state.completed {
+				continue
+			}
+			sponsorID, _ := getStringRequestParam(state.apiParams, "sponsor_id")
+			if !common.Contains(filteredRules.Filter.SponsorIDs, sponsorID) {
+				continue
+			}
+		}
+
+		if len(filteredRules.Filter.PropagationChannelIDs) > 0 {
+			if !state.completed {
+				continue
+			}
+			propagationChannelID, _ := getStringRequestParam(state.apiParams, "propagation_channel_id")
+			if !common.Contains(filteredRules.Filter.PropagationChannelIDs, propagationChannelID) {
+				continue
 			}
 		}
+
+		if filteredRules.Filter.MinClientVersion != nil || filteredRules.Filter.MaxClientVersion != nil {
+			clientVersionStr, _ := getStringRequestParam(state.apiParams, "client_version")
+			clientVersion, _ := strconv.Atoi(clientVersionStr)
+			if filteredRules.Filter.MinClientVersion != nil && clientVersion < *filteredRules.Filter.MinClientVersion {
+				continue
+			}
+			if filteredRules.Filter.MaxClientVersion != nil && clientVersion > *filteredRules.Filter.MaxClientVersion {
+				continue
+			}
+		}
+
+		// This is the first match. Override defaults using provided fields from selected rules, and return result.
+
+		if filteredRules.Rules.RateLimits.ReadUnthrottledBytes != nil {
+			trafficRules.RateLimits.ReadUnthrottledBytes = filteredRules.Rules.RateLimits.ReadUnthrottledBytes
+		}
+
+		if filteredRules.Rules.RateLimits.ReadBytesPerSecond != nil {
+			trafficRules.RateLimits.ReadBytesPerSecond = filteredRules.Rules.RateLimits.ReadBytesPerSecond
+		}
+
+		if filteredRules.Rules.RateLimits.WriteUnthrottledBytes != nil {
+			trafficRules.RateLimits.WriteUnthrottledBytes = filteredRules.Rules.RateLimits.WriteUnthrottledBytes
+		}
+
+		if filteredRules.Rules.RateLimits.WriteBytesPerSecond != nil {
+			trafficRules.RateLimits.WriteBytesPerSecond = filteredRules.Rules.RateLimits.WriteBytesPerSecond
+		}
+
+		if filteredRules.Rules.RateLimits.CloseAfterExhausted != nil {
+			trafficRules.RateLimits.CloseAfterExhausted = filteredRules.Rules.RateLimits.CloseAfterExhausted
+		}
+
+		if filteredRules.Rules.IdleTCPPortForwardTimeoutMilliseconds != nil {
+			trafficRules.IdleTCPPortForwardTimeoutMilliseconds = filteredRules.Rules.IdleTCPPortForwardTimeoutMilliseconds
+		}
+
+		if filteredRules.Rules.IdleUDPPortForwardTimeoutMilliseconds != nil {
+			trafficRules.IdleUDPPortForwardTimeoutMilliseconds = filteredRules.Rules.IdleUDPPortForwardTimeoutMilliseconds
+		}
+
+		if filteredRules.Rules.MaxTCPPortForwardCount != nil {
+			trafficRules.MaxTCPPortForwardCount = filteredRules.Rules.MaxTCPPortForwardCount
+		}
+
+		if filteredRules.Rules.MaxUDPPortForwardCount != nil {
+			trafficRules.MaxUDPPortForwardCount = filteredRules.Rules.MaxUDPPortForwardCount
+		}
+
+		if filteredRules.Rules.AllowTCPPorts != nil {
+			trafficRules.AllowTCPPorts = filteredRules.Rules.AllowTCPPorts
+		}
+
+		if filteredRules.Rules.AllowUDPPorts != nil {
+			trafficRules.AllowUDPPorts = filteredRules.Rules.AllowUDPPorts
+		}
+
+		if filteredRules.Rules.DenyTCPPorts != nil {
+			trafficRules.DenyTCPPorts = filteredRules.Rules.DenyTCPPorts
+		}
+
+		if filteredRules.Rules.DenyUDPPorts != nil {
+			trafficRules.DenyUDPPorts = filteredRules.Rules.DenyUDPPorts
+		}
+
+		break
 	}
-	return rules.DefaultLimits
+
+	return trafficRules
 }

+ 16 - 16
psiphon/server/tunnelServer.go

@@ -192,10 +192,10 @@ func (server *TunnelServer) GetLoadStats() map[string]map[string]int64 {
 	return server.sshServer.getLoadStats()
 }
 
-// SelectAllClientTrafficRules resets all established client traffic rules
+// ResetAllClientTrafficRules resets all established client traffic rules
 // to use the latest server config and client state.
-func (server *TunnelServer) SelectAllClientTrafficRules() {
-	server.sshServer.selectAllClientTrafficRules()
+func (server *TunnelServer) ResetAllClientTrafficRules() {
+	server.sshServer.resetAllClientTrafficRules()
 }
 
 // SetClientHandshakeState sets the handshake state -- that it completed and
@@ -424,7 +424,7 @@ func (sshServer *sshServer) getLoadStats() map[string]map[string]int64 {
 	return loadStats
 }
 
-func (sshServer *sshServer) selectAllClientTrafficRules() {
+func (sshServer *sshServer) resetAllClientTrafficRules() {
 
 	sshServer.clientsMutex.Lock()
 	clients := make(map[string]*sshClient)
@@ -434,7 +434,7 @@ func (sshServer *sshServer) selectAllClientTrafficRules() {
 	sshServer.clientsMutex.Unlock()
 
 	for _, client := range clients {
-		client.selectTrafficRules()
+		client.setTrafficRules()
 	}
 }
 
@@ -454,7 +454,7 @@ func (sshServer *sshServer) setClientHandshakeState(
 		return common.ContextError(err)
 	}
 
-	client.selectTrafficRules()
+	client.setTrafficRules()
 
 	return nil
 }
@@ -482,7 +482,8 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 
 	sshClient := newSshClient(sshServer, tunnelProtocol, geoIPData)
 
-	sshClient.selectTrafficRules()
+	// Set initial traffic rules, pre-handshake, based on currently known info.
+	sshClient.setTrafficRules()
 
 	// Wrap the base client connection with an ActivityMonitoredConn which will
 	// terminate the connection if no data is received before the deadline. This
@@ -901,30 +902,29 @@ func (sshClient *sshClient) setHandshakeState(state handshakeState) error {
 	return nil
 }
 
-// selectTrafficRules resets the client's traffic rules based on the latest server config
+// setTrafficRules resets the client's traffic rules based on the latest server config
 // and client state. As sshClient.trafficRules may be reset by a concurrent goroutine,
 // trafficRules must only be accessed within the sshClient mutex.
-func (sshClient *sshClient) selectTrafficRules() {
+func (sshClient *sshClient) setTrafficRules() {
 	sshClient.Lock()
 	defer sshClient.Unlock()
 
 	sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
-		// TODO: sshClient.geoIPData, sshClient.handshakeState)
-		sshClient.geoIPData.Country)
+		sshClient.tunnelProtocol, sshClient.geoIPData, sshClient.handshakeState)
 }
 
 func (sshClient *sshClient) rateLimits() common.RateLimits {
 	sshClient.Lock()
 	defer sshClient.Unlock()
 
-	return sshClient.trafficRules.GetRateLimits(sshClient.tunnelProtocol)
+	return sshClient.trafficRules.RateLimits.CommonRateLimits()
 }
 
 func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration {
 	sshClient.Lock()
 	defer sshClient.Unlock()
 
-	return time.Duration(sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds) * time.Millisecond
+	return time.Duration(*sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds) * time.Millisecond
 }
 
 func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration {
@@ -932,7 +932,7 @@ func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration {
 	sshClient.Lock()
 	defer sshClient.Unlock()
 
-	return time.Duration(sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds) * time.Millisecond
+	return time.Duration(*sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds) * time.Millisecond
 }
 
 const (
@@ -994,10 +994,10 @@ func (sshClient *sshClient) isPortForwardLimitExceeded(
 	var maxPortForwardCount int
 	var state *trafficState
 	if portForwardType == portForwardTypeTCP {
-		maxPortForwardCount = sshClient.trafficRules.MaxTCPPortForwardCount
+		maxPortForwardCount = *sshClient.trafficRules.MaxTCPPortForwardCount
 		state = &sshClient.tcpTrafficState
 	} else {
-		maxPortForwardCount = sshClient.trafficRules.MaxUDPPortForwardCount
+		maxPortForwardCount = *sshClient.trafficRules.MaxUDPPortForwardCount
 		state = &sshClient.udpTrafficState
 	}