Browse Source

Process random streams concurrently

Rod Hynes 4 years ago
parent
commit
20fb618d2f
2 changed files with 54 additions and 20 deletions
  1. 18 8
      psiphon/server/tunnelServer.go
  2. 36 12
      psiphon/tunnel.go

+ 18 - 8
psiphon/server/tunnelServer.go

@@ -2161,7 +2161,7 @@ func (sshClient *sshClient) handleNewRandomStreamChannel(
 	// is available pre-handshake, albeit with additional restrictions.
 	//
 	// The random stream is subject to throttling in traffic rules; for
-	// unthrottled liveness tests, set initial   Read/WriteUnthrottledBytes as
+	// unthrottled liveness tests, set initial Read/WriteUnthrottledBytes as
 	// required. The random stream maximum count and response size cap
 	// mitigate clients abusing the facility to waste server resources.
 	//
@@ -2228,18 +2228,26 @@ func (sshClient *sshClient) handleNewRandomStreamChannel(
 	go func() {
 		defer waitGroup.Done()
 
+		upstream := new(sync.WaitGroup)
 		received := 0
 		sent := 0
 
 		if request.UpstreamBytes > 0 {
-			n, err := io.CopyN(ioutil.Discard, channel, int64(request.UpstreamBytes))
-			received = int(n)
-			if err != nil {
-				if !isExpectedTunnelIOError(err) {
-					log.WithTraceFields(LogFields{"error": err}).Warning("receive failed")
+
+			// Process streams concurrently to minimize elapsed time. This also
+			// avoids a unidirectional flow burst early in the tunnel lifecycle.
+
+			upstream.Add(1)
+			go func() {
+				defer upstream.Done()
+				n, err := io.CopyN(ioutil.Discard, channel, int64(request.UpstreamBytes))
+				received = int(n)
+				if err != nil {
+					if !isExpectedTunnelIOError(err) {
+						log.WithTraceFields(LogFields{"error": err}).Warning("receive failed")
+					}
 				}
-				// Fall through and record any bytes received...
-			}
+			}()
 		}
 
 		if request.DownstreamBytes > 0 {
@@ -2252,6 +2260,8 @@ func (sshClient *sshClient) handleNewRandomStreamChannel(
 			}
 		}
 
+		upstream.Wait()
+
 		sshClient.Lock()
 		metrics.upstreamBytes += request.UpstreamBytes
 		metrics.receivedUpstreamBytes += received

+ 36 - 12
psiphon/tunnel.go

@@ -1183,28 +1183,52 @@ func performLivenessTest(
 
 	go ssh.DiscardRequests(requests)
 
-	// In consideration of memory-constrained environments, use a modest-sized
-	// copy buffer since many tunnel establishment workers may run the
-	// liveness test concurrently.
-
-	var buffer [8192]byte
+	sent := 0
+	received := 0
+	upstream := new(sync.WaitGroup)
+	var errUpstream, errDownstream error
 
 	if metrics.UpstreamBytes > 0 {
-		n, err := common.CopyNBuffer(channel, rand.Reader, int64(metrics.UpstreamBytes), buffer[:])
-		metrics.SentUpstreamBytes = int(n)
-		if err != nil {
-			return metrics, errors.Trace(err)
-		}
+
+		// Process streams concurrently to minimize elapsed time. This also
+		// avoids a unidirectional flow burst early in the tunnel lifecycle.
+
+		upstream.Add(1)
+		go func() {
+			defer upstream.Done()
+
+			// In consideration of memory-constrained environments, use modest-sized copy
+			// buffers since many tunnel establishment workers may run the liveness test
+			// concurrently.
+			var buffer [4096]byte
+
+			n, err := common.CopyNBuffer(channel, rand.Reader, int64(metrics.UpstreamBytes), buffer[:])
+			sent = int(n)
+			if err != nil {
+				errUpstream = errors.Trace(err)
+			}
+		}()
 	}
 
 	if metrics.DownstreamBytes > 0 {
+		var buffer [4096]byte
 		n, err := common.CopyNBuffer(ioutil.Discard, channel, int64(metrics.DownstreamBytes), buffer[:])
-		metrics.ReceivedDownstreamBytes = int(n)
+		received = int(n)
 		if err != nil {
-			return metrics, errors.Trace(err)
+			errDownstream = errors.Trace(err)
 		}
 	}
 
+	upstream.Wait()
+	metrics.SentUpstreamBytes = sent
+	metrics.ReceivedDownstreamBytes = received
+
+	if errUpstream != nil {
+		return metrics, errUpstream
+	} else if errDownstream != nil {
+		return metrics, errDownstream
+	}
+
 	return metrics, nil
 }