|
|
@@ -1376,6 +1376,13 @@ func (q *qualityMetrics) reset() {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+type handshakeStateInfo struct {
|
|
|
+ activeAuthorizationIDs []string
|
|
|
+ authorizedAccessTypes []string
|
|
|
+ upstreamBytesPerSecond int64
|
|
|
+ downstreamBytesPerSecond int64
|
|
|
+}
|
|
|
+
|
|
|
type handshakeState struct {
|
|
|
completed bool
|
|
|
apiProtocol string
|
|
|
@@ -1385,14 +1392,64 @@ type handshakeState struct {
|
|
|
authorizationsRevoked bool
|
|
|
expectDomainBytes bool
|
|
|
establishedTunnelsCount int
|
|
|
- splitTunnel bool
|
|
|
+ splitTunnelLookup *splitTunnelLookup
|
|
|
}
|
|
|
|
|
|
-type handshakeStateInfo struct {
|
|
|
- activeAuthorizationIDs []string
|
|
|
- authorizedAccessTypes []string
|
|
|
- upstreamBytesPerSecond int64
|
|
|
- downstreamBytesPerSecond int64
|
|
|
+type splitTunnelLookup struct {
|
|
|
+ regions []string
|
|
|
+ regionsLookup map[string]bool
|
|
|
+}
|
|
|
+
|
|
|
+func newSplitTunnelLookup(
|
|
|
+ ownRegion string,
|
|
|
+ otherRegions []string) (*splitTunnelLookup, error) {
|
|
|
+
|
|
|
+ length := len(otherRegions)
|
|
|
+ if ownRegion != "" {
|
|
|
+ length += 1
|
|
|
+ }
|
|
|
+
|
|
|
+ // This length check is a sanity check and prevents clients shipping
|
|
|
+ // excessively long lists which could impact performance.
|
|
|
+ if length > 250 {
|
|
|
+ return nil, errors.Tracef("too many regions: %d", length)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Create map lookups for lists where the number of values to compare
|
|
|
+ // against exceeds a threshold where benchmarks show maps are faster than
|
|
|
+ // looping through a slice. Otherwise use a slice for lookups. In both
|
|
|
+ // cases, the input slice is no longer referenced.
|
|
|
+
|
|
|
+ if length >= stringLookupThreshold {
|
|
|
+ regionsLookup := make(map[string]bool)
|
|
|
+ if ownRegion != "" {
|
|
|
+ regionsLookup[ownRegion] = true
|
|
|
+ }
|
|
|
+ for _, region := range otherRegions {
|
|
|
+ regionsLookup[region] = true
|
|
|
+ }
|
|
|
+ return &splitTunnelLookup{
|
|
|
+ regionsLookup: regionsLookup,
|
|
|
+ }, nil
|
|
|
+ } else {
|
|
|
+ regions := []string{}
|
|
|
+ if ownRegion != "" && !common.Contains(otherRegions, ownRegion) {
|
|
|
+ regions = append(regions, ownRegion)
|
|
|
+ }
|
|
|
+ // TODO: check for other duplicate regions?
|
|
|
+ regions = append(regions, otherRegions...)
|
|
|
+ return &splitTunnelLookup{
|
|
|
+ regions: regions,
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (lookup *splitTunnelLookup) lookup(region string) bool {
|
|
|
+ if lookup.regionsLookup != nil {
|
|
|
+ return lookup.regionsLookup[region]
|
|
|
+ } else {
|
|
|
+ return common.Contains(lookup.regions, region)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func newSshClient(
|
|
|
@@ -2520,11 +2577,13 @@ func (sshClient *sshClient) handleNewTCPPortForwardChannel(
|
|
|
// Split tunnel logic is enabled for this TCP port forward when the client
|
|
|
// has enabled split tunnel mode and the channel type allows it.
|
|
|
|
|
|
+ doSplitTunnel := sshClient.handshakeState.splitTunnelLookup != nil && allowSplitTunnel
|
|
|
+
|
|
|
tcpPortForward := &newTCPPortForward{
|
|
|
enqueueTime: time.Now(),
|
|
|
hostToConnect: directTcpipExtraData.HostToConnect,
|
|
|
portToConnect: int(directTcpipExtraData.PortToConnect),
|
|
|
- doSplitTunnel: sshClient.handshakeState.splitTunnel && allowSplitTunnel,
|
|
|
+ doSplitTunnel: doSplitTunnel,
|
|
|
newChannel: newChannel,
|
|
|
}
|
|
|
|
|
|
@@ -3816,8 +3875,9 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
|
|
|
destinationGeoIPData := sshClient.sshServer.support.GeoIPService.LookupIP(IP)
|
|
|
|
|
|
- if destinationGeoIPData.Country == sshClient.geoIPData.Country &&
|
|
|
- sshClient.geoIPData.Country != GEOIP_UNKNOWN_VALUE {
|
|
|
+ if sshClient.geoIPData.Country != GEOIP_UNKNOWN_VALUE &&
|
|
|
+ sshClient.handshakeState.splitTunnelLookup.lookup(
|
|
|
+ destinationGeoIPData.Country) {
|
|
|
|
|
|
// Since isPortForwardPermitted is not called in this case, explicitly call
|
|
|
// ipBlocklistCheck. The domain blocklist case is handled above.
|