Procházet zdrojové kódy

Readability: disentangle normal and split tunnel mode code paths

Rod Hynes před 5 roky
rodič
revize
9216f7a446
1 změnil soubory, kde provedl 29 přidání a 25 odebrání
  1. 29 25
      psiphon/controller.go

+ 29 - 25
psiphon/controller.go

@@ -1232,6 +1232,22 @@ func (controller *Controller) Dial(
 		return nil, errors.TraceNew("no active tunnels")
 	}
 
+	if !controller.config.EnableSplitTunnel {
+
+		tunneledConn, splitTunnel, err := tunnel.DialTCPChannel(
+			remoteAddr, false, downstreamConn)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		if splitTunnel {
+			return nil, errors.TraceNew(
+				"unexpected split tunnel classification")
+		}
+
+		return tunneledConn, nil
+	}
+
 	// In split tunnel mode, TCP port forwards to destinations in the same
 	// country as the client are untunneled.
 	//
@@ -1255,22 +1271,17 @@ func (controller *Controller) Dial(
 	// it does for all port forwards in non-split tunnel mode. There is no
 	// additional round trip for tunneled port forwards.
 
-	untunneledCache := controller.untunneledSplitTunnelClassifications
-	var splitTunnelHost string
-	cachedUntunneled := false
+	splitTunnelHost, _, err := net.SplitHostPort(remoteAddr)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
 
-	if controller.config.EnableSplitTunnel {
-		var err error
-		splitTunnelHost, _, err = net.SplitHostPort(remoteAddr)
-		if err != nil {
-			return nil, errors.Trace(err)
-		}
+	untunneledCache := controller.untunneledSplitTunnelClassifications
 
-		// If the destination hostname is in the untunneled split tunnel
-		// classifications cache, skip the round trip to the server and do the
-		// direct, untunneled dial immediately.
-		_, cachedUntunneled = untunneledCache.Get(splitTunnelHost)
-	}
+	// If the destination hostname is in the untunneled split tunnel
+	// classifications cache, skip the round trip to the server and do the
+	// direct, untunneled dial immediately.
+	_, cachedUntunneled := untunneledCache.Get(splitTunnelHost)
 
 	if !cachedUntunneled {
 
@@ -1282,21 +1293,13 @@ func (controller *Controller) Dial(
 
 		if !splitTunnel {
 
-			if controller.config.EnableSplitTunnel {
-
-				// Clear any cached untunneled classification entry for this destination
-				// hostname, as the server is now classifying it as tunneled.
-				untunneledCache.Delete(splitTunnelHost)
-			}
+			// Clear any cached untunneled classification entry for this destination
+			// hostname, as the server is now classifying it as tunneled.
+			untunneledCache.Delete(splitTunnelHost)
 
 			return tunneledConn, nil
 		}
 
-		if !controller.config.EnableSplitTunnel {
-			return nil, errors.TraceNew(
-				"unexpected split tunnel classification")
-		}
-
 		// The server has indicated that the client should make a direct,
 		// untunneled dial. Cache the classification to avoid this round trip in
 		// the immediate future.
@@ -1309,6 +1312,7 @@ func (controller *Controller) Dial(
 	if err != nil {
 		return nil, errors.Trace(err)
 	}
+
 	return untunneledConn, nil
 }