|
|
@@ -1232,6 +1232,22 @@ func (controller *Controller) Dial(
|
|
|
return nil, errors.TraceNew("no active tunnels")
|
|
|
}
|
|
|
|
|
|
+ if !controller.config.EnableSplitTunnel {
|
|
|
+
|
|
|
+ tunneledConn, splitTunnel, err := tunnel.DialTCPChannel(
|
|
|
+ remoteAddr, false, downstreamConn)
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if splitTunnel {
|
|
|
+ return nil, errors.TraceNew(
|
|
|
+ "unexpected split tunnel classification")
|
|
|
+ }
|
|
|
+
|
|
|
+ return tunneledConn, nil
|
|
|
+ }
|
|
|
+
|
|
|
// In split tunnel mode, TCP port forwards to destinations in the same
|
|
|
// country as the client are untunneled.
|
|
|
//
|
|
|
@@ -1255,22 +1271,17 @@ func (controller *Controller) Dial(
|
|
|
// it does for all port forwards in non-split tunnel mode. There is no
|
|
|
// additional round trip for tunneled port forwards.
|
|
|
|
|
|
- untunneledCache := controller.untunneledSplitTunnelClassifications
|
|
|
- var splitTunnelHost string
|
|
|
- cachedUntunneled := false
|
|
|
+ splitTunnelHost, _, err := net.SplitHostPort(remoteAddr)
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Trace(err)
|
|
|
+ }
|
|
|
|
|
|
- if controller.config.EnableSplitTunnel {
|
|
|
- var err error
|
|
|
- splitTunnelHost, _, err = net.SplitHostPort(remoteAddr)
|
|
|
- if err != nil {
|
|
|
- return nil, errors.Trace(err)
|
|
|
- }
|
|
|
+ untunneledCache := controller.untunneledSplitTunnelClassifications
|
|
|
|
|
|
- // If the destination hostname is in the untunneled split tunnel
|
|
|
- // classifications cache, skip the round trip to the server and do the
|
|
|
- // direct, untunneled dial immediately.
|
|
|
- _, cachedUntunneled = untunneledCache.Get(splitTunnelHost)
|
|
|
- }
|
|
|
+ // If the destination hostname is in the untunneled split tunnel
|
|
|
+ // classifications cache, skip the round trip to the server and do the
|
|
|
+ // direct, untunneled dial immediately.
|
|
|
+ _, cachedUntunneled := untunneledCache.Get(splitTunnelHost)
|
|
|
|
|
|
if !cachedUntunneled {
|
|
|
|
|
|
@@ -1282,21 +1293,13 @@ func (controller *Controller) Dial(
|
|
|
|
|
|
if !splitTunnel {
|
|
|
|
|
|
- if controller.config.EnableSplitTunnel {
|
|
|
-
|
|
|
- // Clear any cached untunneled classification entry for this destination
|
|
|
- // hostname, as the server is now classifying it as tunneled.
|
|
|
- untunneledCache.Delete(splitTunnelHost)
|
|
|
- }
|
|
|
+ // Clear any cached untunneled classification entry for this destination
|
|
|
+ // hostname, as the server is now classifying it as tunneled.
|
|
|
+ untunneledCache.Delete(splitTunnelHost)
|
|
|
|
|
|
return tunneledConn, nil
|
|
|
}
|
|
|
|
|
|
- if !controller.config.EnableSplitTunnel {
|
|
|
- return nil, errors.TraceNew(
|
|
|
- "unexpected split tunnel classification")
|
|
|
- }
|
|
|
-
|
|
|
// The server has indicated that the client should make a direct,
|
|
|
// untunneled dial. Cache the classification to avoid this round trip in
|
|
|
// the immediate future.
|
|
|
@@ -1309,6 +1312,7 @@ func (controller *Controller) Dial(
|
|
|
if err != nil {
|
|
|
return nil, errors.Trace(err)
|
|
|
}
|
|
|
+
|
|
|
return untunneledConn, nil
|
|
|
}
|
|
|
|