Просмотр исходного кода

Fix: apply throttling net.Conn layer early enough to take effect

Rod Hynes 9 лет назад
Родитель
Сommit
b3d6ea7adc
1 измененных файлов с 17 добавлено и 13 удалено
  1. 17 13
      psiphon/tunnel.go

+ 17 - 13
psiphon/tunnel.go

@@ -124,20 +124,17 @@ func EstablishTunnel(
 	}
 
 	// Build transport layers and establish SSH connection
-	dialConn, sshClient, dialStats, err := dialSsh(
+	conn, sshClient, dialStats, err := dialSsh(
 		config, pendingConns, serverEntry, selectedProtocol, sessionId)
 	if err != nil {
 		return nil, common.ContextError(err)
 	}
 
-	// Apply throttling (if configured)
-	conn := common.NewThrottledConn(dialConn, config.RateLimits)
-
 	// Cleanup on error
 	defer func() {
 		if err != nil {
 			sshClient.Close()
-			dialConn.Close()
+			conn.Close()
 		}
 	}()
 
@@ -178,7 +175,7 @@ func EstablishTunnel(
 	tunnel.startTime = time.Now()
 
 	// Now that network operations are complete, cancel interruptibility
-	pendingConns.Remove(dialConn)
+	pendingConns.Remove(conn)
 
 	// Spawn the operateTunnel goroutine, which monitors the tunnel and handles periodic stats updates.
 	tunnel.operateWaitGroup.Add(1)
@@ -589,20 +586,20 @@ func dialSsh(
 		DeviceRegion:                  config.DeviceRegion,
 		ResolvedIPCallback:            setResolvedIPAddress,
 	}
-	var conn net.Conn
+	var dialConn net.Conn
 	if meekConfig != nil {
-		conn, err = DialMeek(meekConfig, dialConfig)
+		dialConn, err = DialMeek(meekConfig, dialConfig)
 		if err != nil {
 			return nil, nil, nil, common.ContextError(err)
 		}
 	} else {
-		conn, err = DialTCP(directTCPDialAddress, dialConfig)
+		dialConn, err = DialTCP(directTCPDialAddress, dialConfig)
 		if err != nil {
 			return nil, nil, nil, common.ContextError(err)
 		}
 	}
 
-	cleanupConn := conn
+	cleanupConn := dialConn
 	defer func() {
 		// Cleanup on error
 		if cleanupConn != nil {
@@ -610,11 +607,14 @@ func dialSsh(
 		}
 	}()
 
+	// Apply throttling (if configured)
+	throttledConn := common.NewThrottledConn(dialConn, config.RateLimits)
+
 	// Add obfuscated SSH layer
-	sshConn := conn
+	var sshConn net.Conn = throttledConn
 	if useObfuscatedSsh {
 		sshConn, err = NewObfuscatedSshConn(
-			OBFUSCATION_CONN_MODE_CLIENT, conn, serverEntry.SshObfuscatedKey)
+			OBFUSCATION_CONN_MODE_CLIENT, throttledConn, serverEntry.SshObfuscatedKey)
 		if err != nil {
 			return nil, nil, nil, common.ContextError(err)
 		}
@@ -720,7 +720,11 @@ func dialSsh(
 
 	cleanupConn = nil
 
-	return conn, result.sshClient, dialStats, nil
+	// Note: dialConn may be used to close the underlying network connection
+	// but should not be used to perform I/O as that would interfere with SSH
+	// (and also bypasses throttling).
+
+	return dialConn, result.sshClient, dialStats, nil
 }
 
 func makeRandomPeriod(min, max time.Duration) time.Duration {