Procházet zdrojové kódy

Merge pull request #155 from rod-hynes/master

Impaired protocol test case and fixes
Rod Hynes před 10 roky
rodič
revize
9d4d3dcc8e
3 změnil soubory, kde provedl 260 přidání a 72 odebrání
  1. 69 43
      psiphon/controller.go
  2. 186 29
      psiphon/controller_test.go
  3. 5 0
      psiphon/notice.go

+ 69 - 43
psiphon/controller.go

@@ -509,50 +509,70 @@ loop:
 				controller.startEstablishing()
 			}
 
-		// !TODO! design issue: might not be enough server entries with region/caps to ever fill tunnel slots
-		// solution(?) target MIN(CountServerEntries(region, protocol), TunnelPoolSize)
 		case establishedTunnel := <-controller.establishedTunnels:
-			tunnelCount, registered := controller.registerTunnel(establishedTunnel)
-			if registered {
-				NoticeActiveTunnel(establishedTunnel.serverEntry.IpAddress, establishedTunnel.protocol)
-
-				if tunnelCount == 1 {
-
-					// The split tunnel classifier is started once the first tunnel is
-					// established. This first tunnel is passed in to be used to make
-					// the routes data request.
-					// A long-running controller may run while the host device is present
-					// in different regions. In this case, we want the split tunnel logic
-					// to switch to routes for new regions and not classify traffic based
-					// on routes installed for older regions.
-					// We assume that when regions change, the host network will also
-					// change, and so all tunnels will fail and be re-established. Under
-					// that assumption, the classifier will be re-Start()-ed here when
-					// the region has changed.
-					controller.splitTunnelClassifier.Start(establishedTunnel)
-
-					// Signal a connected request on each 1st tunnel establishment. For
-					// multi-tunnels, the session is connected as long as at least one
-					// tunnel is established.
-					controller.startOrSignalConnectedReporter()
-
-					// If the handshake indicated that a new client version is available,
-					// trigger an upgrade download.
-					// Note: serverContext is nil when DisableApi is set
-					if establishedTunnel.serverContext != nil &&
-						establishedTunnel.serverContext.clientUpgradeVersion != "" {
-
-						handshakeVersion := establishedTunnel.serverContext.clientUpgradeVersion
-						select {
-						case controller.signalDownloadUpgrade <- handshakeVersion:
-						default:
-						}
-					}
+
+			if controller.isImpairedProtocol(establishedTunnel.protocol) {
+
+				NoticeAlert("established tunnel with impaired protocol: %s", establishedTunnel.protocol)
+
+				// Protocol was classified as impaired while this tunnel
+				// established, so discard.
+				controller.discardTunnel(establishedTunnel)
+
+				// Reset establish generator to stop producing tunnels
+				// with impaired protocols.
+				if controller.isEstablishing {
+					controller.stopEstablishing()
+					controller.startEstablishing()
 				}
+				break
+			}
 
-			} else {
+			tunnelCount, registered := controller.registerTunnel(establishedTunnel)
+			if !registered {
+				// Already fully established, so discard.
 				controller.discardTunnel(establishedTunnel)
+				break
+			}
+
+			NoticeActiveTunnel(establishedTunnel.serverEntry.IpAddress, establishedTunnel.protocol)
+
+			if tunnelCount == 1 {
+
+				// The split tunnel classifier is started once the first tunnel is
+				// established. This first tunnel is passed in to be used to make
+				// the routes data request.
+				// A long-running controller may run while the host device is present
+				// in different regions. In this case, we want the split tunnel logic
+				// to switch to routes for new regions and not classify traffic based
+				// on routes installed for older regions.
+				// We assume that when regions change, the host network will also
+				// change, and so all tunnels will fail and be re-established. Under
+				// that assumption, the classifier will be re-Start()-ed here when
+				// the region has changed.
+				controller.splitTunnelClassifier.Start(establishedTunnel)
+
+				// Signal a connected request on each 1st tunnel establishment. For
+				// multi-tunnels, the session is connected as long as at least one
+				// tunnel is established.
+				controller.startOrSignalConnectedReporter()
+
+				// If the handshake indicated that a new client version is available,
+				// trigger an upgrade download.
+				// Note: serverContext is nil when DisableApi is set
+				if establishedTunnel.serverContext != nil &&
+					establishedTunnel.serverContext.clientUpgradeVersion != "" {
+
+					handshakeVersion := establishedTunnel.serverContext.clientUpgradeVersion
+					select {
+					case controller.signalDownloadUpgrade <- handshakeVersion:
+					default:
+					}
+				}
 			}
+
+			// TODO: design issue -- might not be enough server entries with region/caps to ever fill tunnel slots;
+			// possible solution is establish target MIN(CountServerEntries(region, protocol), TunnelPoolSize)
 			if controller.isFullyEstablished() {
 				controller.stopEstablishing()
 			}
@@ -612,9 +632,7 @@ func (controller *Controller) classifyImpairedProtocol(failedTunnel *Tunnel) {
 //
 // Concurrency note: only the runTunnels() goroutine may call getImpairedProtocols
 func (controller *Controller) getImpairedProtocols() []string {
-	if len(controller.impairedProtocolClassification) > 0 {
-		NoticeInfo("impaired protocols: %+v", controller.impairedProtocolClassification)
-	}
+	NoticeImpairedProtocolClassification(controller.impairedProtocolClassification)
 	impairedProtocols := make([]string, 0)
 	for protocol, count := range controller.impairedProtocolClassification {
 		if count >= IMPAIRED_PROTOCOL_CLASSIFICATION_THRESHOLD {
@@ -624,6 +642,14 @@ func (controller *Controller) getImpairedProtocols() []string {
 	return impairedProtocols
 }
 
+// isImpairedProtocol checks if the specified protocol is classified as impaired.
+//
+// Concurrency note: only the runTunnels() goroutine may call isImpairedProtocol
+func (controller *Controller) isImpairedProtocol(protocol string) bool {
+	count, ok := controller.impairedProtocolClassification[protocol]
+	return ok && count >= IMPAIRED_PROTOCOL_CLASSIFICATION_THRESHOLD
+}
+
 // SignalTunnelFailure implements the TunnelOwner interface. This function
 // is called by Tunnel.operateTunnel when the tunnel has detected that it
 // has failed. The Controller will signal runTunnels to create a new
@@ -798,7 +824,7 @@ func (controller *Controller) Dial(
 		// relative to the outbound network.
 
 		if controller.splitTunnelClassifier.IsUntunneled(host) {
-			// !TODO! track downstreamConn and close it when the DialTCP conn closes, as with tunnel.Dial conns?
+			// TODO: track downstreamConn and close it when the DialTCP conn closes, as with tunnel.Dial conns?
 			return DialTCP(remoteAddr, controller.untunneledDialConfig)
 		}
 	}

+ 186 - 29
psiphon/controller_test.go

@@ -58,8 +58,10 @@ func TestUntunneledUpgradeDownload(t *testing.T) {
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: false,
 			disableEstablishing:      true,
+			tunnelPoolSize:           1,
 			disruptNetwork:           false,
 			useHostNameTransformer:   false,
+			runDuration:              0,
 		})
 }
 
@@ -70,8 +72,10 @@ func TestUntunneledResumableUpgradeDownload(t *testing.T) {
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: false,
 			disableEstablishing:      true,
+			tunnelPoolSize:           1,
 			disruptNetwork:           true,
 			useHostNameTransformer:   false,
+			runDuration:              0,
 		})
 }
 
@@ -82,8 +86,10 @@ func TestUntunneledUpgradeClientIsLatestVersion(t *testing.T) {
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: false,
 			disableEstablishing:      true,
+			tunnelPoolSize:           1,
 			disruptNetwork:           false,
 			useHostNameTransformer:   false,
+			runDuration:              0,
 		})
 }
 
@@ -94,116 +100,158 @@ func TestTunneledUpgradeClientIsLatestVersion(t *testing.T) {
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
+			tunnelPoolSize:           1,
 			disruptNetwork:           false,
 			useHostNameTransformer:   false,
+			runDuration:              0,
 		})
 }
 
-func TestControllerRunSSH(t *testing.T) {
+func TestImpairedProtocols(t *testing.T) {
+
+	// This test sets a tunnelPoolSize of 40 and runs
+	// the session for 1 minute with network disruption
+	// on. All 40 tunnels being disrupted every 10
+	// seconds (followed by ssh keep alive probe timeout)
+	// should be sufficient to trigger at least one
+	// impaired protocol classification.
+
+	controllerRun(t,
+		&controllerRunConfig{
+			protocol:                 "",
+			clientIsLatestVersion:    true,
+			disableUntunneledUpgrade: true,
+			disableEstablishing:      false,
+			tunnelPoolSize:           40,
+			disruptNetwork:           true,
+			useHostNameTransformer:   false,
+			runDuration:              1 * time.Minute,
+		})
+}
+
+func TestSSH(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			protocol:                 TUNNEL_PROTOCOL_SSH,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
+			tunnelPoolSize:           1,
 			disruptNetwork:           false,
 			useHostNameTransformer:   false,
+			runDuration:              0,
 		})
 }
 
-func TestControllerRunObfuscatedSSH(t *testing.T) {
+func TestObfuscatedSSH(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			protocol:                 TUNNEL_PROTOCOL_OBFUSCATED_SSH,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
+			tunnelPoolSize:           1,
 			disruptNetwork:           false,
 			useHostNameTransformer:   false,
+			runDuration:              0,
 		})
 }
 
-func TestControllerRunUnfrontedMeek(t *testing.T) {
+func TestUnfrontedMeek(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			protocol:                 TUNNEL_PROTOCOL_UNFRONTED_MEEK,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
+			tunnelPoolSize:           1,
 			disruptNetwork:           false,
 			useHostNameTransformer:   false,
+			runDuration:              0,
 		})
 }
 
-func TestControllerRunUnfrontedMeekWithTransformer(t *testing.T) {
+func TestUnfrontedMeekWithTransformer(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			protocol:                 TUNNEL_PROTOCOL_UNFRONTED_MEEK,
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
+			tunnelPoolSize:           1,
 			disruptNetwork:           false,
 			useHostNameTransformer:   true,
+			runDuration:              0,
 		})
 }
 
-func TestControllerRunFrontedMeek(t *testing.T) {
+func TestFrontedMeek(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			protocol:                 TUNNEL_PROTOCOL_FRONTED_MEEK,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
+			tunnelPoolSize:           1,
 			disruptNetwork:           false,
 			useHostNameTransformer:   false,
+			runDuration:              0,
 		})
 }
 
-func TestControllerRunFrontedMeekWithTransformer(t *testing.T) {
+func TestFrontedMeekWithTransformer(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			protocol:                 TUNNEL_PROTOCOL_FRONTED_MEEK,
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
+			tunnelPoolSize:           1,
 			disruptNetwork:           false,
 			useHostNameTransformer:   true,
+			runDuration:              0,
 		})
 }
 
-func TestControllerFrontedMeekHTTP(t *testing.T) {
+func TestFrontedMeekHTTP(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			protocol:                 TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP,
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
+			tunnelPoolSize:           1,
 			disruptNetwork:           false,
 			useHostNameTransformer:   false,
+			runDuration:              0,
 		})
 }
 
-func TestControllerRunUnfrontedMeekHTTPS(t *testing.T) {
+func TestUnfrontedMeekHTTPS(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			protocol:                 TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
+			tunnelPoolSize:           1,
 			disruptNetwork:           false,
 			useHostNameTransformer:   false,
+			runDuration:              0,
 		})
 }
 
-func TestControllerRunUnfrontedMeekHTTPSWithTransformer(t *testing.T) {
+func TestUnfrontedMeekHTTPSWithTransformer(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			protocol:                 TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
+			tunnelPoolSize:           1,
 			disruptNetwork:           false,
 			useHostNameTransformer:   true,
+			runDuration:              0,
 		})
 }
 
@@ -212,8 +260,10 @@ type controllerRunConfig struct {
 	clientIsLatestVersion    bool
 	disableUntunneledUpgrade bool
 	disableEstablishing      bool
+	tunnelPoolSize           int
 	disruptNetwork           bool
 	useHostNameTransformer   bool
+	runDuration              time.Duration
 }
 
 func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
@@ -238,6 +288,8 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 		config.RemoteServerListUrl = ""
 	}
 
+	config.TunnelPoolSize = runConfig.tunnelPoolSize
+
 	if runConfig.disableUntunneledUpgrade {
 		// Disable untunneled upgrade downloader to ensure tunneled case is tested
 		config.UpgradeDownloadClientVersionHeader = ""
@@ -277,6 +329,11 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 	confirmedLatestVersion := make(chan struct{}, 1)
 
 	var clientUpgradeDownloadedBytesCount int32
+	var impairedProtocolCount int32
+	var impairedProtocolClassification = struct {
+		sync.RWMutex
+		classification map[string]int
+	}{classification: make(map[string]int)}
 
 	SetNoticeOutput(NewNoticeReceiver(
 		func(notice []byte) {
@@ -287,7 +344,21 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 				return
 			}
 			switch noticeType {
+
+			case "ListeningHttpProxyPort":
+
+				httpProxyPort = int(payload["port"].(float64))
+
+			case "ConnectingServer":
+
+				serverProtocol := payload["protocol"].(string)
+				if runConfig.protocol != "" && serverProtocol != runConfig.protocol {
+					// TODO: wrong goroutine for t.FatalNow()
+					t.Fatalf("wrong protocol selected: %s", serverProtocol)
+				}
+
 			case "Tunnels":
+
 				count := int(payload["count"].(float64))
 				if count > 0 {
 					if runConfig.disableEstablishing {
@@ -300,27 +371,59 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 						}
 					}
 				}
+
 			case "ClientUpgradeDownloadedBytes":
+
 				atomic.AddInt32(&clientUpgradeDownloadedBytesCount, 1)
 				t.Logf("ClientUpgradeDownloadedBytes: %d", int(payload["bytes"].(float64)))
+
 			case "ClientUpgradeDownloaded":
+
 				select {
 				case upgradeDownloaded <- *new(struct{}):
 				default:
 				}
+
 			case "ClientIsLatestVersion":
+
 				select {
 				case confirmedLatestVersion <- *new(struct{}):
 				default:
 				}
-			case "ListeningHttpProxyPort":
-				httpProxyPort = int(payload["port"].(float64))
-			case "ConnectingServer":
-				serverProtocol := payload["protocol"]
-				if runConfig.protocol != "" && serverProtocol != runConfig.protocol {
+
+			case "ImpairedProtocolClassification":
+
+				classification := payload["classification"].(map[string]interface{})
+
+				impairedProtocolClassification.Lock()
+				impairedProtocolClassification.classification = make(map[string]int)
+				for k, v := range classification {
+					count := int(v.(float64))
+					if count >= IMPAIRED_PROTOCOL_CLASSIFICATION_THRESHOLD {
+						atomic.AddInt32(&impairedProtocolCount, 1)
+					}
+					impairedProtocolClassification.classification[k] = count
+				}
+				impairedProtocolClassification.Unlock()
+
+			case "ActiveTunnel":
+
+				serverProtocol := payload["protocol"].(string)
+
+				classification := make(map[string]int)
+				impairedProtocolClassification.RLock()
+				for k, v := range impairedProtocolClassification.classification {
+					classification[k] = v
+				}
+				impairedProtocolClassification.RUnlock()
+
+				count, ok := classification[serverProtocol]
+				if ok && count >= IMPAIRED_PROTOCOL_CLASSIFICATION_THRESHOLD {
 					// TODO: wrong goroutine for t.FatalNow()
-					t.Fatalf("wrong protocol selected: %s", serverProtocol)
+					t.Fatalf("unexpected tunnel using impaired protocol: %s, %+v",
+						serverProtocol, classification)
 				}
+
 			}
 		}))
 
@@ -335,11 +438,11 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 	}()
 
 	defer func() {
-		// Test: shutdown must complete within 10 seconds
+		// Test: shutdown must complete within 20 seconds
 
 		close(shutdownBroadcast)
 
-		shutdownTimeout := time.NewTimer(10 * time.Second)
+		shutdownTimeout := time.NewTimer(20 * time.Second)
 
 		shutdownOk := make(chan struct{}, 1)
 		go func() {
@@ -371,7 +474,38 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 
 		// Allow for known race condition described in NewHttpProxy():
 		time.Sleep(1 * time.Second)
-		fetchWebsite(t, httpProxyPort)
+
+		fetchAndVerifyWebsite(t, httpProxyPort)
+
+		// Test: run for duration, periodically using the tunnel to
+		// ensure failed tunnel detection, and ultimately hitting
+		// impaired protocol checks.
+
+		startTime := time.Now()
+
+		for {
+
+			time.Sleep(1 * time.Second)
+			useTunnel(t, httpProxyPort)
+
+			if startTime.Add(runConfig.runDuration).Before(time.Now()) {
+				break
+			}
+		}
+
+		// Test: with disruptNetwork, impaired protocols should be exercised
+
+		if runConfig.runDuration > 0 && runConfig.disruptNetwork {
+			count := atomic.LoadInt32(&impairedProtocolCount)
+			if count <= 0 {
+				t.Fatalf("unexpected impaired protocol count: %d", count)
+			} else {
+				impairedProtocolClassification.RLock()
+				t.Logf("impaired protocol classification: %+v",
+					impairedProtocolClassification.classification)
+				impairedProtocolClassification.RUnlock()
+			}
+		}
 	}
 
 	// Test: upgrade check/download must be downloaded within 120 seconds
@@ -385,6 +519,15 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 			t.Fatalf("upgrade downloaded unexpectedly")
 		}
 
+		// Test: with disruptNetwork, must be multiple download progress notices
+
+		if runConfig.disruptNetwork {
+			count := atomic.LoadInt32(&clientUpgradeDownloadedBytesCount)
+			if count <= 1 {
+				t.Fatalf("unexpected upgrade download progress: %d", count)
+			}
+		}
+
 	case <-confirmedLatestVersion:
 		if !runConfig.clientIsLatestVersion {
 			t.Fatalf("confirmed latest version unexpectedly")
@@ -393,15 +536,6 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 	case <-upgradeTimeout.C:
 		t.Fatalf("upgrade download timeout exceeded")
 	}
-
-	// Test: with disruptNetwork, must be multiple download progress notices
-
-	if runConfig.disruptNetwork && !runConfig.clientIsLatestVersion {
-		count := atomic.LoadInt32(&clientUpgradeDownloadedBytesCount)
-		if count <= 1 {
-			t.Fatalf("unexpected upgrade download progress: %d", count)
-		}
-	}
 }
 
 type TestHostNameTransformer struct {
@@ -411,7 +545,7 @@ func (TestHostNameTransformer) TransformHostName(string) (string, bool) {
 	return "example.com", true
 }
 
-func fetchWebsite(t *testing.T, httpProxyPort int) {
+func fetchAndVerifyWebsite(t *testing.T, httpProxyPort int) {
 
 	testUrl := "https://raw.githubusercontent.com/Psiphon-Labs/psiphon-tunnel-core/master/LICENSE"
 	roundTripTimeout := 10 * time.Second
@@ -494,10 +628,33 @@ func fetchWebsite(t *testing.T, httpProxyPort int) {
 	}
 }
 
+func useTunnel(t *testing.T, httpProxyPort int) {
+
+	// No action on errors as the tunnel is expected to fail sometimes
+
+	testUrl := "https://psiphon3.com"
+	roundTripTimeout := 1 * time.Second
+	proxyUrl, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", httpProxyPort))
+	if err != nil {
+		return
+	}
+	httpClient := &http.Client{
+		Transport: &http.Transport{
+			Proxy: http.ProxyURL(proxyUrl),
+		},
+		Timeout: roundTripTimeout,
+	}
+	response, err := httpClient.Get(testUrl)
+	if err != nil {
+		return
+	}
+	response.Body.Close()
+}
+
 const disruptorProxyAddress = "127.0.0.1:2160"
 const disruptorProxyURL = "socks4a://" + disruptorProxyAddress
 const disruptorMaxConnectionBytes = 2000000
-const disruptorMaxConnectionTime = 15 * time.Second
+const disruptorMaxConnectionTime = 10 * time.Second
 
 func initDisruptor() {
 

+ 5 - 0
psiphon/notice.go

@@ -214,6 +214,11 @@ func NoticeTunnels(count int) {
 	outputNotice("Tunnels", false, false, "count", count)
 }
 
+func NoticeImpairedProtocolClassification(impairedProtocolClassification map[string]int) {
+	outputNotice("ImpairedProtocolClassification", false, false,
+		"classification", impairedProtocolClassification)
+}
+
 // NoticeUntunneled indicates than an address has been classified as untunneled and is being
 // accessed directly.
 //