|
|
@@ -423,15 +423,58 @@ func (tunnel *Tunnel) SendAPIRequest(
|
|
|
// Dial establishes a port forward connection through the tunnel
|
|
|
// This Dial doesn't support split tunnel, so alwaysTunnel is not referenced
|
|
|
func (tunnel *Tunnel) Dial(
|
|
|
- remoteAddr string, alwaysTunnel bool, downstreamConn net.Conn) (conn net.Conn, err error) {
|
|
|
+ remoteAddr string, alwaysTunnel bool, downstreamConn net.Conn) (net.Conn, error) {
|
|
|
|
|
|
- if !tunnel.IsActivated() {
|
|
|
- return nil, errors.TraceNew("tunnel is not activated")
|
|
|
+ channel, err := tunnel.dialChannel("tcp", remoteAddr)
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ netConn, ok := channel.(net.Conn)
|
|
|
+ if !ok {
|
|
|
+ return nil, errors.Tracef("unexpected channel type: %T", channel)
|
|
|
+ }
|
|
|
+
|
|
|
+ conn := &TunneledConn{
|
|
|
+ Conn: netConn,
|
|
|
+ tunnel: tunnel,
|
|
|
+ downstreamConn: downstreamConn}
|
|
|
+
|
|
|
+ return tunnel.wrapWithTransferStats(conn), nil
|
|
|
+}
|
|
|
+
|
|
|
+func (tunnel *Tunnel) DialPacketTunnelChannel() (net.Conn, error) {
|
|
|
+
|
|
|
+ channel, err := tunnel.dialChannel(protocol.PACKET_TUNNEL_CHANNEL_TYPE, "")
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ sshChannel, ok := channel.(ssh.Channel)
|
|
|
+ if !ok {
|
|
|
+ return nil, errors.Tracef("unexpected channel type: %T", channel)
|
|
|
}
|
|
|
|
|
|
- type tunnelDialResult struct {
|
|
|
- sshPortForwardConn net.Conn
|
|
|
- err error
|
|
|
+ NoticeInfo("DialPacketTunnelChannel: established channel")
|
|
|
+
|
|
|
+ conn := newChannelConn(sshChannel)
|
|
|
+
|
|
|
+ // wrapWithTransferStats will track bytes transferred for the
|
|
|
+ // packet tunnel. It will count packet overhead (TCP/UDP/IP headers).
|
|
|
+ //
|
|
|
+ // Since the data in the channel is not HTTP or TLS, no domain bytes
|
|
|
+ // counting is expected.
|
|
|
+ //
|
|
|
+ // transferstats are also used to determine that there's been recent
|
|
|
+ // activity and skip periodic SSH keep alives; see Tunnel.operateTunnel.
|
|
|
+
|
|
|
+ return tunnel.wrapWithTransferStats(conn), nil
|
|
|
+}
|
|
|
+
|
|
|
+func (tunnel *Tunnel) dialChannel(channelType, remoteAddr string) (interface{}, error) {
|
|
|
+
|
|
|
+ if !tunnel.IsActivated() {
|
|
|
+ return nil, errors.TraceNew("tunnel is not activated")
|
|
|
}
|
|
|
|
|
|
// Note: there is no dial context since SSH port forward dials cannot
|
|
|
@@ -439,26 +482,44 @@ func (tunnel *Tunnel) Dial(
|
|
|
// A timeout is set to unblock this function, but the goroutine may
|
|
|
// not exit until the tunnel is closed.
|
|
|
|
|
|
- // Use a buffer of 1 as there are two senders and only one guaranteed receive.
|
|
|
+ type channelDialResult struct {
|
|
|
+ channel interface{}
|
|
|
+ err error
|
|
|
+ }
|
|
|
|
|
|
- resultChannel := make(chan *tunnelDialResult, 1)
|
|
|
+ // Use a buffer of 1 as there are two senders and only one guaranteed receive.
|
|
|
|
|
|
- timeout := tunnel.getCustomClientParameters().Duration(
|
|
|
- parameters.TunnelPortForwardDialTimeout)
|
|
|
+ results := make(chan *channelDialResult, 1)
|
|
|
|
|
|
afterFunc := time.AfterFunc(
|
|
|
- timeout,
|
|
|
+ tunnel.getCustomClientParameters().Duration(
|
|
|
+ parameters.TunnelPortForwardDialTimeout),
|
|
|
func() {
|
|
|
- resultChannel <- &tunnelDialResult{nil, errors.TraceNew("tunnel dial timeout")}
|
|
|
+ results <- &channelDialResult{
|
|
|
+ nil, errors.Tracef("channel dial timeout: %s", channelType)}
|
|
|
})
|
|
|
defer afterFunc.Stop()
|
|
|
|
|
|
go func() {
|
|
|
- sshPortForwardConn, err := tunnel.sshClient.Dial("tcp", remoteAddr)
|
|
|
- resultChannel <- &tunnelDialResult{sshPortForwardConn, err}
|
|
|
+ result := new(channelDialResult)
|
|
|
+ if channelType == "tcp" {
|
|
|
+ result.channel, result.err =
|
|
|
+ tunnel.sshClient.Dial("tcp", remoteAddr)
|
|
|
+ } else {
|
|
|
+ var sshRequests <-chan *ssh.Request
|
|
|
+ result.channel, sshRequests, result.err =
|
|
|
+ tunnel.sshClient.OpenChannel(channelType, nil)
|
|
|
+ if result.err == nil {
|
|
|
+ go ssh.DiscardRequests(sshRequests)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if result.err != nil {
|
|
|
+ result.err = errors.Trace(result.err)
|
|
|
+ }
|
|
|
+ results <- result
|
|
|
}()
|
|
|
|
|
|
- result := <-resultChannel
|
|
|
+ result := <-results
|
|
|
|
|
|
if result.err != nil {
|
|
|
// TODO: conditional on type of error or error message?
|
|
|
@@ -469,44 +530,7 @@ func (tunnel *Tunnel) Dial(
|
|
|
return nil, errors.Trace(result.err)
|
|
|
}
|
|
|
|
|
|
- conn = &TunneledConn{
|
|
|
- Conn: result.sshPortForwardConn,
|
|
|
- tunnel: tunnel,
|
|
|
- downstreamConn: downstreamConn}
|
|
|
-
|
|
|
- return tunnel.wrapWithTransferStats(conn), nil
|
|
|
-}
|
|
|
-
|
|
|
-func (tunnel *Tunnel) DialPacketTunnelChannel() (net.Conn, error) {
|
|
|
-
|
|
|
- if !tunnel.IsActivated() {
|
|
|
- return nil, errors.TraceNew("tunnel is not activated")
|
|
|
- }
|
|
|
- channel, requests, err := tunnel.sshClient.OpenChannel(
|
|
|
- protocol.PACKET_TUNNEL_CHANNEL_TYPE, nil)
|
|
|
- if err != nil {
|
|
|
- // TODO: conditional on type of error or error message?
|
|
|
- select {
|
|
|
- case tunnel.signalPortForwardFailure <- struct{}{}:
|
|
|
- default:
|
|
|
- }
|
|
|
-
|
|
|
- return nil, errors.Trace(err)
|
|
|
- }
|
|
|
- go ssh.DiscardRequests(requests)
|
|
|
-
|
|
|
- conn := newChannelConn(channel)
|
|
|
-
|
|
|
- // wrapWithTransferStats will track bytes transferred for the
|
|
|
- // packet tunnel. It will count packet overhead (TCP/UDP/IP headers).
|
|
|
- //
|
|
|
- // Since the data in the channel is not HTTP or TLS, no domain bytes
|
|
|
- // counting is expected.
|
|
|
- //
|
|
|
- // transferstats are also used to determine that there's been recent
|
|
|
- // activity and skip periodic SSH keep alives; see Tunnel.operateTunnel.
|
|
|
-
|
|
|
- return tunnel.wrapWithTransferStats(conn), nil
|
|
|
+ return result.channel, nil
|
|
|
}
|
|
|
|
|
|
func (tunnel *Tunnel) wrapWithTransferStats(conn net.Conn) net.Conn {
|
|
|
@@ -1152,7 +1176,7 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
|
|
|
bytesUp := atomic.LoadInt64(&totalSent)
|
|
|
bytesDown := atomic.LoadInt64(&totalReceived)
|
|
|
err := tunnel.sendSshKeepAlive(
|
|
|
- isFirstPeriodicKeepAlive, timeout, bytesUp, bytesDown)
|
|
|
+ isFirstPeriodicKeepAlive, false, timeout, bytesUp, bytesDown)
|
|
|
if err != nil {
|
|
|
select {
|
|
|
case sshKeepAliveError <- err:
|
|
|
@@ -1175,7 +1199,7 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
|
|
|
bytesUp := atomic.LoadInt64(&totalSent)
|
|
|
bytesDown := atomic.LoadInt64(&totalReceived)
|
|
|
err := tunnel.sendSshKeepAlive(
|
|
|
- false, timeout, bytesUp, bytesDown)
|
|
|
+ false, true, timeout, bytesUp, bytesDown)
|
|
|
if err != nil {
|
|
|
select {
|
|
|
case sshKeepAliveError <- err:
|
|
|
@@ -1361,6 +1385,7 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
|
|
|
// closed, which will terminate the associated tunnel.
|
|
|
func (tunnel *Tunnel) sendSshKeepAlive(
|
|
|
isFirstPeriodicKeepAlive bool,
|
|
|
+ isProbeKeepAlive bool,
|
|
|
timeout time.Duration,
|
|
|
bytesUp int64,
|
|
|
bytesDown int64) error {
|
|
|
@@ -1412,6 +1437,12 @@ func (tunnel *Tunnel) sendSshKeepAlive(
|
|
|
|
|
|
errChannel <- err
|
|
|
|
|
|
+ success := (err == nil && requestOk)
|
|
|
+
|
|
|
+ if success && isProbeKeepAlive {
|
|
|
+ NoticeInfo("Probe SSH keep-alive RTT: %s", elapsedTime)
|
|
|
+ }
|
|
|
+
|
|
|
// Record the keep alive round trip as a speed test sample. The first
|
|
|
// periodic keep alive is always recorded, as many tunnels are short-lived
|
|
|
// and we want to ensure that some data is gathered. Subsequent keep alives
|
|
|
@@ -1419,7 +1450,7 @@ func (tunnel *Tunnel) sendSshKeepAlive(
|
|
|
// only the last SpeedTestMaxSampleCount samples are retained, enables
|
|
|
// tuning the sampling frequency.
|
|
|
|
|
|
- if err == nil && requestOk && speedTestSample {
|
|
|
+ if success && speedTestSample {
|
|
|
|
|
|
err = tactics.AddSpeedTestSample(
|
|
|
tunnel.config.GetClientParameters(),
|