Просмотр исходного кода

Monitor packet tunnel channel operations

- Explicitly log packet tunnel establishment success and
  probe SSH keep alive success to assist with diagnostic
  tracing.

- Add timeout to packet tunnel channel dial.
Rod Hynes 5 лет назад
Родитель
Сommit
7bc5f04050
2 измененных файлов с 88 добавлено и 57 удалено
  1. 1 1
      psiphon/packetTunnelTransport.go
  2. 87 56
      psiphon/tunnel.go

+ 1 - 1
psiphon/packetTunnelTransport.go

@@ -223,7 +223,7 @@ func (p *PacketTunnelTransport) UseTunnel(tunnel *Tunnel) {
 			// Note: DialPacketTunnelChannel will signal a probe on failure,
 			// so it's not necessary to do so here.
 
-			NoticeWarning("dial packet tunnel channel failed : %s", err)
+			NoticeWarning("dial packet tunnel channel failed: %s", err)
 			// TODO: retry?
 			return
 		}

+ 87 - 56
psiphon/tunnel.go

@@ -423,15 +423,58 @@ func (tunnel *Tunnel) SendAPIRequest(
 // Dial establishes a port forward connection through the tunnel
 // This Dial doesn't support split tunnel, so alwaysTunnel is not referenced
 func (tunnel *Tunnel) Dial(
-	remoteAddr string, alwaysTunnel bool, downstreamConn net.Conn) (conn net.Conn, err error) {
+	remoteAddr string, alwaysTunnel bool, downstreamConn net.Conn) (net.Conn, error) {
 
-	if !tunnel.IsActivated() {
-		return nil, errors.TraceNew("tunnel is not activated")
+	channel, err := tunnel.dialChannel("tcp", remoteAddr)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	netConn, ok := channel.(net.Conn)
+	if !ok {
+		return nil, errors.Tracef("unexpected channel type: %T", channel)
+	}
+
+	conn := &TunneledConn{
+		Conn:           netConn,
+		tunnel:         tunnel,
+		downstreamConn: downstreamConn}
+
+	return tunnel.wrapWithTransferStats(conn), nil
+}
+
+func (tunnel *Tunnel) DialPacketTunnelChannel() (net.Conn, error) {
+
+	channel, err := tunnel.dialChannel(protocol.PACKET_TUNNEL_CHANNEL_TYPE, "")
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	sshChannel, ok := channel.(ssh.Channel)
+	if !ok {
+		return nil, errors.Tracef("unexpected channel type: %T", channel)
 	}
 
-	type tunnelDialResult struct {
-		sshPortForwardConn net.Conn
-		err                error
+	NoticeInfo("DialPacketTunnelChannel: established channel")
+
+	conn := newChannelConn(sshChannel)
+
+	// wrapWithTransferStats will track bytes transferred for the
+	// packet tunnel. It will count packet overhead (TCP/UDP/IP headers).
+	//
+	// Since the data in the channel is not HTTP or TLS, no domain bytes
+	// counting is expected.
+	//
+	// transferstats are also used to determine that there's been recent
+	// activity and skip periodic SSH keep alives; see Tunnel.operateTunnel.
+
+	return tunnel.wrapWithTransferStats(conn), nil
+}
+
+func (tunnel *Tunnel) dialChannel(channelType, remoteAddr string) (interface{}, error) {
+
+	if !tunnel.IsActivated() {
+		return nil, errors.TraceNew("tunnel is not activated")
 	}
 
 	// Note: there is no dial context since SSH port forward dials cannot
@@ -439,26 +482,44 @@ func (tunnel *Tunnel) Dial(
 	// A timeout is set to unblock this function, but the goroutine may
 	// not exit until the tunnel is closed.
 
-	// Use a buffer of 1 as there are two senders and only one guaranteed receive.
+	type channelDialResult struct {
+		channel interface{}
+		err     error
+	}
 
-	resultChannel := make(chan *tunnelDialResult, 1)
+	// Use a buffer of 1 as there are two senders and only one guaranteed receive.
 
-	timeout := tunnel.getCustomClientParameters().Duration(
-		parameters.TunnelPortForwardDialTimeout)
+	results := make(chan *channelDialResult, 1)
 
 	afterFunc := time.AfterFunc(
-		timeout,
+		tunnel.getCustomClientParameters().Duration(
+			parameters.TunnelPortForwardDialTimeout),
 		func() {
-			resultChannel <- &tunnelDialResult{nil, errors.TraceNew("tunnel dial timeout")}
+			results <- &channelDialResult{
+				nil, errors.Tracef("channel dial timeout: %s", channelType)}
 		})
 	defer afterFunc.Stop()
 
 	go func() {
-		sshPortForwardConn, err := tunnel.sshClient.Dial("tcp", remoteAddr)
-		resultChannel <- &tunnelDialResult{sshPortForwardConn, err}
+		result := new(channelDialResult)
+		if channelType == "tcp" {
+			result.channel, result.err =
+				tunnel.sshClient.Dial("tcp", remoteAddr)
+		} else {
+			var sshRequests <-chan *ssh.Request
+			result.channel, sshRequests, result.err =
+				tunnel.sshClient.OpenChannel(channelType, nil)
+			if result.err == nil {
+				go ssh.DiscardRequests(sshRequests)
+			}
+		}
+		if result.err != nil {
+			result.err = errors.Trace(result.err)
+		}
+		results <- result
 	}()
 
-	result := <-resultChannel
+	result := <-results
 
 	if result.err != nil {
 		// TODO: conditional on type of error or error message?
@@ -469,44 +530,7 @@ func (tunnel *Tunnel) Dial(
 		return nil, errors.Trace(result.err)
 	}
 
-	conn = &TunneledConn{
-		Conn:           result.sshPortForwardConn,
-		tunnel:         tunnel,
-		downstreamConn: downstreamConn}
-
-	return tunnel.wrapWithTransferStats(conn), nil
-}
-
-func (tunnel *Tunnel) DialPacketTunnelChannel() (net.Conn, error) {
-
-	if !tunnel.IsActivated() {
-		return nil, errors.TraceNew("tunnel is not activated")
-	}
-	channel, requests, err := tunnel.sshClient.OpenChannel(
-		protocol.PACKET_TUNNEL_CHANNEL_TYPE, nil)
-	if err != nil {
-		// TODO: conditional on type of error or error message?
-		select {
-		case tunnel.signalPortForwardFailure <- struct{}{}:
-		default:
-		}
-
-		return nil, errors.Trace(err)
-	}
-	go ssh.DiscardRequests(requests)
-
-	conn := newChannelConn(channel)
-
-	// wrapWithTransferStats will track bytes transferred for the
-	// packet tunnel. It will count packet overhead (TCP/UDP/IP headers).
-	//
-	// Since the data in the channel is not HTTP or TLS, no domain bytes
-	// counting is expected.
-	//
-	// transferstats are also used to determine that there's been recent
-	// activity and skip periodic SSH keep alives; see Tunnel.operateTunnel.
-
-	return tunnel.wrapWithTransferStats(conn), nil
+	return result.channel, nil
 }
 
 func (tunnel *Tunnel) wrapWithTransferStats(conn net.Conn) net.Conn {
@@ -1152,7 +1176,7 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 			bytesUp := atomic.LoadInt64(&totalSent)
 			bytesDown := atomic.LoadInt64(&totalReceived)
 			err := tunnel.sendSshKeepAlive(
-				isFirstPeriodicKeepAlive, timeout, bytesUp, bytesDown)
+				isFirstPeriodicKeepAlive, false, timeout, bytesUp, bytesDown)
 			if err != nil {
 				select {
 				case sshKeepAliveError <- err:
@@ -1175,7 +1199,7 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 			bytesUp := atomic.LoadInt64(&totalSent)
 			bytesDown := atomic.LoadInt64(&totalReceived)
 			err := tunnel.sendSshKeepAlive(
-				false, timeout, bytesUp, bytesDown)
+				false, true, timeout, bytesUp, bytesDown)
 			if err != nil {
 				select {
 				case sshKeepAliveError <- err:
@@ -1361,6 +1385,7 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 // closed, which will terminate the associated tunnel.
 func (tunnel *Tunnel) sendSshKeepAlive(
 	isFirstPeriodicKeepAlive bool,
+	isProbeKeepAlive bool,
 	timeout time.Duration,
 	bytesUp int64,
 	bytesDown int64) error {
@@ -1412,6 +1437,12 @@ func (tunnel *Tunnel) sendSshKeepAlive(
 
 		errChannel <- err
 
+		success := (err == nil && requestOk)
+
+		if success && isProbeKeepAlive {
+			NoticeInfo("Probe SSH keep-alive RTT: %s", elapsedTime)
+		}
+
 		// Record the keep alive round trip as a speed test sample. The first
 		// periodic keep alive is always recorded, as many tunnels are short-lived
 		// and we want to ensure that some data is gathered. Subsequent keep alives
@@ -1419,7 +1450,7 @@ func (tunnel *Tunnel) sendSshKeepAlive(
 		// only the last SpeedTestMaxSampleCount samples are retained, enables
 		// tuning the sampling frequency.
 
-		if err == nil && requestOk && speedTestSample {
+		if success && speedTestSample {
 
 			err = tactics.AddSpeedTestSample(
 				tunnel.config.GetClientParameters(),