Prechádzať zdrojové kódy

Merge branch 'master' into staging-client

Rod Hynes 3 rokov pred
rodič
commit
b3279cf727

+ 1 - 1
MobileLibrary/iOS/PsiphonTunnel/PsiphonTunnel.xcodeproj/project.pbxproj

@@ -57,7 +57,7 @@
 		CEDBA51325B7737C007685E2 /* NetworkInterface.m in Sources */ = {isa = PBXBuildFile; fileRef = CEDBA51125B7737C007685E2 /* NetworkInterface.m */; };
 		CEDE547924EBF5980053566E /* PsiphonProviderFeedbackHandlerShim.h in Headers */ = {isa = PBXBuildFile; fileRef = CEDE547724EBF5980053566E /* PsiphonProviderFeedbackHandlerShim.h */; };
 		CEDE547A24EBF5980053566E /* PsiphonProviderFeedbackHandlerShim.m in Sources */ = {isa = PBXBuildFile; fileRef = CEDE547824EBF5980053566E /* PsiphonProviderFeedbackHandlerShim.m */; };
-		CEFC764225B1F358003A2A52 /* Network.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = CEFC764125B1F358003A2A52 /* Network.framework */; };
+		CEFC764225B1F358003A2A52 /* Network.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = CEFC764125B1F358003A2A52 /* Network.framework */; settings = {ATTRIBUTES = (Weak, ); }; };
 		EFED7EBF1F587F6E0078980F /* libresolv.tbd in Frameworks */ = {isa = PBXBuildFile; fileRef = EFED7EBE1F587F6E0078980F /* libresolv.tbd */; };
 /* End PBXBuildFile section */
 

+ 26 - 62
MobileLibrary/iOS/PsiphonTunnel/PsiphonTunnel/PsiphonTunnel.m

@@ -114,7 +114,6 @@ typedef NS_ERROR_ENUM(PsiphonTunnelErrorDomain, PsiphonTunnelErrorCode) {
     id<ReachabilityProtocol> reachability;
     _Atomic NetworkReachability currentNetworkStatus;
 
-    BOOL tunnelWholeDevice;
     _Atomic BOOL usingNoticeFiles;
 
     // DNS
@@ -167,7 +166,6 @@ typedef NS_ERROR_ENUM(PsiphonTunnelErrorDomain, PsiphonTunnelErrorCode) {
         self->reachability = [Reachability reachabilityForInternetConnection];
     }
     atomic_init(&self->currentNetworkStatus, NetworkReachabilityNotReachable);
-    self->tunnelWholeDevice = FALSE;
     atomic_init(&self->usingNoticeFiles, FALSE);
 
     // Use the workaround, comma-delimited format required for gobind.
@@ -337,7 +335,7 @@ typedef NS_ERROR_ENUM(PsiphonTunnelErrorDomain, PsiphonTunnelErrorCode) {
                 embeddedServerEntries,
                 embeddedServerEntriesPath,
                 self,
-                self->tunnelWholeDevice, // useDeviceBinder
+                FALSE, // useDeviceBinder
                 UseIPv6Synthesizer,
                 UseHasIPv6RouteGetter,
                 &e);
@@ -539,7 +537,6 @@ typedef NS_ERROR_ENUM(PsiphonTunnelErrorDomain, PsiphonTunnelErrorCode) {
     NSError *err;
     NSString *psiphonConfig = [PsiphonTunnel buildPsiphonConfig:configObject
                                                usingNoticeFiles:usingNoticeFiles
-                                              tunnelWholeDevice:&self->tunnelWholeDevice
                                                       sessionID:self.sessionID
                                                      logMessage:logMessage
                                                           error:&err];
@@ -553,7 +550,6 @@ typedef NS_ERROR_ENUM(PsiphonTunnelErrorDomain, PsiphonTunnelErrorCode) {
 
 + (NSString * _Nullable)buildPsiphonConfig:(id _Nonnull)configObject
                           usingNoticeFiles:(BOOL * _Nonnull)usingNoticeFiles
-                         tunnelWholeDevice:(BOOL * _Nonnull)tunnelWholeDevice
                                  sessionID:(NSString * _Nonnull)sessionID
                                 logMessage:(void (^)(NSString * _Nonnull))logMessage
                                      error:(NSError *_Nullable *_Nonnull)outError {
@@ -794,9 +790,9 @@ typedef NS_ERROR_ENUM(PsiphonTunnelErrorDomain, PsiphonTunnelErrorCode) {
     //
 
     // We'll record our state about what mode we're in.
-    *tunnelWholeDevice = (config[@"PacketTunnelTunFileDescriptor"] != nil);
+    BOOL tunnelWholeDevice = (config[@"PacketTunnelTunFileDescriptor"] != nil);
 
-    // Other optional fields not being altered. If not set, their defaults will be used:
+    // Optional fields not being altered. If not set, their defaults will be used:
     // * LocalSocksProxyPort
     // * LocalHttpProxyPort
     // * UpstreamProxyUrl
@@ -823,10 +819,27 @@ typedef NS_ERROR_ENUM(PsiphonTunnelErrorDomain, PsiphonTunnelErrorCode) {
     // Indicate whether UseNoticeFiles is set
     *usingNoticeFiles = (config[@"UseNoticeFiles"] != nil);
 
-    // For iOS VPN, the standard library system resolver will automatically be
-    // routed outside the VPN.
-    if (*tunnelWholeDevice) {
-        config[@"AllowDefaultDNSResolverWithBindToDevice"] = @YES;
+    // For iOS VPN, set VPN client feature while preserving any present feature names
+    if (tunnelWholeDevice == TRUE) {
+        id oldClientFeatures = config[@"ClientFeatures"];
+        NSString *vpnClientFeature = @"VPN";
+        NSMutableArray<NSString*> *clientFeatures;
+
+        if (oldClientFeatures != nil) {
+            if (![oldClientFeatures isKindOfClass:[NSArray<NSString*> class]]) {
+                *outError = [NSError errorWithDomain:PsiphonTunnelErrorDomain
+                                                code:PsiphonTunnelErrorCodeConfigError
+                                            userInfo:@{NSLocalizedDescriptionKey:@"ClientFeatures not NSArray<String*>"}];
+                return nil;
+            }
+            clientFeatures = [NSMutableArray arrayWithArray:oldClientFeatures];
+            if (![clientFeatures containsObject:vpnClientFeature]) {
+                [clientFeatures addObject:vpnClientFeature];
+            }
+        } else {
+            clientFeatures = [NSMutableArray arrayWithObject:vpnClientFeature];
+        }
+        config[@"ClientFeatures"] = clientFeatures;
     }
 
     NSString *finalConfigStr = [[[SBJson4Writer alloc] init] stringWithObject:config];
@@ -1218,55 +1231,8 @@ typedef NS_ERROR_ENUM(PsiphonTunnelErrorDomain, PsiphonTunnelErrorCode) {
 
 - (NSString *)bindToDevice:(long)fileDescriptor error:(NSError **)error {
 
-    if (!self->tunnelWholeDevice) {
-        *error = [[NSError alloc] initWithDomain:@"iOSLibrary" code:1 userInfo:@{NSLocalizedDescriptionKey: @"bindToDevice: invalid mode"}];
-        return @"";
-    }
-
-    NSError *err;
-    NSString *activeInterface = [NetworkInterface getActiveInterfaceWithReachability:self->reachability
-                                                             andCurrentNetworkStatus:atomic_load(&self->currentNetworkStatus)
-                                                                               error:&err];
-    if (err != nil) {
-        NSString *localizedDescription = [NSString stringWithFormat:@"bindToDevice: error getting active interface %@", err.localizedDescription];
-        *error = [[NSError alloc] initWithDomain:@"iOSLibrary" code:1 userInfo:@{NSLocalizedDescriptionKey:localizedDescription}];
-        return @"";
-    }
-
-    unsigned int interfaceIndex = if_nametoindex([activeInterface UTF8String]);
-    if (interfaceIndex == 0) {
-        *error = [[NSError alloc] initWithDomain:NSPOSIXErrorDomain code:errno userInfo:@{NSLocalizedDescriptionKey: [NSString stringWithFormat:@"bindToDevice: if_nametoindex failed: %d", errno]}];
-        return @"";
-    }
-
-    struct sockaddr sa;
-    socklen_t len = sizeof(sa);
-    int ret = getsockname((int)fileDescriptor, &sa, &len);
-    if (ret != 0) {
-        *error = [[NSError alloc] initWithDomain:NSPOSIXErrorDomain code:errno userInfo:@{NSLocalizedDescriptionKey: [NSString stringWithFormat:@"bindToDevice: getsockname failed: %d", errno]}];
-        return @"";
-    }
-
-    int level = 0;
-    int optname = 0;
-    if (sa.sa_family == PF_INET) {
-        level = IPPROTO_IP;
-        optname = IP_BOUND_IF;
-    } else if (sa.sa_family == PF_INET6) {
-        level = IPPROTO_IPV6;
-        optname = IPV6_BOUND_IF;
-    } else {
-        *error = [[NSError alloc] initWithDomain:@"iOSLibrary" code:1 userInfo:@{NSLocalizedDescriptionKey: [NSString stringWithFormat:@"bindToDevice: unsupported domain: %d", (int)sa.sa_family]}];
-        return @"";
-    }
-
-    ret = setsockopt((int)fileDescriptor, level, optname, &interfaceIndex, sizeof(interfaceIndex));
-    if (ret != 0) {
-        *error = [[NSError alloc] initWithDomain:NSPOSIXErrorDomain code:errno userInfo:@{NSLocalizedDescriptionKey: [NSString stringWithFormat:@"bindToDevice: setsockopt failed: %d", errno]}];
-        return @"";
-    }
-    
-    return [NSString stringWithFormat:@"active interface: %@", activeInterface];
+    *error = [[NSError alloc] initWithDomain:@"iOSLibrary" code:1 userInfo:@{NSLocalizedDescriptionKey: @"bindToDevice: not supported"}];
+    return @"";
 }
 
 - (NSString *)getDNSServersAsString {
@@ -1679,11 +1645,9 @@ typedef NS_ERROR_ENUM(PsiphonTunnelErrorDomain, PsiphonTunnelErrorCode) {
         }
 
         BOOL usingNoticeFiles = FALSE;
-        BOOL tunnelWholeDevice = FALSE;
 
         NSString *psiphonConfig = [PsiphonTunnel buildPsiphonConfig:feedbackConfigJson
                                                    usingNoticeFiles:&usingNoticeFiles
-                                                  tunnelWholeDevice:&tunnelWholeDevice
                                                           sessionID:sessionID
                                                          logMessage:logMessage
                                                               error:&err];

+ 15 - 10
psiphon/common/activity.go

@@ -144,23 +144,28 @@ func (conn *ActivityMonitoredConn) Read(buffer []byte) (int, error) {
 
 func (conn *ActivityMonitoredConn) Write(buffer []byte) (int, error) {
 	n, err := conn.Conn.Write(buffer)
-	if n > 0 && conn.activeOnWrite {
+	if n > 0 {
 
-		if conn.inactivityTimeout > 0 {
-			err = conn.Conn.SetDeadline(time.Now().Add(conn.inactivityTimeout))
-			if err != nil {
-				return n, errors.Trace(err)
-			}
-		}
+		// Bytes written are reported regardless of activeOnWrite. Inactivity
+		// deadline extension and LRU updates are conditional on activeOnWrite.
 
 		for _, activityUpdater := range conn.activityUpdaters {
 			activityUpdater.UpdateProgress(0, int64(n), 0)
 		}
 
-		if conn.lruEntry != nil {
-			conn.lruEntry.Touch()
-		}
+		if conn.activeOnWrite {
 
+			if conn.inactivityTimeout > 0 {
+				err = conn.Conn.SetDeadline(time.Now().Add(conn.inactivityTimeout))
+				if err != nil {
+					return n, errors.Trace(err)
+				}
+			}
+
+			if conn.lruEntry != nil {
+				conn.lruEntry.Touch()
+			}
+		}
 	}
 	// Note: no trace error to preserve error type
 	return n, err

+ 1 - 3
psiphon/common/prng/prng.go

@@ -18,7 +18,6 @@
  */
 
 /*
-
 Package prng implements a seeded, unbiased PRNG that is suitable for use
 cases including obfuscation, network jitter, load balancing.
 
@@ -42,7 +41,6 @@ required for replay.
 
 PRNG conforms to io.Reader and math/rand.Source, with additional helper
 functions.
-
 */
 package prng
 
@@ -266,7 +264,7 @@ func (p *PRNG) ExpFloat64Range(min, max, lambda float64) float64 {
 	return value
 }
 
-// Intn is equivilent to math/rand.Perm.
+// Perm is equivilent to math/rand.Perm.
 func (p *PRNG) Perm(n int) []int {
 	return p.rand.Perm(n)
 }

+ 18 - 11
psiphon/common/protocol/serverEntry.go

@@ -487,7 +487,7 @@ func (serverEntry *ServerEntry) SupportsProtocol(protocol string) bool {
 // ProtocolUsesLegacyPassthrough indicates whether the ServerEntry supports
 // the specified protocol using legacy passthrough messages.
 //
-// There is no correspondong check for v2 passthrough, as clients send v2
+// There is no corresponding check for v2 passthrough, as clients send v2
 // passthrough messages unconditionally, by default, for passthrough
 // protocols.
 func (serverEntry *ServerEntry) ProtocolUsesLegacyPassthrough(protocol string) bool {
@@ -725,23 +725,31 @@ func TagToDiagnosticID(tag string) string {
 // EncodeServerEntry returns a string containing the encoding of
 // a ServerEntry following Psiphon conventions.
 func EncodeServerEntry(serverEntry *ServerEntry) (string, error) {
-	return encodeServerEntry(
+	encodedServerEntry, err := encodeServerEntry(
 		serverEntry.IpAddress,
 		serverEntry.WebServerPort,
 		serverEntry.WebServerSecret,
 		serverEntry.WebServerCertificate,
 		serverEntry)
+	if err != nil {
+		return "", errors.Trace(err)
+	}
+	return encodedServerEntry, nil
 }
 
 // EncodeServerEntryFields returns a string containing the encoding of
 // ServerEntryFields following Psiphon conventions.
 func EncodeServerEntryFields(serverEntryFields ServerEntryFields) (string, error) {
-	return encodeServerEntry(
+	encodedServerEntry, err := encodeServerEntry(
 		serverEntryFields.GetIPAddress(),
 		serverEntryFields.GetWebServerPort(),
 		serverEntryFields.GetWebServerSecret(),
 		serverEntryFields.GetWebServerCertificate(),
 		serverEntryFields)
+	if err != nil {
+		return "", errors.Trace(err)
+	}
+	return encodedServerEntry, nil
 }
 
 func encodeServerEntry(
@@ -940,14 +948,13 @@ func NewStreamingServerEntryDecoder(
 // input stream, returning a nil server entry when the stream is complete.
 //
 // Limitations:
-// - Each encoded server entry line cannot exceed bufio.MaxScanTokenSize,
-//   the default buffer size which this decoder uses. This is 64K.
-// - DecodeServerEntry is called on each encoded server entry line, which
-//   will allocate memory to hex decode and JSON deserialze the server
-//   entry. As this is not presently reusing a fixed buffer, each call
-//   will allocate additional memory; garbage collection is necessary to
-//   reclaim that memory for reuse for the next server entry.
-//
+//   - Each encoded server entry line cannot exceed bufio.MaxScanTokenSize,
+//     the default buffer size which this decoder uses. This is 64K.
+//   - DecodeServerEntry is called on each encoded server entry line, which
+//     will allocate memory to hex decode and JSON deserialze the server
+//     entry. As this is not presently reusing a fixed buffer, each call
+//     will allocate additional memory; garbage collection is necessary to
+//     reclaim that memory for reuse for the next server entry.
 func (decoder *StreamingServerEntryDecoder) Next() (ServerEntryFields, error) {
 
 	for {

+ 59 - 53
psiphon/common/tun/tun.go

@@ -22,57 +22,55 @@
 // license that can be found in the LICENSE file.
 
 /*
-
 Package tun is an IP packet tunnel server and client. It supports tunneling
 both IPv4 and IPv6.
 
- .........................................................       .-,(  ),-.
- . [server]                                     .-----.  .    .-(          )-.
- .                                              | NIC |<---->(    Internet    )
- . .......................................      '-----'  .    '-(          ).-'
- . . [packet tunnel daemon]              .         ^     .        '-.( ).-'
- . .                                     .         |     .
- . . ...........................         .         |     .
- . . . [session]               .         .        NAT    .
- . . .                         .         .         |     .
- . . .                         .         .         v     .
- . . .                         .         .       .---.   .
- . . .                         .         .       | t |   .
- . . .                         .         .       | u |   .
- . . .                 .---.   .  .---.  .       | n |   .
- . . .                 | q |   .  | d |  .       |   |   .
- . . .                 | u |   .  | e |  .       | d |   .
- . . .          .------| e |<-----| m |<---------| e |   .
- . . .          |      | u |   .  | u |  .       | v |   .
- . . .          |      | e |   .  | x |  .       | i |   .
- . . .       rewrite   '---'   .  '---'  .       | c |   .
- . . .          |              .         .       | e |   .
- . . .          v              .         .       '---'   .
- . . .     .---------.         .         .         ^     .
- . . .     | channel |--rewrite--------------------'     .
- . . .     '---------'         .         .               .
- . . ...........^...............         .               .
- . .............|.........................               .
- ...............|.........................................
-                |
-                | (typically via Internet)
-                |
- ...............|.................
- . [client]     |                .
- .              |                .
- . .............|............... .
- . .            v              . .
- . .       .---------.         . .
- . .       | channel |         . .
- . .       '---------'         . .
- . .            ^              . .
- . .............|............... .
- .              v                .
- .        .------------.         .
- .        | tun device |         .
- .        '------------'         .
- .................................
-
+	.........................................................       .-,(  ),-.
+	. [server]                                     .-----.  .    .-(          )-.
+	.                                              | NIC |<---->(    Internet    )
+	. .......................................      '-----'  .    '-(          ).-'
+	. . [packet tunnel daemon]              .         ^     .        '-.( ).-'
+	. .                                     .         |     .
+	. . ...........................         .         |     .
+	. . . [session]               .         .        NAT    .
+	. . .                         .         .         |     .
+	. . .                         .         .         v     .
+	. . .                         .         .       .---.   .
+	. . .                         .         .       | t |   .
+	. . .                         .         .       | u |   .
+	. . .                 .---.   .  .---.  .       | n |   .
+	. . .                 | q |   .  | d |  .       |   |   .
+	. . .                 | u |   .  | e |  .       | d |   .
+	. . .          .------| e |<-----| m |<---------| e |   .
+	. . .          |      | u |   .  | u |  .       | v |   .
+	. . .          |      | e |   .  | x |  .       | i |   .
+	. . .       rewrite   '---'   .  '---'  .       | c |   .
+	. . .          |              .         .       | e |   .
+	. . .          v              .         .       '---'   .
+	. . .     .---------.         .         .         ^     .
+	. . .     | channel |--rewrite--------------------'     .
+	. . .     '---------'         .         .               .
+	. . ...........^...............         .               .
+	. .............|.........................               .
+	...............|.........................................
+	               |
+	               | (typically via Internet)
+	               |
+	...............|.................
+	. [client]     |                .
+	.              |                .
+	. .............|............... .
+	. .            v              . .
+	. .       .---------.         . .
+	. .       | channel |         . .
+	. .       '---------'         . .
+	. .            ^              . .
+	. .............|............... .
+	.              v                .
+	.        .------------.         .
+	.        | tun device |         .
+	.        '------------'         .
+	.................................
 
 The client relays IP packets between a local tun device and a channel, which
 is a transport to the server. In Psiphon, the channel will be an SSH channel
@@ -120,7 +118,6 @@ channel and negotiating the correct MTU and DNS settings. The Psiphon
 server will call Server.ClientConnected when a client connects and establishes
 a packet tunnel channel; and Server.ClientDisconnected when the client closes
 the channel and/or disconnects.
-
 */
 package tun
 
@@ -576,6 +573,11 @@ func (server *Server) resumeSession(
 	// Set new access control, flow monitoring, and metrics
 	// callbacks; all associated with the new client connection.
 
+	// IMPORTANT: any new callbacks or references to the outer client added
+	// here must be cleared in interruptSession to ensure that a paused
+	// session does not retain references to old client connection objects
+	// after the client disconnects.
+
 	session.setCheckAllowedTCPPortFunc(&checkAllowedTCPPortFunc)
 
 	session.setCheckAllowedUDPPortFunc(&checkAllowedUDPPortFunc)
@@ -665,6 +667,8 @@ func (server *Server) interruptSession(session *session) {
 	session.setFlowActivityUpdaterMaker(nil)
 
 	session.setMetricsUpdater(nil)
+
+	session.setDNSQualityReporter(nil)
 }
 
 func (server *Server) runSessionReaper() {
@@ -1506,10 +1510,12 @@ func (session *session) deleteFlow(ID flowID, flowState *flowState) {
 
 			resolveElapsedTime := dnsEndTime.Sub(dnsStartTime)
 
-			flowState.dnsQualityReporter(
-				dnsSuccess,
-				resolveElapsedTime,
-				net.IP(ID.upstreamIPAddress[:]))
+			if flowState.dnsQualityReporter != nil {
+				flowState.dnsQualityReporter(
+					dnsSuccess,
+					resolveElapsedTime,
+					net.IP(ID.upstreamIPAddress[:]))
+			}
 		}
 	}
 

+ 3 - 0
psiphon/server/meek.go

@@ -674,6 +674,9 @@ func (server *MeekServer) getSessionOrEndpoint(
 				IPs := strings.Split(value, ",")
 				IP := IPs[len(IPs)-1]
 
+				// Remove optional whitespace surrounding the commas.
+				IP = strings.TrimSpace(IP)
+
 				if net.ParseIP(IP) != nil {
 					clientIP = IP
 					break

+ 105 - 70
psiphon/server/trafficRules.go

@@ -59,9 +59,13 @@ type TrafficRulesSet struct {
 	// For each client, the first matching Filter in FilteredTrafficRules
 	// determines the additional Rules that are selected and applied
 	// on top of DefaultRules.
+	//
+	// When ExceptFilter is present, a client must match Filter and not match
+	// ExceptFilter.
 	FilteredRules []struct {
-		Filter TrafficRulesFilter
-		Rules  TrafficRules
+		Filter       TrafficRulesFilter
+		ExceptFilter *TrafficRulesFilter
+		Rules        TrafficRules
 	}
 
 	// MeekRateLimiterHistorySize enables the late-stage meek rate limiter and
@@ -418,14 +422,8 @@ func (set *TrafficRulesSet) Validate() error {
 		return nil
 	}
 
-	err := validateTrafficRules(&set.DefaultRules)
-	if err != nil {
-		return errors.Trace(err)
-	}
-
-	for _, filteredRule := range set.FilteredRules {
-
-		for paramName := range filteredRule.Filter.HandshakeParameters {
+	validateFilter := func(filter *TrafficRulesFilter) error {
+		for paramName := range filter.HandshakeParameters {
 			validParamName := false
 			for _, paramSpec := range handshakeRequestParams {
 				if paramSpec.name == paramName {
@@ -437,8 +435,29 @@ func (set *TrafficRulesSet) Validate() error {
 				return errors.Tracef("invalid parameter name: %s", paramName)
 			}
 		}
+		return nil
+	}
 
-		err := validateTrafficRules(&filteredRule.Rules)
+	err := validateTrafficRules(&set.DefaultRules)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	for _, filteredRule := range set.FilteredRules {
+
+		err := validateFilter(&filteredRule.Filter)
+		if err != nil {
+			return errors.Trace(err)
+		}
+
+		if filteredRule.ExceptFilter != nil {
+			err := validateFilter(filteredRule.ExceptFilter)
+			if err != nil {
+				return errors.Trace(err)
+			}
+		}
+
+		err = validateTrafficRules(&filteredRule.Rules)
 		if err != nil {
 			return errors.Trace(err)
 		}
@@ -506,6 +525,9 @@ func (set *TrafficRulesSet) initLookups() {
 
 	for i := range set.FilteredRules {
 		initTrafficRulesFilterLookups(&set.FilteredRules[i].Filter)
+		if set.FilteredRules[i].ExceptFilter != nil {
+			initTrafficRulesFilterLookups(set.FilteredRules[i].ExceptFilter)
+		}
 		initTrafficRulesLookups(&set.FilteredRules[i].Rules)
 	}
 
@@ -616,144 +638,157 @@ func (set *TrafficRulesSet) GetTrafficRules(
 		trafficRules.DisableDiscovery = new(bool)
 	}
 
-	// TODO: faster lookup?
-	for _, filteredRules := range set.FilteredRules {
+	// matchFilter is used to check both Filter and any ExceptFilter
 
-		log.WithTraceFields(LogFields{"filter": filteredRules.Filter}).Debug("filter check")
+	matchFilter := func(filter *TrafficRulesFilter) bool {
 
-		if len(filteredRules.Filter.TunnelProtocols) > 0 {
-			if !common.Contains(filteredRules.Filter.TunnelProtocols, tunnelProtocol) {
-				continue
+		if len(filter.TunnelProtocols) > 0 {
+			if !common.Contains(filter.TunnelProtocols, tunnelProtocol) {
+				return false
 			}
 		}
 
-		if len(filteredRules.Filter.Regions) > 0 {
-			if filteredRules.Filter.regionLookup != nil {
-				if !filteredRules.Filter.regionLookup[geoIPData.Country] {
-					continue
+		if len(filter.Regions) > 0 {
+			if filter.regionLookup != nil {
+				if !filter.regionLookup[geoIPData.Country] {
+					return false
 				}
 			} else {
-				if !common.Contains(filteredRules.Filter.Regions, geoIPData.Country) {
-					continue
+				if !common.Contains(filter.Regions, geoIPData.Country) {
+					return false
 				}
 			}
 		}
 
-		if len(filteredRules.Filter.ISPs) > 0 {
-			if filteredRules.Filter.ispLookup != nil {
-				if !filteredRules.Filter.ispLookup[geoIPData.ISP] {
-					continue
+		if len(filter.ISPs) > 0 {
+			if filter.ispLookup != nil {
+				if !filter.ispLookup[geoIPData.ISP] {
+					return false
 				}
 			} else {
-				if !common.Contains(filteredRules.Filter.ISPs, geoIPData.ISP) {
-					continue
+				if !common.Contains(filter.ISPs, geoIPData.ISP) {
+					return false
 				}
 			}
 		}
 
-		if len(filteredRules.Filter.ASNs) > 0 {
-			if filteredRules.Filter.asnLookup != nil {
-				if !filteredRules.Filter.asnLookup[geoIPData.ASN] {
-					continue
+		if len(filter.ASNs) > 0 {
+			if filter.asnLookup != nil {
+				if !filter.asnLookup[geoIPData.ASN] {
+					return false
 				}
 			} else {
-				if !common.Contains(filteredRules.Filter.ASNs, geoIPData.ASN) {
-					continue
+				if !common.Contains(filter.ASNs, geoIPData.ASN) {
+					return false
 				}
 			}
 		}
 
-		if len(filteredRules.Filter.Cities) > 0 {
-			if filteredRules.Filter.cityLookup != nil {
-				if !filteredRules.Filter.cityLookup[geoIPData.City] {
-					continue
+		if len(filter.Cities) > 0 {
+			if filter.cityLookup != nil {
+				if !filter.cityLookup[geoIPData.City] {
+					return false
 				}
 			} else {
-				if !common.Contains(filteredRules.Filter.Cities, geoIPData.City) {
-					continue
+				if !common.Contains(filter.Cities, geoIPData.City) {
+					return false
 				}
 			}
 		}
 
-		if filteredRules.Filter.APIProtocol != "" {
+		if filter.APIProtocol != "" {
 			if !state.completed {
-				continue
+				return false
 			}
-			if state.apiProtocol != filteredRules.Filter.APIProtocol {
-				continue
+			if state.apiProtocol != filter.APIProtocol {
+				return false
 			}
 		}
 
-		if filteredRules.Filter.HandshakeParameters != nil {
+		if filter.HandshakeParameters != nil {
 			if !state.completed {
-				continue
+				return false
 			}
 
-			mismatch := false
-			for name, values := range filteredRules.Filter.HandshakeParameters {
+			for name, values := range filter.HandshakeParameters {
 				clientValue, err := getStringRequestParam(state.apiParams, name)
 				if err != nil || !common.ContainsWildcard(values, clientValue) {
-					mismatch = true
-					break
+					return false
 				}
 			}
-			if mismatch {
-				continue
-			}
 		}
 
-		if filteredRules.Filter.AuthorizationsRevoked {
+		if filter.AuthorizationsRevoked {
 			if !state.completed {
-				continue
+				return false
 			}
 
 			if !state.authorizationsRevoked {
-				continue
+				return false
 			}
 
 		} else {
-			if len(filteredRules.Filter.ActiveAuthorizationIDs) > 0 {
+			if len(filter.ActiveAuthorizationIDs) > 0 {
 				if !state.completed {
-					continue
+					return false
 				}
 
 				if state.authorizationsRevoked {
-					continue
+					return false
 				}
 
-				if filteredRules.Filter.activeAuthorizationIDLookup != nil {
+				if filter.activeAuthorizationIDLookup != nil {
 					found := false
 					for _, ID := range state.activeAuthorizationIDs {
-						if filteredRules.Filter.activeAuthorizationIDLookup[ID] {
+						if filter.activeAuthorizationIDLookup[ID] {
 							found = true
 							break
 						}
 					}
 					if !found {
-						continue
+						return false
 					}
 				} else {
-					if !common.ContainsAny(filteredRules.Filter.ActiveAuthorizationIDs, state.activeAuthorizationIDs) {
-						continue
+					if !common.ContainsAny(filter.ActiveAuthorizationIDs, state.activeAuthorizationIDs) {
+						return false
 					}
 				}
 
 			}
-			if len(filteredRules.Filter.AuthorizedAccessTypes) > 0 {
+			if len(filter.AuthorizedAccessTypes) > 0 {
 				if !state.completed {
-					continue
+					return false
 				}
 
 				if state.authorizationsRevoked {
-					continue
+					return false
 				}
 
-				if !common.ContainsAny(filteredRules.Filter.AuthorizedAccessTypes, state.authorizedAccessTypes) {
-					continue
+				if !common.ContainsAny(filter.AuthorizedAccessTypes, state.authorizedAccessTypes) {
+					return false
 				}
 			}
 		}
 
+		return true
+	}
+
+	// Match filtered rules
+	//
+	// TODO: faster lookup?
+
+	for _, filteredRules := range set.FilteredRules {
+
+		log.WithTraceFields(LogFields{"filter": filteredRules.Filter}).Debug("filter check")
+
+		match := matchFilter(&filteredRules.Filter)
+		if match && filteredRules.ExceptFilter != nil {
+			match = !matchFilter(filteredRules.ExceptFilter)
+		}
+		if !match {
+			continue
+		}
+
 		log.WithTraceFields(LogFields{"filter": filteredRules.Filter}).Debug("filter match")
 
 		// This is the first match. Override defaults using provided fields from selected rules, and return result.

+ 274 - 0
psiphon/server/trafficRules_test.go

@@ -0,0 +1,274 @@
+/*
+ * Copyright (c) 2022, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package server
+
+import (
+	"encoding/json"
+	"io/ioutil"
+	"os"
+	"reflect"
+	"testing"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+)
+
+func TestTrafficRulesFilters(t *testing.T) {
+
+	trafficRulesJSON := `
+	{
+      "DefaultRules" :  {
+        "RateLimits" : {
+          "WriteUnthrottledBytes": 1,
+          "WriteBytesPerSecond": 2,
+          "ReadUnthrottledBytes": 3,
+          "ReadBytesPerSecond": 4,
+          "UnthrottleFirstTunnelOnly": true
+        },
+        "AllowTCPPorts" : [5],
+        "AllowUDPPorts" : [6]
+      },
+  
+      "FilteredRules" : [
+  
+        {
+          "Filter" : {
+            "Regions" : ["R2"],
+            "HandshakeParameters" : {
+                "client_version" : ["1"]
+            }
+          },
+          "Rules" : {
+            "RateLimits" : {
+              "WriteBytesPerSecond": 7,
+              "ReadBytesPerSecond": 8
+            },
+            "AllowTCPPorts" : [5,9],
+            "AllowUDPPorts" : [6,10]
+          }
+        },
+
+        {
+          "Filter" : {
+            "TunnelProtocols" : ["P2"],
+            "Regions" : ["R3", "R4"],
+            "HandshakeParameters" : {
+                "client_version" : ["1", "2"]
+            }
+          },
+          "ExceptFilter" : {
+            "ISPs" : ["I2", "I3"],
+            "HandshakeParameters" : {
+                "client_version" : ["1"]
+            }
+          },
+          "Rules" : {
+            "RateLimits" : {
+              "WriteBytesPerSecond": 11,
+              "ReadBytesPerSecond": 12
+            },
+            "AllowTCPPorts" : [5,13],
+            "AllowUDPPorts" : [6,14]
+          }
+        },
+
+        {
+          "Filter" : {
+            "Regions" : ["R3", "R4"],
+            "HandshakeParameters" : {
+                "client_version" : ["1", "2"]
+            }
+          },
+          "ExceptFilter" : {
+            "ISPs" : ["I2", "I3"],
+            "HandshakeParameters" : {
+                "client_version" : ["1"]
+            }
+          },
+          "Rules" : {
+            "RateLimits" : {
+              "WriteBytesPerSecond": 15,
+              "ReadBytesPerSecond": 16
+            },
+            "AllowTCPPorts" : [5,17],
+            "AllowUDPPorts" : [6,18]
+          }
+        }
+      ]
+    }
+	`
+
+	file, err := ioutil.TempFile("", "trafficRules.config")
+	if err != nil {
+		t.Fatalf("TempFile create failed: %s", err)
+	}
+	_, err = file.Write([]byte(trafficRulesJSON))
+	if err != nil {
+		t.Fatalf("TempFile write failed: %s", err)
+	}
+	file.Close()
+	configFileName := file.Name()
+	defer os.Remove(configFileName)
+
+	trafficRules, err := NewTrafficRulesSet(configFileName)
+	if err != nil {
+		t.Fatalf("NewTrafficRulesSet failed: %s", err)
+	}
+
+	err = trafficRules.Validate()
+	if err != nil {
+		t.Fatalf("TrafficRulesSet.Validate failed: %s", err)
+	}
+
+	makePortList := func(portsJSON string) common.PortList {
+		var p common.PortList
+		_ = json.Unmarshal([]byte(portsJSON), &p)
+		return p
+	}
+
+	testCases := []struct {
+		description                   string
+		isFirstTunnelInSession        bool
+		tunnelProtocol                string
+		geoIPData                     GeoIPData
+		state                         handshakeState
+		expectedWriteUnthrottledBytes int64
+		expectedWriteBytesPerSecond   int64
+		expectedReadUnthrottledBytes  int64
+		expectedReadBytesPerSecond    int64
+		expectedAllowTCPPorts         common.PortList
+		expectedAllowUDPPorts         common.PortList
+	}{
+		{
+			"get defaults",
+			true,
+			"P1",
+			GeoIPData{Country: "R1", ISP: "I1"},
+			handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
+			1, 2, 3, 4, makePortList("[5]"), makePortList("[6]"),
+		},
+
+		{
+			"get defaults for not first tunnel in session",
+			false,
+			"P1",
+			GeoIPData{Country: "R1", ISP: "I1"},
+			handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
+			0, 2, 0, 4, makePortList("[5]"), makePortList("[6]"),
+		},
+
+		{
+			"get first filtered rule",
+			true,
+			"P1",
+			GeoIPData{Country: "R2", ISP: "I1"},
+			handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
+			1, 7, 3, 8, makePortList("[5,9]"), makePortList("[6,10]"),
+		},
+
+		{
+			"don't get first filtered rule with incomplete match",
+			true,
+			"P1",
+			GeoIPData{Country: "R2", ISP: "I1"},
+			handshakeState{apiParams: map[string]interface{}{"client_version": "2"}, completed: true},
+			1, 2, 3, 4, makePortList("[5]"), makePortList("[6]"),
+		},
+
+		{
+			"get second filtered rule",
+			true,
+			"P2",
+			GeoIPData{Country: "R3", ISP: "I1"},
+			handshakeState{apiParams: map[string]interface{}{"client_version": "2"}, completed: true},
+			1, 11, 3, 12, makePortList("[5,13]"), makePortList("[6,14]"),
+		},
+
+		{
+			"get second filtered rule with incomplete exception",
+			true,
+			"P2",
+			GeoIPData{Country: "R3", ISP: "I2"},
+			handshakeState{apiParams: map[string]interface{}{"client_version": "2"}, completed: true},
+			1, 11, 3, 12, makePortList("[5,13]"), makePortList("[6,14]"),
+		},
+
+		{
+			"don't get second filtered rule due to exception",
+			true,
+			"P2",
+			GeoIPData{Country: "R3", ISP: "I2"},
+			handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
+			1, 2, 3, 4, makePortList("[5]"), makePortList("[6]"),
+		},
+
+		{
+			"get third filtered rule",
+			true,
+			"P1",
+			GeoIPData{Country: "R3", ISP: "I1"},
+			handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
+			1, 15, 3, 16, makePortList("[5,17]"), makePortList("[6,18]"),
+		},
+
+		{
+			"don't get third filtered rule due to exception",
+			true,
+			"P1",
+			GeoIPData{Country: "R3", ISP: "I2"},
+			handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
+			1, 2, 3, 4, makePortList("[5]"), makePortList("[6]"),
+		},
+	}
+	for _, testCase := range testCases {
+		t.Run(testCase.description, func(t *testing.T) {
+
+			rules := trafficRules.GetTrafficRules(
+				testCase.isFirstTunnelInSession,
+				testCase.tunnelProtocol,
+				testCase.geoIPData,
+				testCase.state)
+
+			if *rules.RateLimits.WriteUnthrottledBytes != testCase.expectedWriteUnthrottledBytes {
+				t.Errorf("unexpected rules.RateLimits.WriteUnthrottledBytes: %v != %v",
+					*rules.RateLimits.WriteUnthrottledBytes, testCase.expectedWriteUnthrottledBytes)
+			}
+			if *rules.RateLimits.WriteBytesPerSecond != testCase.expectedWriteBytesPerSecond {
+				t.Errorf("unexpected rules.RateLimits.WriteBytesPerSecond: %v != %v",
+					*rules.RateLimits.WriteBytesPerSecond, testCase.expectedWriteBytesPerSecond)
+			}
+			if *rules.RateLimits.ReadUnthrottledBytes != testCase.expectedReadUnthrottledBytes {
+				t.Errorf("unexpected rules.RateLimits.ReadUnthrottledBytes: %v != %v",
+					*rules.RateLimits.ReadUnthrottledBytes, testCase.expectedReadUnthrottledBytes)
+			}
+			if *rules.RateLimits.ReadBytesPerSecond != testCase.expectedReadBytesPerSecond {
+				t.Errorf("unexpected rules.RateLimits.ReadBytesPerSecond: %v != %v",
+					*rules.RateLimits.ReadBytesPerSecond, testCase.expectedReadBytesPerSecond)
+			}
+			if !reflect.DeepEqual(*rules.AllowTCPPorts, testCase.expectedAllowTCPPorts) {
+				t.Errorf("unexpected rules.RateLimits.AllowTCPPorts: %v != %v",
+					*rules.AllowTCPPorts, testCase.expectedAllowTCPPorts)
+			}
+			if !reflect.DeepEqual(*rules.AllowUDPPorts, testCase.expectedAllowUDPPorts) {
+				t.Errorf("unexpected rules.RateLimits.AllowUDPPorts: %v != %v",
+					*rules.AllowUDPPorts, testCase.expectedAllowUDPPorts)
+			}
+		})
+	}
+}

+ 1 - 2
psiphon/server/udp.go

@@ -42,7 +42,6 @@ import (
 // The udpgw protocol and original server implementation:
 // Copyright (c) 2009, Ambroz Bizjak <ambrop7@gmail.com>
 // https://github.com/ambrop72/badvpn
-//
 func (sshClient *sshClient) handleUdpgwChannel(newChannel ssh.NewChannel) {
 
 	// Accept this channel immediately. This channel will replace any
@@ -264,7 +263,7 @@ func (mux *udpgwPortForwardMultiplexer) run() {
 			// Can't defer lruEntry.Remove() here;
 			// relayDownstream will call lruEntry.Remove()
 
-			// ActivityMonitoredConn monitors the TCP port forward I/O and updates
+			// ActivityMonitoredConn monitors the UDP port forward I/O and updates
 			// its LRU status. ActivityMonitoredConn also times out I/O on the port
 			// forward if both reads and writes have been idle for the specified
 			// duration.