Преглед изворни кода

Merge remote-tracking branch 'upstream/master'

Adam Pritchard пре 8 година
родитељ
комит
c1d15c6411

+ 8 - 7
psiphon/LookupIP.go

@@ -38,6 +38,13 @@ import (
 // socket, binds it to the device, and makes an explicit DNS request
 // to the specified DNS resolver.
 func LookupIP(host string, config *DialConfig) (addrs []net.IP, err error) {
+
+	// When the input host is an IP address, echo it back
+	ipAddr := net.ParseIP(host)
+	if ipAddr != nil {
+		return []net.IP{ipAddr}, nil
+	}
+
 	if config.DeviceBinder != nil {
 		addrs, err = bindLookupIP(host, config.DnsServerGetter.GetPrimaryDnsServer(), config)
 		if err == nil {
@@ -63,14 +70,8 @@ func LookupIP(host string, config *DialConfig) (addrs []net.IP, err error) {
 // https://code.google.com/p/go/issues/detail?id=6966
 func bindLookupIP(host, dnsServer string, config *DialConfig) (addrs []net.IP, err error) {
 
-	// When the input host is an IP address, echo it back
-	ipAddr := net.ParseIP(host)
-	if ipAddr != nil {
-		return []net.IP{ipAddr}, nil
-	}
-
 	// config.DnsServerGetter.GetDnsServers() must return IP addresses
-	ipAddr = net.ParseIP(dnsServer)
+	ipAddr := net.ParseIP(dnsServer)
 	if ipAddr == nil {
 		return nil, common.ContextError(errors.New("invalid IP address"))
 	}

+ 28 - 0
psiphon/common/protocol/protocol.go

@@ -20,6 +20,7 @@
 package protocol
 
 import (
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
 )
 
@@ -96,6 +97,26 @@ func TunnelProtocolUsesObfuscatedSessionTickets(protocol string) bool {
 	return protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET
 }
 
+func UseClientTunnelProtocol(
+	clientProtocol string,
+	serverProtocols []string) bool {
+
+	// When the server is running _both_ fronted HTTP and
+	// fronted HTTPS, use the client's reported tunnel
+	// protocol since some CDNs forward both to the same
+	// server port; in this case the server port is not
+	// sufficient to distinguish these protocols.
+	if (clientProtocol == TUNNEL_PROTOCOL_FRONTED_MEEK ||
+		clientProtocol == TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP) &&
+		common.Contains(serverProtocols, TUNNEL_PROTOCOL_FRONTED_MEEK) &&
+		common.Contains(serverProtocols, TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP) {
+
+		return true
+	}
+
+	return false
+}
+
 type HandshakeResponse struct {
 	SSHSessionID         string              `json:"ssh_session_id"`
 	Homepages            []string            `json:"homepages"`
@@ -121,3 +142,10 @@ type SSHPasswordPayload struct {
 	SshPassword        string   `json:"SshPassword"`
 	ClientCapabilities []string `json:"ClientCapabilities"`
 }
+
+type MeekCookieData struct {
+	ServerAddress        string `json:"p"`
+	SessionID            string `json:"s"`
+	MeekProtocolVersion  int    `json:"v"`
+	ClientTunnelProtocol string `json:"t"`
+}

+ 3 - 3
psiphon/common/utils_test.go

@@ -34,7 +34,7 @@ func TestMakeRandomPeriod(t *testing.T) {
 	res1, err := MakeRandomPeriod(min, max)
 
 	if err != nil {
-		t.Error("MakeRandomPeriod failed: %s", err)
+		t.Errorf("MakeRandomPeriod failed: %s", err)
 	}
 
 	if res1 < min {
@@ -48,7 +48,7 @@ func TestMakeRandomPeriod(t *testing.T) {
 	res2, err := MakeRandomPeriod(min, max)
 
 	if err != nil {
-		t.Error("MakeRandomPeriod failed: %s", err)
+		t.Errorf("MakeRandomPeriod failed: %s", err)
 	}
 
 	if res1 == res2 {
@@ -104,7 +104,7 @@ func TestCompress(t *testing.T) {
 
 	decompressedData, err := Decompress(compressedData)
 	if err != nil {
-		t.Error("Uncompress failed: %s", err)
+		t.Errorf("Uncompress failed: %s", err)
 	}
 
 	if bytes.Compare(originalData, decompressedData) != 0 {

+ 2 - 1
psiphon/config.go

@@ -340,7 +340,8 @@ type Config struct {
 	UpgradeDownloadClientVersionHeader string
 
 	// UpgradeDownloadFilename is the local target filename for an upgrade download.
-	// This parameter is required when UpgradeDownloadURLs is specified.
+	// This parameter is required when UpgradeDownloadURLs (or UpgradeDownloadUrl)
+	// is specified.
 	// Data is stored in co-located files (UpgradeDownloadFilename.part*) to allow
 	// for resumable downloading.
 	UpgradeDownloadFilename string

+ 202 - 125
psiphon/meekConn.go

@@ -21,6 +21,7 @@ package psiphon
 
 import (
 	"bytes"
+	"context"
 	"crypto/rand"
 	"encoding/base64"
 	"encoding/json"
@@ -37,6 +38,7 @@ import (
 	"github.com/Psiphon-Inc/crypto/nacl/box"
 	"github.com/Psiphon-Inc/goarista/monotime"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/upstreamproxy"
 )
 
@@ -48,7 +50,7 @@ import (
 // https://bitbucket.org/psiphon/psiphon-circumvention-system/src/default/go/meek-client/meek-client.go
 
 const (
-	MEEK_PROTOCOL_VERSION          = 2
+	MEEK_PROTOCOL_VERSION          = 3
 	MEEK_COOKIE_MAX_PADDING        = 32
 	MAX_SEND_PAYLOAD_LENGTH        = 65536
 	FULL_RECEIVE_BUFFER_LENGTH     = 4194304
@@ -59,7 +61,7 @@ const (
 	MAX_POLL_INTERVAL_JITTER       = 0.1
 	POLL_INTERVAL_MULTIPLIER       = 1.5
 	POLL_INTERVAL_JITTER           = 0.1
-	MEEK_ROUND_TRIP_RETRY_DEADLINE = 1 * time.Second
+	MEEK_ROUND_TRIP_RETRY_DEADLINE = 5 * time.Second
 	MEEK_ROUND_TRIP_RETRY_DELAY    = 50 * time.Millisecond
 	MEEK_ROUND_TRIP_TIMEOUT        = 20 * time.Second
 )
@@ -98,6 +100,12 @@ type MeekConfig struct {
 	// in effect. This value is used for stats reporting.
 	TransformedHostName bool
 
+	// ClientTunnelProtocol is the protocol the client is using. It's
+	// included in the meek cookie for optional use by the server, in
+	// cases where the server cannot unambiguously determine the
+	// tunnel protocol.
+	ClientTunnelProtocol string
+
 	// The following values are used to create the obfuscated meek cookie.
 
 	PsiphonServerAddress          string
@@ -125,7 +133,8 @@ type MeekConn struct {
 	transport            transporter
 	mutex                sync.Mutex
 	isClosed             bool
-	broadcastClosed      chan struct{}
+	runContext           context.Context
+	stopRunning          context.CancelFunc
 	relayWaitGroup       *sync.WaitGroup
 	emptyReceiveBuffer   chan *bytes.Buffer
 	partialReceiveBuffer chan *bytes.Buffer
@@ -298,6 +307,8 @@ func DialMeek(
 		return nil, common.ContextError(err)
 	}
 
+	runContext, stopRunning := context.WithCancel(context.Background())
+
 	// The main loop of a MeekConn is run in the relay() goroutine.
 	// A MeekConn implements net.Conn concurrency semantics:
 	// "Multiple goroutines may invoke methods on a Conn simultaneously."
@@ -321,7 +332,8 @@ func DialMeek(
 		pendingConns:         pendingConns,
 		transport:            transport,
 		isClosed:             false,
-		broadcastClosed:      make(chan struct{}),
+		runContext:           runContext,
+		stopRunning:          stopRunning,
 		relayWaitGroup:       new(sync.WaitGroup),
 		emptyReceiveBuffer:   make(chan *bytes.Buffer, 1),
 		partialReceiveBuffer: make(chan *bytes.Buffer, 1),
@@ -356,7 +368,7 @@ func (meek *MeekConn) Close() (err error) {
 	meek.mutex.Unlock()
 
 	if !isClosed {
-		close(meek.broadcastClosed)
+		meek.stopRunning()
 		meek.pendingConns.CloseAll()
 		meek.relayWaitGroup.Wait()
 		meek.transport.CloseIdleConnections()
@@ -386,7 +398,7 @@ func (meek *MeekConn) Read(buffer []byte) (n int, err error) {
 	select {
 	case receiveBuffer = <-meek.partialReceiveBuffer:
 	case receiveBuffer = <-meek.fullReceiveBuffer:
-	case <-meek.broadcastClosed:
+	case <-meek.runContext.Done():
 		return 0, common.ContextError(errors.New("meek connection has closed"))
 	}
 	n, err = receiveBuffer.Read(buffer)
@@ -408,7 +420,7 @@ func (meek *MeekConn) Write(buffer []byte) (n int, err error) {
 		select {
 		case sendBuffer = <-meek.emptySendBuffer:
 		case sendBuffer = <-meek.partialSendBuffer:
-		case <-meek.broadcastClosed:
+		case <-meek.runContext.Done():
 			return 0, common.ContextError(errors.New("meek connection has closed"))
 		}
 		writeLen := MAX_SEND_PAYLOAD_LENGTH - sendBuffer.Len()
@@ -490,6 +502,7 @@ func (meek *MeekConn) relay() {
 
 	for {
 		timeout.Reset(interval)
+
 		// Block until there is payload to send or it is time to poll
 		var sendBuffer *bytes.Buffer
 		select {
@@ -497,10 +510,17 @@ func (meek *MeekConn) relay() {
 		case sendBuffer = <-meek.fullSendBuffer:
 		case <-timeout.C:
 			// In the polling case, send an empty payload
-		case <-meek.broadcastClosed:
-			// TODO: timeout case may be selected when broadcastClosed is set?
+		case <-meek.runContext.Done():
+			// Drop through to second Done() check
+		}
+
+		// Check Done() again, to ensure it takes precedence
+		select {
+		case <-meek.runContext.Done():
 			return
+		default:
 		}
+
 		sendPayloadSize := 0
 		if sendBuffer != nil {
 			var err error
@@ -512,18 +532,16 @@ func (meek *MeekConn) relay() {
 				return
 			}
 		}
-		receivedPayload, err := meek.roundTrip(sendPayload[:sendPayloadSize])
-		if err != nil {
-			NoticeAlert("%s", common.ContextError(err))
-			go meek.Close()
-			return
-		}
-		if receivedPayload == nil {
-			// In this case, meek.roundTrip encountered broadcastClosed. Exit without error.
-			return
-		}
-		receivedPayloadSize, err := meek.readPayload(receivedPayload)
+
+		receivedPayloadSize, err := meek.roundTrip(sendPayload[:sendPayloadSize])
+
 		if err != nil {
+			select {
+			case <-meek.runContext.Done():
+				// In this case, meek.roundTrip encountered Done(). Exit without logging error.
+				return
+			default:
+			}
 			NoticeAlert("%s", common.ContextError(err))
 			go meek.Close()
 			return
@@ -566,143 +584,202 @@ func (meek *MeekConn) relay() {
 	}
 }
 
-// readPayload reads the HTTP response  in chunks, making the read buffer available
-// to MeekConn.Read() calls after each chunk; the intention is to allow bytes to
-// flow back to the reader as soon as possible instead of buffering the entire payload.
-func (meek *MeekConn) readPayload(receivedPayload io.ReadCloser) (totalSize int64, err error) {
-	defer receivedPayload.Close()
-	totalSize = 0
-	for {
-		reader := io.LimitReader(receivedPayload, READ_PAYLOAD_CHUNK_LENGTH)
-		// Block until there is capacity in the receive buffer
-		var receiveBuffer *bytes.Buffer
-		select {
-		case receiveBuffer = <-meek.emptyReceiveBuffer:
-		case receiveBuffer = <-meek.partialReceiveBuffer:
-		case <-meek.broadcastClosed:
-			return 0, nil
-		}
-		// Note: receiveBuffer size may exceed FULL_RECEIVE_BUFFER_LENGTH by up to the size
-		// of one received payload. The FULL_RECEIVE_BUFFER_LENGTH value is just a threshold.
-		n, err := receiveBuffer.ReadFrom(reader)
-		meek.replaceReceiveBuffer(receiveBuffer)
-		if err != nil {
-			return 0, common.ContextError(err)
-		}
-		totalSize += n
-		if n == 0 {
-			break
-		}
-	}
-	return totalSize, nil
-}
-
 // roundTrip configures and makes the actual HTTP POST request
-func (meek *MeekConn) roundTrip(sendPayload []byte) (io.ReadCloser, error) {
+func (meek *MeekConn) roundTrip(sendPayload []byte) (int64, error) {
 
-	// The retry mitigates intermittent failures between the client and front/server.
+	// Retries are made when the round trip fails. This adds resiliency
+	// to connection interruption and intermittent failures.
+	//
+	// At least one retry is always attempted, and retries continue
+	// while still within a brief deadline -- 5 seconds, currently the
+	// deadline for an actively probed SSH connection to timeout. There
+	// is a brief delay between retries, allowing for intermittent
+	// failure states to resolve.
 	//
-	// Note: Retry will only be effective if entire request failed (underlying transport protocol
-	// such as SSH will fail if extra bytes are replayed in either direction due to partial relay
-	// success followed by retry).
-	// At least one retry is always attempted. We retry when still within a brief deadline and wait
-	// for a short time before re-dialing.
+	// Failure may occur at various stages of the HTTP request:
 	//
-	// TODO: in principle, we could retry for min(TUNNEL_WRITE_TIMEOUT, meek-server.MAX_SESSION_STALENESS),
-	// i.e., as long as the underlying tunnel has not timed out and as long as the server has not
-	// expired the current meek session. Presently not doing this to avoid excessive connection attempts
-	// through the first hop. In addition, this will require additional support for timely shutdown.
+	// 1. Before the request begins. In this case, the entire request
+	//    may be rerun.
+	//
+	// 2. While sending the request payload. In this case, the client
+	//    must resend its request payload. The server will not have
+	//    relayed its partially received request payload.
+	//
+	// 3. After sending the request payload but before receiving
+	//    a response. The client cannot distinguish between case 2 and
+	//    this case, case 3. The client resends its payload and the
+	//    server detects this and skips relaying the request payload.
+	//
+	// 4. While reading the response payload. The client will omit its
+	//    request payload when retrying, as the server has already
+	//    acknowleged it. The client will also indicate to the server
+	//    the amount of response payload already received, and the
+	//    server will skip resending the indicated amount of response
+	//    payload.
+	//
+	// Retries are indicated to the server by adding a Range header,
+	// which includes the response payload resend position.
+
 	retries := uint(0)
 	retryDeadline := monotime.Now().Add(MEEK_ROUND_TRIP_RETRY_DEADLINE)
+	serverAcknowlegedRequestPayload := false
+	receivedPayloadSize := int64(0)
 
-	var err error
-	var response *http.Response
-	for {
+	for try := 0; ; try++ {
+
+		// Omit the request payload when retrying after receiving a
+		// partial server response.
+
+		var sendPayloadReader io.Reader
+		if !serverAcknowlegedRequestPayload {
+			sendPayloadReader = bytes.NewReader(sendPayload)
+		}
 
 		var request *http.Request
-		request, err = http.NewRequest("POST", meek.url.String(), bytes.NewReader(sendPayload))
+		request, err := http.NewRequest("POST", meek.url.String(), sendPayloadReader)
 		if err != nil {
 			// Don't retry when can't initialize a Request
-			break
+			return 0, common.ContextError(err)
 		}
 
-		request.Header.Set("Content-Type", "application/octet-stream")
+		// Note: meek.stopRunning() will abort a round trip in flight
+		request = request.WithContext(meek.runContext)
 
-		// Set additional headers to the HTTP request using the same method we use for adding
-		// custom headers to HTTP proxy requests
-		for name, value := range meek.additionalHeaders {
-			// hack around special case of "Host" header
-			// https://golang.org/src/net/http/request.go#L474
-			// using URL.Opaque, see URL.RequestURI() https://golang.org/src/net/url/url.go#L915
-			if name == "Host" {
-				if len(value) > 0 {
-					if request.URL.Opaque == "" {
-						request.URL.Opaque = request.URL.Scheme + "://" + request.Host + request.URL.RequestURI()
-					}
-					request.Host = value[0]
-				}
-			} else {
-				request.Header[name] = value
-			}
-		}
+		meek.addAdditionalHeaders(request)
 
+		request.Header.Set("Content-Type", "application/octet-stream")
 		request.AddCookie(meek.cookie)
 
-		// The http.Transport.RoundTrip is run in a goroutine to enable cancelling a request in-flight.
-		type roundTripResponse struct {
-			response *http.Response
-			err      error
+		expectedStatusCode := http.StatusOK
+
+		// When retrying, add a Range header to indicate how much
+		// of the response was already received.
+
+		if try > 0 {
+			expectedStatusCode = http.StatusPartialContent
+			request.Header.Set("Range", fmt.Sprintf("bytes=%d-", receivedPayloadSize))
 		}
-		roundTripResponseChannel := make(chan *roundTripResponse, 1)
-		roundTripWaitGroup := new(sync.WaitGroup)
-		roundTripWaitGroup.Add(1)
-		go func() {
-			defer roundTripWaitGroup.Done()
-			r, err := meek.transport.RoundTrip(request)
-			roundTripResponseChannel <- &roundTripResponse{r, err}
-		}()
-		select {
-		case roundTripResponse := <-roundTripResponseChannel:
-			response = roundTripResponse.response
-			err = roundTripResponse.err
-		case <-meek.broadcastClosed:
-			meek.transport.CancelRequest(request)
-			return nil, nil
+
+		response, err := meek.transport.RoundTrip(request)
+		if err != nil {
+			select {
+			case <-meek.runContext.Done():
+				// Exit without retrying and without logging error.
+				return 0, common.ContextError(err)
+			default:
+			}
+			NoticeAlert("meek round trip failed: %s", err)
+			// ...continue to retry
 		}
-		roundTripWaitGroup.Wait()
 
 		if err == nil {
-			break
+
+			if response.StatusCode != expectedStatusCode {
+				// Don't retry when the status code is incorrect
+				response.Body.Close()
+				return 0, common.ContextError(
+					fmt.Errorf(
+						"unexpected status code: %d instead of %d",
+						response.StatusCode, expectedStatusCode))
+			}
+
+			// Update meek session cookie
+			for _, c := range response.Cookies() {
+				if meek.cookie.Name == c.Name {
+					meek.cookie.Value = c.Value
+					break
+				}
+			}
+
+			// Received the response status code, so the server
+			// must have received the request payload.
+			serverAcknowlegedRequestPayload = true
+
+			readPayloadSize, err := meek.readPayload(response.Body)
+			response.Body.Close()
+
+			// receivedPayloadSize is the number of response
+			// payload bytes received and relayed. A retry can
+			// resume after this position.
+			receivedPayloadSize += readPayloadSize
+
+			if err != nil {
+				NoticeAlert("meek read payload failed: %s", err)
+				// ...continue to retry
+			} else {
+				// Round trip completed successfully
+				break
+			}
 		}
 
+		// Either the request failed entirely, or there was a failure
+		// streaming the response payload. Retry, if time remains.
+
 		if retries >= 1 && monotime.Now().After(retryDeadline) {
-			break
+			return 0, common.ContextError(err)
 		}
 		retries += 1
 
 		time.Sleep(MEEK_ROUND_TRIP_RETRY_DELAY)
 	}
-	if err != nil {
-		return nil, common.ContextError(err)
-	}
-	if response.StatusCode != http.StatusOK {
-		return nil, common.ContextError(fmt.Errorf("http request failed %d", response.StatusCode))
-	}
-	// observe response cookies for meek session key token.
-	// Once found it must be used for all consecutive requests made to the server
-	for _, c := range response.Cookies() {
-		if meek.cookie.Name == c.Name {
-			meek.cookie.Value = c.Value
-			break
+
+	return receivedPayloadSize, nil
+}
+
+// Add additional headers to the HTTP request using the same method we use for adding
+// custom headers to HTTP proxy requests.
+func (meek *MeekConn) addAdditionalHeaders(request *http.Request) {
+	for name, value := range meek.additionalHeaders {
+		// hack around special case of "Host" header
+		// https://golang.org/src/net/http/request.go#L474
+		// using URL.Opaque, see URL.RequestURI() https://golang.org/src/net/url/url.go#L915
+		if name == "Host" {
+			if len(value) > 0 {
+				if request.URL.Opaque == "" {
+					request.URL.Opaque = request.URL.Scheme + "://" + request.Host + request.URL.RequestURI()
+				}
+				request.Host = value[0]
+			}
+		} else {
+			request.Header[name] = value
 		}
 	}
-	return response.Body, nil
 }
 
-type meekCookieData struct {
-	ServerAddress       string `json:"p"`
-	SessionID           string `json:"s"`
-	MeekProtocolVersion int    `json:"v"`
+// readPayload reads the HTTP response in chunks, making the read buffer available
+// to MeekConn.Read() calls after each chunk; the intention is to allow bytes to
+// flow back to the reader as soon as possible instead of buffering the entire payload.
+//
+// When readPayload returns an error, the totalSize output is remains valid -- it's the
+// number of payload bytes successfully read and relayed.
+func (meek *MeekConn) readPayload(
+	receivedPayload io.ReadCloser) (totalSize int64, err error) {
+
+	defer receivedPayload.Close()
+	totalSize = 0
+	for {
+		reader := io.LimitReader(receivedPayload, READ_PAYLOAD_CHUNK_LENGTH)
+		// Block until there is capacity in the receive buffer
+		var receiveBuffer *bytes.Buffer
+		select {
+		case receiveBuffer = <-meek.emptyReceiveBuffer:
+		case receiveBuffer = <-meek.partialReceiveBuffer:
+		case <-meek.runContext.Done():
+			return 0, nil
+		}
+		// Note: receiveBuffer size may exceed FULL_RECEIVE_BUFFER_LENGTH by up to the size
+		// of one received payload. The FULL_RECEIVE_BUFFER_LENGTH value is just a guideline.
+		n, err := receiveBuffer.ReadFrom(reader)
+		meek.replaceReceiveBuffer(receiveBuffer)
+		totalSize += n
+		if err != nil {
+			return totalSize, common.ContextError(err)
+		}
+		if n == 0 {
+			break
+		}
+	}
+	return totalSize, nil
 }
 
 // makeCookie creates the cookie to be sent with initial meek HTTP request.
@@ -721,7 +798,7 @@ func makeMeekCookie(meekConfig *MeekConfig) (cookie *http.Cookie, err error) {
 
 	// Make the JSON data
 	serverAddress := meekConfig.PsiphonServerAddress
-	cookieData := &meekCookieData{
+	cookieData := &protocol.MeekCookieData{
 		ServerAddress:       serverAddress,
 		SessionID:           meekConfig.SessionID,
 		MeekProtocolVersion: MEEK_PROTOCOL_VERSION,

+ 17 - 16
psiphon/pluginProtocol.go

@@ -29,26 +29,28 @@ import (
 
 var registeredPluginProtocolDialer atomic.Value
 
-// PluginProtocolNetDialer is a base network dialer that's used
-// by PluginProtocolDialer to make its IP network connections. This
-// is used, for example, to create TCPConns as the base TCP
-// connections used by the plugin protocol.
-type PluginProtocolNetDialer func(network, addr string) (net.Conn, error)
-
 // PluginProtocolDialer creates a connection to addr over a
-// plugin protocol. It uses netDialer to create its base network
+// plugin protocol. It uses dialConfig to create its base network
 // connection(s) and sends its log messages to loggerOutput.
+//
+// To ensure timely interruption and shutdown, each
+// PluginProtocolDialerimplementation must:
+//
+// - Places its outer net.Conn in pendingConns and leave it
+//   there unless an error occurs
+// - Replace the dialConfig.pendingConns with its own
+//   PendingConns and use that to ensure base network
+//   connections are interrupted when Close() is invoked on
+//   the returned net.Conn.
+//
 // PluginProtocolDialer returns true if it attempts to create
 // a connection, or false if it decides not to attempt a connection.
-// PluginProtocolDialer must add its connection to pendingConns
-// before the initial dial to allow for interruption.
 type PluginProtocolDialer func(
 	config *Config,
 	loggerOutput io.Writer,
 	pendingConns *common.Conns,
-	netDialer PluginProtocolNetDialer,
-	addr string) (
-	bool, net.Conn, error)
+	addr string,
+	dialConfig *DialConfig) (bool, net.Conn, error)
 
 // RegisterPluginProtocol sets the current plugin protocol
 // dialer.
@@ -62,14 +64,13 @@ func DialPluginProtocol(
 	config *Config,
 	loggerOutput io.Writer,
 	pendingConns *common.Conns,
-	netDialer PluginProtocolNetDialer,
-	addr string) (
-	bool, net.Conn, error) {
+	addr string,
+	dialConfig *DialConfig) (bool, net.Conn, error) {
 
 	dialer := registeredPluginProtocolDialer.Load()
 	if dialer != nil {
 		return dialer.(PluginProtocolDialer)(
-			config, loggerOutput, pendingConns, netDialer, addr)
+			config, loggerOutput, pendingConns, addr, dialConfig)
 	}
 	return false, nil, nil
 }

+ 30 - 0
psiphon/server/config.go

@@ -179,6 +179,36 @@ type Config struct {
 	// used as the client IP.
 	MeekProxyForwardedForHeaders []string
 
+	// MeekCachedResponseBufferSize is the size of a private,
+	// fixed-size buffer allocated for every meek client. The buffer
+	// is used to cache response payload, allowing the client to retry
+	// fetching when a network connection is interrupted. This retry
+	// makes the OSSH tunnel within meek resilient to interruptions
+	// at the HTTP TCP layer.
+	// Larger buffers increase resiliency to interruption, but consume
+	// more memory as buffers as never freed. The maximum size of a
+	// response payload is a function of client activity, network
+	// throughput and throttling.
+	// A default of 64K is used when MeekCachedResponseBufferSize is 0.
+	MeekCachedResponseBufferSize int
+
+	// MeekCachedResponsePoolBufferSize is the size of a fixed-size,
+	// shared buffer used to temporarily extend a private buffer when
+	// MeekCachedResponseBufferSize is insufficient. Shared buffers
+	// allow some clients to sucessfully retry longer response payloads
+	// without allocating large buffers for all clients.
+	// A default of 64K is used when MeekCachedResponsePoolBufferSize
+	// is 0.
+	MeekCachedResponsePoolBufferSize int
+
+	// MeekCachedResponsePoolBufferCount is the number of shared
+	// buffers. Shared buffers are allocated on first use and remain
+	// allocated, so shared buffer count * size is roughly the memory
+	// overhead of this facility.
+	// A default of 2048 is used when MeekCachedResponsePoolBufferCount
+	// is 0.
+	MeekCachedResponsePoolBufferCount int
+
 	// UDPInterceptUdpgwServerAddress specifies the network address of
 	// a udpgw server which clients may be port forwarding to. When
 	// specified, these TCP port forwards are intercepted and handled

+ 8 - 0
psiphon/server/log.go

@@ -33,6 +33,14 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
+// MetricsSource is an object that provides metrics to be logged
+type MetricsSource interface {
+
+	// GetMetrics returns a LogFields populated with
+	// metrics from the MetricsSource
+	GetMetrics() LogFields
+}
+
 // ContextLogger adds context logging functionality to the
 // underlying logging packages.
 type ContextLogger struct {

+ 369 - 78
psiphon/server/meek.go

@@ -27,9 +27,11 @@ import (
 	"encoding/hex"
 	"encoding/json"
 	"errors"
+	"hash/crc64"
 	"io"
 	"net"
 	"net/http"
+	"strconv"
 	"strings"
 	"sync"
 	"sync/atomic"
@@ -38,6 +40,7 @@ import (
 	"github.com/Psiphon-Inc/crypto/nacl/box"
 	"github.com/Psiphon-Inc/goarista/monotime"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tls"
 )
 
@@ -54,20 +57,27 @@ const (
 	// report no version number and expect at most 64K response bodies.
 	MEEK_PROTOCOL_VERSION_1 = 1
 
-	// Protocol version 2 clients initiate a session by sending a encrypted and obfuscated meek
+	// Protocol version 2 clients initiate a session by sending an encrypted and obfuscated meek
 	// cookie with their initial HTTP request. Connection information is contained within the
 	// encrypted cookie payload. The server inspects the cookie and establishes a new session and
 	// returns a new random session ID back to client via Set-Cookie header. The client uses this
 	// session ID on all subsequent requests for the remainder of the session.
 	MEEK_PROTOCOL_VERSION_2 = 2
 
-	MEEK_MAX_PAYLOAD_LENGTH           = 0x10000
-	MEEK_TURN_AROUND_TIMEOUT          = 20 * time.Millisecond
-	MEEK_EXTENDED_TURN_AROUND_TIMEOUT = 100 * time.Millisecond
-	MEEK_MAX_SESSION_STALENESS        = 45 * time.Second
-	MEEK_HTTP_CLIENT_IO_TIMEOUT       = 45 * time.Second
-	MEEK_MIN_SESSION_ID_LENGTH        = 8
-	MEEK_MAX_SESSION_ID_LENGTH        = 20
+	// Protocol version 3 clients include resiliency enhancements and will add a Range header
+	// when retrying a request for a partially downloaded response payload.
+	MEEK_PROTOCOL_VERSION_3 = 3
+
+	MEEK_MAX_REQUEST_PAYLOAD_LENGTH     = 65536
+	MEEK_TURN_AROUND_TIMEOUT            = 20 * time.Millisecond
+	MEEK_EXTENDED_TURN_AROUND_TIMEOUT   = 100 * time.Millisecond
+	MEEK_MAX_SESSION_STALENESS          = 45 * time.Second
+	MEEK_HTTP_CLIENT_IO_TIMEOUT         = 45 * time.Second
+	MEEK_MIN_SESSION_ID_LENGTH          = 8
+	MEEK_MAX_SESSION_ID_LENGTH          = 20
+	MEEK_DEFAULT_RESPONSE_BUFFER_LENGTH = 65536
+	MEEK_DEFAULT_POOL_BUFFER_LENGTH     = 65536
+	MEEK_DEFAULT_POOL_BUFFER_COUNT      = 2048
 )
 
 // MeekServer implements the meek protocol, which tunnels TCP traffic (in the case of Psiphon,
@@ -85,11 +95,13 @@ type MeekServer struct {
 	support       *SupportServices
 	listener      net.Listener
 	tlsConfig     *tls.Config
-	clientHandler func(clientConn net.Conn)
+	clientHandler func(clientTunnelProtocol string, clientConn net.Conn)
 	openConns     *common.Conns
 	stopBroadcast <-chan struct{}
 	sessionsLock  sync.RWMutex
 	sessions      map[string]*meekSession
+	checksumTable *crc64.Table
+	bufferPool    *CachedResponseBufferPool
 }
 
 // NewMeekServer initializes a new meek server.
@@ -97,9 +109,23 @@ func NewMeekServer(
 	support *SupportServices,
 	listener net.Listener,
 	useTLS, useObfuscatedSessionTickets bool,
-	clientHandler func(clientConn net.Conn),
+	clientHandler func(clientTunnelProtocol string, clientConn net.Conn),
 	stopBroadcast <-chan struct{}) (*MeekServer, error) {
 
+	checksumTable := crc64.MakeTable(crc64.ECMA)
+
+	bufferLength := MEEK_DEFAULT_POOL_BUFFER_LENGTH
+	if support.Config.MeekCachedResponsePoolBufferSize != 0 {
+		bufferLength = support.Config.MeekCachedResponsePoolBufferSize
+	}
+
+	bufferCount := MEEK_DEFAULT_POOL_BUFFER_COUNT
+	if support.Config.MeekCachedResponsePoolBufferCount != 0 {
+		bufferCount = support.Config.MeekCachedResponsePoolBufferCount
+	}
+
+	bufferPool := NewCachedResponseBufferPool(bufferLength, bufferCount)
+
 	meekServer := &MeekServer{
 		support:       support,
 		listener:      listener,
@@ -107,6 +133,8 @@ func NewMeekServer(
 		openConns:     new(common.Conns),
 		stopBroadcast: stopBroadcast,
 		sessions:      make(map[string]*meekSession),
+		checksumTable: checksumTable,
+		bufferPool:    bufferPool,
 	}
 
 	if useTLS {
@@ -172,7 +200,7 @@ func (server *MeekServer) Run() error {
 	// Note: Serve() will be interrupted by listener.Close() call
 	var err error
 	if server.tlsConfig != nil {
-		httpsServer := HTTPSServer{Server: *httpServer}
+		httpsServer := HTTPSServer{Server: httpServer}
 		err = httpsServer.ServeTLS(server.listener, server.tlsConfig)
 	} else {
 		err = httpServer.Serve(server.listener)
@@ -236,19 +264,58 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
 		return
 	}
 
+	// Ensure that there's only one concurrent request handler per client
+	// session. Depending on the nature of a network disruption, it can
+	// happen that a client detects a failure and retries while the server
+	// is still streaming response in the handler for the _previous_ client
+	// request.
+	//
+	// Even if the session.cachedResponse were safe for concurrent
+	// use (it is not), concurrent handling could lead to loss of session
+	// since upstream data read by the first request may not reach the
+	// cached response before the second request reads the cached data.
+	//
+	// The existing handler will stream response data, holding the lock,
+	// for no more than MEEK_EXTENDED_TURN_AROUND_TIMEOUT.
+	//
+	// TODO: interrupt an existing handler? The existing handler will be
+	// sending data to the cached response, but if that buffer fills, the
+	// session will be lost.
+
+	requestNumber := atomic.AddInt64(&session.requestCount, 1)
+
+	// Wait for the existing request to complete.
+	session.lock.Lock()
+	defer session.lock.Unlock()
+
+	// If a newer request has arrived while waiting, discard this one.
+	// Do not delay processing the newest request.
+	if atomic.LoadInt64(&session.requestCount) > requestNumber {
+		server.terminateConnection(responseWriter, request)
+		return
+	}
+
 	// pumpReads causes a TunnelServer/SSH goroutine blocking on a Read to
 	// read the request body as upstream traffic.
 	// TODO: run pumpReads and pumpWrites concurrently?
 
+	// pumpReads checksums the request payload and skips relaying it when
+	// it matches the immediately previous request payload. This allows
+	// clients to resend request payloads, when retrying due to connection
+	// interruption, without knowing whether the server has received or
+	// relayed the data.
+
 	err = session.clientConn.pumpReads(request.Body)
 	if err != nil {
 		if err != io.EOF {
 			// Debug since errors such as "i/o timeout" occur during normal operation;
 			// also, golang network error messages may contain client IP.
-			log.WithContextFields(LogFields{"error": err}).Debug("pump reads failed")
+			log.WithContextFields(LogFields{"error": err}).Debug("read request failed")
 		}
 		server.terminateConnection(responseWriter, request)
-		server.closeSession(sessionID)
+
+		// Note: keep session open to allow client to retry
+
 		return
 	}
 
@@ -262,22 +329,134 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
 		session.sessionIDSent = true
 	}
 
-	// pumpWrites causes a TunnelServer/SSH goroutine blocking on a Write to
-	// write its downstream traffic through to the response body.
+	// When streaming data into the response body, a copy is
+	// retained in the cachedResponse buffer. This allows the
+	// client to retry and request that the response be resent
+	// when the HTTP connection is interrupted.
+	//
+	// If a Range header is present, the client is retrying,
+	// possibly after having received a partial response. In
+	// this case, use any cached response to attempt to resend
+	// the response, starting from the resend position the client
+	// indicates.
+	//
+	// When the resend position is not available -- because the
+	// cachedResponse buffer could not hold it -- the client session
+	// is closed, as there's no way to resume streaming the payload
+	// uninterrupted.
+	//
+	// The client may retry before a cached response is prepared,
+	// so a cached response is not always used when a Range header
+	// is present.
+	//
+	// TODO: invalid Range header is ignored; should it be otherwise?
+
+	position, isRetry := checkRangeHeader(request)
+	if isRetry {
+		atomic.AddInt64(&session.metricClientRetries, 1)
+	}
 
-	err = session.clientConn.pumpWrites(responseWriter)
-	if err != nil {
-		if err != io.EOF {
+	hasCompleteCachedResponse := session.cachedResponse.HasPosition(0)
+
+	// The client is not expected to send position > 0 when there is
+	// no cached response; let that case fall through to the next
+	// HasPosition check which will fail and close the session.
+
+	var responseSize int
+	var responseError error
+
+	if isRetry && (hasCompleteCachedResponse || position > 0) {
+
+		if !session.cachedResponse.HasPosition(position) {
+			greaterThanSwapInt64(&session.metricCachedResponseMissPosition, int64(position))
+			server.terminateConnection(responseWriter, request)
+			server.closeSession(sessionID)
+			return
+		}
+
+		responseWriter.WriteHeader(http.StatusPartialContent)
+
+		// TODO:
+		// - enforce a max extended buffer count per client, for
+		//   fairness? Throttling may make this unnecessary.
+		// - cachedResponse can now start releasing extended buffers,
+		//   as response bytes before "position" will never be requested
+		//   again?
+
+		responseSize, responseError = session.cachedResponse.CopyFromPosition(position, responseWriter)
+		greaterThanSwapInt64(&session.metricPeakCachedResponseHitSize, int64(responseSize))
+
+		// The client may again fail to receive the payload and may again
+		// retry, so not yet releasing cachedReponse buffers.
+
+	} else {
+
+		// _Now_ we release buffers holding data from the previous
+		// response. And then immediately stream the new response into
+		// newly acquired buffers.
+		session.cachedResponse.Reset()
+
+		// Note: this code depends on an implementation detail of
+		// io.MultiWriter: a Write() to the MultiWriter writes first
+		// to the cache, and then to the response writer. So if the
+		// write to the reponse writer fails, the payload is cached.
+		multiWriter := io.MultiWriter(session.cachedResponse, responseWriter)
+
+		// The client expects 206, not 200, whenever it sets a Range header,
+		// which it may do even when no cached response is prepared.
+		if isRetry {
+			responseWriter.WriteHeader(http.StatusPartialContent)
+		}
+
+		// pumpWrites causes a TunnelServer/SSH goroutine blocking on a Write to
+		// write its downstream traffic through to the response body.
+
+		responseSize, responseError = session.clientConn.pumpWrites(multiWriter)
+		greaterThanSwapInt64(&session.metricPeakResponseSize, int64(responseSize))
+		greaterThanSwapInt64(&session.metricPeakCachedResponseSize, int64(session.cachedResponse.Available()))
+	}
+
+	// responseError is the result of writing the body either from CopyFromPosition or pumpWrites
+	if responseError != nil {
+		if responseError != io.EOF {
 			// Debug since errors such as "i/o timeout" occur during normal operation;
 			// also, golang network error messages may contain client IP.
-			log.WithContextFields(LogFields{"error": err}).Debug("pump writes failed")
+			log.WithContextFields(LogFields{"error": responseError}).Debug("write response failed")
 		}
 		server.terminateConnection(responseWriter, request)
-		server.closeSession(sessionID)
+
+		// Note: keep session open to allow client to retry
+
 		return
 	}
 }
 
+func checkRangeHeader(request *http.Request) (int, bool) {
+	rangeHeader := request.Header.Get("Range")
+	if rangeHeader == "" {
+		return 0, false
+	}
+
+	prefix := "bytes="
+	suffix := "-"
+
+	if !strings.HasPrefix(rangeHeader, prefix) ||
+		!strings.HasSuffix(rangeHeader, suffix) {
+
+		return 0, false
+	}
+
+	rangeHeader = strings.TrimPrefix(rangeHeader, prefix)
+	rangeHeader = strings.TrimSuffix(rangeHeader, suffix)
+	position, err := strconv.Atoi(rangeHeader)
+
+	if err != nil {
+		return 0, false
+	}
+
+	return position, true
+}
+
 // getSession returns the meek client session corresponding the
 // meek cookie/session ID. If no session is found, the cookie is
 // treated as a meek cookie for a new session and its payload is
@@ -307,13 +486,9 @@ func (server *MeekServer) getSession(
 		return "", nil, common.ContextError(err)
 	}
 
-	// Note: this meek server ignores all but Version MeekProtocolVersion;
-	// the other values are legacy or currently unused.
-	var clientSessionData struct {
-		MeekProtocolVersion    int    `json:"v"`
-		PsiphonClientSessionId string `json:"s"`
-		PsiphonServerAddress   string `json:"p"`
-	}
+	// Note: this meek server ignores legacy values PsiphonClientSessionId
+	// and PsiphonServerAddress.
+	var clientSessionData protocol.MeekCookieData
 
 	err = json.Unmarshal(payloadJSON, &clientSessionData)
 	if err != nil {
@@ -345,6 +520,22 @@ func (server *MeekServer) getSession(
 		}
 	}
 
+	// Create a new session
+
+	bufferLength := MEEK_DEFAULT_RESPONSE_BUFFER_LENGTH
+	if server.support.Config.MeekCachedResponseBufferSize != 0 {
+		bufferLength = server.support.Config.MeekCachedResponseBufferSize
+	}
+	cachedResponse := NewCachedResponse(bufferLength, server.bufferPool)
+
+	session = &meekSession{
+		meekProtocolVersion: clientSessionData.MeekProtocolVersion,
+		sessionIDSent:       false,
+		cachedResponse:      cachedResponse,
+	}
+
+	session.touch()
+
 	// Create a new meek conn that will relay the payload
 	// between meek request/responses and the tunnel server client
 	// handler. The client IP is also used to initialize the
@@ -354,18 +545,15 @@ func (server *MeekServer) getSession(
 	// Assumes clientIP is a valid IP address; the port value is a stub
 	// and is expected to be ignored.
 	clientConn := newMeekConn(
+		server,
+		session,
 		&net.TCPAddr{
 			IP:   net.ParseIP(clientIP),
 			Port: 0,
 		},
 		clientSessionData.MeekProtocolVersion)
 
-	session = &meekSession{
-		clientConn:          clientConn,
-		meekProtocolVersion: clientSessionData.MeekProtocolVersion,
-		sessionIDSent:       false,
-	}
-	session.touch()
+	session.clientConn = clientConn
 
 	// Note: MEEK_PROTOCOL_VERSION_1 doesn't support changing the
 	// meek cookie to a session ID; v1 clients always send the
@@ -388,7 +576,7 @@ func (server *MeekServer) getSession(
 
 	// Note: from the tunnel server's perspective, this client connection
 	// will close when closeSessionHelper calls Close() on the meekConn.
-	server.clientHandler(session.clientConn)
+	server.clientHandler(clientSessionData.ClientTunnelProtocol, session.clientConn)
 
 	return sessionID, session, nil
 }
@@ -398,6 +586,10 @@ func (server *MeekServer) closeSessionHelper(
 
 	// TODO: close the persistent HTTP client connection, if one exists
 	session.clientConn.Close()
+
+	// Release all extended buffers back to the pool
+	session.cachedResponse.Reset()
+
 	// Note: assumes caller holds lock on sessionsLock
 	delete(server.sessions, sessionID)
 }
@@ -455,10 +647,18 @@ type meekSession struct {
 	// Note: 64-bit ints used with atomic operations are at placed
 	// at the start of struct to ensure 64-bit alignment.
 	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
-	lastActivity        int64
-	clientConn          *meekConn
-	meekProtocolVersion int
-	sessionIDSent       bool
+	lastActivity                     int64
+	requestCount                     int64
+	metricClientRetries              int64
+	metricPeakResponseSize           int64
+	metricPeakCachedResponseSize     int64
+	metricPeakCachedResponseHitSize  int64
+	metricCachedResponseMissPosition int64
+	lock                             sync.Mutex
+	clientConn                       *meekConn
+	meekProtocolVersion              int
+	sessionIDSent                    bool
+	cachedResponse                   *CachedResponse
 }
 
 func (session *meekSession) touch() {
@@ -470,6 +670,17 @@ func (session *meekSession) expired() bool {
 	return monotime.Since(lastActivity) > MEEK_MAX_SESSION_STALENESS
 }
 
+// GetMetrics implements the MetricsSource interface
+func (session *meekSession) GetMetrics() LogFields {
+	logFields := make(LogFields)
+	logFields["meek_client_retries"] = atomic.LoadInt64(&session.metricClientRetries)
+	logFields["meek_peak_response_size"] = atomic.LoadInt64(&session.metricPeakResponseSize)
+	logFields["meek_peak_cached_response_size"] = atomic.LoadInt64(&session.metricPeakCachedResponseSize)
+	logFields["meek_peak_cached_response_hit_size"] = atomic.LoadInt64(&session.metricPeakCachedResponseHitSize)
+	logFields["meek_cached_response_miss_position"] = atomic.LoadInt64(&session.metricCachedResponseMissPosition)
+	return logFields
+}
+
 // makeMeekTLSConfig creates a TLS config for a meek HTTPS listener.
 // Currently, this config is optimized for fronted meek where the nature
 // of the connection is non-circumvention; it's optimized for performance
@@ -629,10 +840,13 @@ func makeMeekSessionID() (string, error) {
 // meekConn bridges net/http request/response payload readers and writers
 // and goroutines calling Read()s and Write()s.
 type meekConn struct {
+	meekServer        *MeekServer
+	meekSession       *meekSession
 	remoteAddr        net.Addr
 	protocolVersion   int
 	closeBroadcast    chan struct{}
 	closed            int32
+	lastReadChecksum  *uint64
 	readLock          sync.Mutex
 	emptyReadBuffer   chan *bytes.Buffer
 	partialReadBuffer chan *bytes.Buffer
@@ -642,8 +856,15 @@ type meekConn struct {
 	writeResult       chan error
 }
 
-func newMeekConn(remoteAddr net.Addr, protocolVersion int) *meekConn {
+func newMeekConn(
+	meekServer *MeekServer,
+	meekSession *meekSession,
+	remoteAddr net.Addr,
+	protocolVersion int) *meekConn {
+
 	conn := &meekConn{
+		meekServer:        meekServer,
+		meekSession:       meekSession,
 		remoteAddr:        remoteAddr,
 		protocolVersion:   protocolVersion,
 		closeBroadcast:    make(chan struct{}),
@@ -664,33 +885,84 @@ func newMeekConn(remoteAddr net.Addr, protocolVersion int) *meekConn {
 // pumpReads causes goroutines blocking on meekConn.Read() to read
 // from the specified reader. This function blocks until the reader
 // is fully consumed or the meekConn is closed. A read buffer allows
-// up to MEEK_MAX_PAYLOAD_LENGTH bytes to be read and buffered without
-// a Read() immediately consuming the bytes, but there's still a
-// possibility of a stall if no Read() calls are made after this
+// up to MEEK_MAX_REQUEST_PAYLOAD_LENGTH bytes to be read and buffered
+// without a Read() immediately consuming the bytes, but there's still
+// a possibility of a stall if no Read() calls are made after this
 // read buffer is full.
 // Note: assumes only one concurrent call to pumpReads
 func (conn *meekConn) pumpReads(reader io.Reader) error {
-	for {
 
-		var readBuffer *bytes.Buffer
-		select {
-		case readBuffer = <-conn.emptyReadBuffer:
-		case readBuffer = <-conn.partialReadBuffer:
-		case <-conn.closeBroadcast:
-			return io.EOF
-		}
+	// Wait for a full capacity empty buffer. This ensures we can read
+	// the maximum MEEK_MAX_REQUEST_PAYLOAD_LENGTH request payload and
+	// checksum before relaying.
+	//
+	// Note: previously, this code would select conn.partialReadBuffer
+	// and write to that, looping until the entire request payload was
+	// read. Now, the consumer, the Read() caller, must fully drain the
+	// read buffer first.
 
-		limitReader := io.LimitReader(reader, int64(MEEK_MAX_PAYLOAD_LENGTH-readBuffer.Len()))
-		n, err := readBuffer.ReadFrom(limitReader)
+	// Use either an empty or partial buffer. By using a partial
+	// buffer, pumpReads will not block if the Read() caller has
+	// not fully drained the read buffer.
 
+	var readBuffer *bytes.Buffer
+	select {
+	case readBuffer = <-conn.emptyReadBuffer:
+	case readBuffer = <-conn.partialReadBuffer:
+	case <-conn.closeBroadcast:
+		return io.EOF
+	}
+
+	newDataOffset := readBuffer.Len()
+
+	// Since we need to read the full request payload in order to
+	// take its checksum before relaying it, the read buffer can
+	// grow to up to 2 x MEEK_MAX_REQUEST_PAYLOAD_LENGTH + 1.
+
+	// +1 allows for an explict check for request payloads that
+	// exceed the maximum permitted length.
+	limitReader := io.LimitReader(reader, MEEK_MAX_REQUEST_PAYLOAD_LENGTH+1)
+	n, err := readBuffer.ReadFrom(limitReader)
+
+	if err == nil && n == MEEK_MAX_REQUEST_PAYLOAD_LENGTH+1 {
+		err = errors.New("invalid request payload length")
+	}
+
+	// If the request read fails, don't relay the new data. This allows
+	// the client to retry and resend its request payload without
+	// interrupting/duplicating the payload flow.
+	if err != nil {
+		readBuffer.Truncate(newDataOffset)
 		conn.replaceReadBuffer(readBuffer)
+		return common.ContextError(err)
+	}
 
-		if n == 0 || err != nil {
-			return err
-		}
+	// Check if request payload checksum matches immediately
+	// previous payload. On match, assume this is a client retry
+	// sending payload that was already relayed and skip this
+	// payload. Payload is OSSH ciphertext and almost surely
+	// will not repeat. In the highly unlikely case that it does,
+	// the underlying SSH connection will fail and the client
+	// must reconnect.
+
+	checksum := crc64.Checksum(
+		readBuffer.Bytes()[newDataOffset:], conn.meekServer.checksumTable)
+
+	if conn.lastReadChecksum == nil {
+		conn.lastReadChecksum = new(uint64)
+	} else if *conn.lastReadChecksum == checksum {
+		readBuffer.Truncate(newDataOffset)
 	}
+
+	*conn.lastReadChecksum = checksum
+
+	conn.replaceReadBuffer(readBuffer)
+
+	return nil
 }
 
+var errMeekConnectionHasClosed = errors.New("meek connection has closed")
+
 // Read reads from the meekConn into buffer. Read blocks until
 // some data is read or the meekConn closes. Under the hood, it
 // waits for pumpReads to submit a reader to read from.
@@ -704,7 +976,7 @@ func (conn *meekConn) Read(buffer []byte) (int, error) {
 	case readBuffer = <-conn.partialReadBuffer:
 	case readBuffer = <-conn.fullReadBuffer:
 	case <-conn.closeBroadcast:
-		return 0, io.EOF
+		return 0, common.ContextError(errMeekConnectionHasClosed)
 	}
 
 	n, err := readBuffer.Read(buffer)
@@ -715,12 +987,12 @@ func (conn *meekConn) Read(buffer []byte) (int, error) {
 }
 
 func (conn *meekConn) replaceReadBuffer(readBuffer *bytes.Buffer) {
-	switch readBuffer.Len() {
-	case MEEK_MAX_PAYLOAD_LENGTH:
+	length := readBuffer.Len()
+	if length >= MEEK_MAX_REQUEST_PAYLOAD_LENGTH {
 		conn.fullReadBuffer <- readBuffer
-	case 0:
+	} else if length == 0 {
 		conn.emptyReadBuffer <- readBuffer
-	default:
+	} else {
 		conn.partialReadBuffer <- readBuffer
 	}
 }
@@ -730,40 +1002,41 @@ func (conn *meekConn) replaceReadBuffer(readBuffer *bytes.Buffer) {
 // body limits (size for protocol v1, turn around time for protocol v2+)
 // are met, or the meekConn is closed.
 // Note: channel scheme assumes only one concurrent call to pumpWrites
-func (conn *meekConn) pumpWrites(writer io.Writer) error {
+func (conn *meekConn) pumpWrites(writer io.Writer) (int, error) {
 
 	startTime := monotime.Now()
 	timeout := time.NewTimer(MEEK_TURN_AROUND_TIMEOUT)
 	defer timeout.Stop()
 
+	n := 0
 	for {
 		select {
 		case buffer := <-conn.nextWriteBuffer:
-			_, err := writer.Write(buffer)
-
+			written, err := writer.Write(buffer)
+			n += written
 			// Assumes that writeResult won't block.
 			// Note: always send the err to writeResult,
 			// as the Write() caller is blocking on this.
 			conn.writeResult <- err
 
 			if err != nil {
-				return err
+				return n, err
 			}
 
 			if conn.protocolVersion < MEEK_PROTOCOL_VERSION_1 {
 				// Pre-protocol version 1 clients expect at most
-				// MEEK_MAX_PAYLOAD_LENGTH response bodies
-				return nil
+				// MEEK_MAX_REQUEST_PAYLOAD_LENGTH response bodies
+				return n, nil
 			}
 			totalElapsedTime := monotime.Since(startTime) / time.Millisecond
 			if totalElapsedTime >= MEEK_EXTENDED_TURN_AROUND_TIMEOUT {
-				return nil
+				return n, nil
 			}
 			timeout.Reset(MEEK_TURN_AROUND_TIMEOUT)
 		case <-timeout.C:
-			return nil
+			return n, nil
 		case <-conn.closeBroadcast:
-			return io.EOF
+			return n, common.ContextError(errMeekConnectionHasClosed)
 		}
 	}
 }
@@ -783,29 +1056,40 @@ func (conn *meekConn) Write(buffer []byte) (int, error) {
 
 	n := 0
 	for n < len(buffer) {
-		end := n + MEEK_MAX_PAYLOAD_LENGTH
+		end := n + MEEK_MAX_REQUEST_PAYLOAD_LENGTH
 		if end > len(buffer) {
 			end = len(buffer)
 		}
 
-		// Only write MEEK_MAX_PAYLOAD_LENGTH at a time,
+		// Only write MEEK_MAX_REQUEST_PAYLOAD_LENGTH at a time,
 		// to ensure compatibility with v1 protocol.
 		chunk := buffer[n:end]
 
 		select {
 		case conn.nextWriteBuffer <- chunk:
 		case <-conn.closeBroadcast:
-			return n, io.EOF
+			return n, common.ContextError(errMeekConnectionHasClosed)
 		}
 
 		// Wait for the buffer to be processed.
 		select {
-		case err := <-conn.writeResult:
-			if err != nil {
-				return n, err
-			}
+		case _ = <-conn.writeResult:
+			// The err from conn.writeResult comes from the
+			// io.MultiWriter used in pumpWrites, which writes
+			// to both the cached response and the HTTP response.
+			//
+			// Don't stop on error here, since only writing
+			// to the HTTP response will fail, and the client
+			// may retry and use the cached response.
+			//
+			// It's possible that the cached response buffer
+			// is too small for the client to successfully
+			// retry, but that cannot be determined. In this
+			// case, the meek connection will eventually fail.
+			//
+			// err is already logged in ServeHTTP.
 		case <-conn.closeBroadcast:
-			return n, io.EOF
+			return n, common.ContextError(errMeekConnectionHasClosed)
 		}
 		n += len(chunk)
 	}
@@ -838,7 +1122,7 @@ func (conn *meekConn) RemoteAddr() net.Addr {
 // SetDeadline is not a true implementation of net.Conn.SetDeadline. It
 // merely checks that the requested timeout exceeds the MEEK_MAX_SESSION_STALENESS
 // period. When it does, and the session is idle, the meekConn Read/Write will
-// be interrupted and return io.EOF (not a timeout error) before the deadline.
+// be interrupted and return an error (not a timeout error) before the deadline.
 // In other words, this conn will approximate the desired functionality of
 // timing out on idle on or before the requested deadline.
 func (conn *meekConn) SetDeadline(t time.Time) error {
@@ -858,3 +1142,10 @@ func (conn *meekConn) SetReadDeadline(t time.Time) error {
 func (conn *meekConn) SetWriteDeadline(t time.Time) error {
 	return common.ContextError(errors.New("not supported"))
 }
+
+// GetMetrics implements the MetricsSource interface. The metrics are maintained
+// in the meek session type; but logTunnel, which calls MetricsSource.GetMetrics,
+// has a pointer only to this conn, so it calls through to the session.
+func (conn *meekConn) GetMetrics() LogFields {
+	return conn.meekSession.GetMetrics()
+}

+ 304 - 0
psiphon/server/meekBuffer.go

@@ -0,0 +1,304 @@
+/*
+ * Copyright (c) 2017, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package server
+
+import (
+	"errors"
+	"io"
+)
+
+// CachedResponse is a data structure that supports meek
+// protocol connection interruption resiliency: it stores
+// payload data from the most recent response so that it
+// may be resent if the client fails to receive it.
+//
+// The meek server maintains one CachedResponse for each
+// meek client. Psiphon's variant of meek streams response
+// data, so responses are not fixed size. To limit the memory
+// overhead of response caching, each CachedResponse has a
+// fixed-size buffer that operates as a ring buffer,
+// discarding older response bytes when the buffer fills.
+// A CachedResponse that has discarded data may still satisfy
+// a client retry where the client has already received part
+// of the response payload.
+//
+// A CachedResponse will also extend its capacity by
+// borrowing buffers from a CachedResponseBufferPool, if
+// available. When Reset is called, borrowed buffers are
+// released back to the pool.
+type CachedResponse struct {
+	buffers            [][]byte
+	readPosition       int
+	readAvailable      int
+	writeIndex         int
+	writeBufferIndex   int
+	overwriting        bool
+	extendedBufferPool *CachedResponseBufferPool
+}
+
+// NewCachedResponse creates a CachedResponse with a fixed buffer
+// of size bufferSize and borrowing buffers from extendedBufferPool.
+func NewCachedResponse(
+	bufferSize int,
+	extendedBufferPool *CachedResponseBufferPool) *CachedResponse {
+
+	return &CachedResponse{
+		buffers:            [][]byte{make([]byte, bufferSize)},
+		extendedBufferPool: extendedBufferPool,
+	}
+}
+
+// Reset reinitializes the CachedResponse state to have
+// no buffered response and releases all extended buffers
+// back to the pool.
+// Reset _must_ be called before discarding a CachedResponse
+// or extended buffers will not be released.
+func (response *CachedResponse) Reset() {
+	for i, buffer := range response.buffers {
+		if i > 0 {
+			response.extendedBufferPool.Put(buffer)
+		}
+	}
+	response.buffers = response.buffers[0:1]
+	response.readPosition = 0
+	response.readAvailable = 0
+	response.writeIndex = 0
+	response.writeBufferIndex = 0
+	response.overwriting = false
+}
+
+// Available returns the size of the buffered response data.
+func (response *CachedResponse) Available() int {
+	return response.readAvailable
+}
+
+// HasPosition checks if the CachedResponse has buffered
+// response data starting at or before the specified
+// position.
+func (response *CachedResponse) HasPosition(position int) bool {
+	return response.readAvailable > 0 && response.readPosition <= position
+}
+
+// CopyFromPosition writes the response data, starting at
+// the specified position, to writer. Any data before the
+// position is skipped. CopyFromPosition will return an error
+// if the specified position is not available.
+// CopyFromPosition will copy no data and return no error if
+// the position is at the end of its available data.
+// CopyFromPosition can be called repeatedly to read the
+// same data -- it does not advance or modify the CachedResponse.
+func (response *CachedResponse) CopyFromPosition(
+	position int, writer io.Writer) (int, error) {
+
+	if response.readAvailable > 0 && response.readPosition > position {
+		return 0, errors.New("position unavailable")
+	}
+
+	// Special case: position is end of available data
+	if position == response.readPosition+response.readAvailable {
+		return 0, nil
+	}
+
+	// Begin at the start of the response data, which may
+	// be midway through the buffer(s).
+
+	index := 0
+	bufferIndex := 0
+	if response.overwriting {
+		index = response.writeIndex
+		bufferIndex = response.writeBufferIndex
+		if index >= len(response.buffers[bufferIndex]) {
+			index = 0
+			bufferIndex = (bufferIndex + 1) % len(response.buffers)
+		}
+	}
+
+	// Iterate over all available data, skipping until at the
+	// requested position.
+
+	n := 0
+
+	skip := position - response.readPosition
+	available := response.readAvailable
+
+	for available > 0 {
+
+		buffer := response.buffers[bufferIndex]
+
+		toCopy := min(len(buffer)-index, available)
+
+		available -= toCopy
+
+		if skip > 0 {
+			if toCopy >= skip {
+				index += skip
+				toCopy -= skip
+				skip = 0
+			} else {
+				skip -= toCopy
+			}
+		}
+
+		if skip == 0 {
+			written, err := writer.Write(buffer[index : index+toCopy])
+			n += written
+			if err != nil {
+				return n, err
+			}
+		}
+
+		index = 0
+		bufferIndex = (bufferIndex + 1) % len(response.buffers)
+	}
+
+	return n, nil
+}
+
+// Write appends data to the CachedResponse. All writes will
+// succeed, but only the most recent bytes will be retained
+// once the fixed buffer is full and no extended buffers are
+// available.
+//
+// Write may be called multiple times to record a single
+// response; Reset should be called between responses.
+//
+// Write conforms to the io.Writer interface.
+func (response *CachedResponse) Write(data []byte) (int, error) {
+
+	dataIndex := 0
+
+	for dataIndex < len(data) {
+
+		// Write into available space in the current buffer
+
+		buffer := response.buffers[response.writeBufferIndex]
+		canWriteLen := len(buffer) - response.writeIndex
+		needWriteLen := len(data) - dataIndex
+		writeLen := min(canWriteLen, needWriteLen)
+
+		if writeLen > 0 {
+			copy(
+				buffer[response.writeIndex:response.writeIndex+writeLen],
+				data[dataIndex:dataIndex+writeLen])
+
+			response.writeIndex += writeLen
+
+			// readPosition tracks the earliest position in
+			// the response that remains in the cached response.
+			// Once the buffer is full (and cannot be extended),
+			// older data is overwritten and readPosition advances.
+			//
+			// readAvailable is the amount of data in the cached
+			// response, which may be less than the buffer capacity.
+
+			if response.overwriting {
+				response.readPosition += writeLen
+			} else {
+				response.readAvailable += writeLen
+			}
+
+			dataIndex += writeLen
+		}
+
+		if needWriteLen > canWriteLen {
+
+			// Add an extended buffer to increase capacity
+
+			// TODO: can extend whenever response.readIndex and response.readBufferIndex are 0?
+			if response.writeBufferIndex == len(response.buffers)-1 &&
+				!response.overwriting {
+
+				extendedBuffer := response.extendedBufferPool.Get()
+				if extendedBuffer != nil {
+					response.buffers = append(response.buffers, extendedBuffer)
+				}
+			}
+
+			// Move to the next buffer, which may wrap around
+
+			// This isn't a general ring buffer: Reset is called at
+			// start of each response, so the initial data is always
+			// at the beginning of the first buffer. It follows that
+			// data is overwritten once the buffer wraps around back
+			// to the beginning.
+
+			response.writeBufferIndex++
+			if response.writeBufferIndex >= len(response.buffers) {
+				response.writeBufferIndex = 0
+				response.overwriting = true
+			}
+			response.writeIndex = 0
+		}
+	}
+
+	return len(data), nil
+}
+
+// CachedResponseBufferPool is a fixed-size pool of
+// fixed-size buffers that are used to temporarily extend
+// the capacity of CachedResponses.
+type CachedResponseBufferPool struct {
+	bufferSize int
+	buffers    chan []byte
+}
+
+// NewCachedResponseBufferPool creates a new CachedResponseBufferPool
+// with the specified number of buffers. Buffers are allocated on
+// demand and once allocated remain allocated.
+func NewCachedResponseBufferPool(
+	bufferSize, bufferCount int) *CachedResponseBufferPool {
+
+	buffers := make(chan []byte, bufferCount)
+	for i := 0; i < bufferCount; i++ {
+		buffers <- make([]byte, 0)
+	}
+
+	return &CachedResponseBufferPool{
+		bufferSize: bufferSize,
+		buffers:    buffers,
+	}
+}
+
+// Get returns a buffer, if one is available, or returns nil
+// when no buffer is available. Get does not block. Call Put
+// to release the buffer back to the pool.
+//
+// Note: currently, Buffers are not zeroed between use by
+// different CachedResponses owned by different clients.
+// A bug resulting in cross-client data transfer exposes
+// only OSSH ciphertext in the case of meek's use of
+// CachedResponses.
+func (pool *CachedResponseBufferPool) Get() []byte {
+	select {
+	case buffer := <-pool.buffers:
+		if len(buffer) == 0 {
+			buffer = make([]byte, pool.bufferSize)
+		}
+		return buffer
+	default:
+		return nil
+	}
+}
+
+// Put releases a buffer back to the pool. The buffer must
+// have been obtained from Get.
+func (pool *CachedResponseBufferPool) Put(buffer []byte) {
+	pool.buffers <- buffer
+}

+ 362 - 0
psiphon/server/meek_test.go

@@ -0,0 +1,362 @@
+/*
+ * Copyright (c) 2017, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package server
+
+import (
+	"bytes"
+	crypto_rand "crypto/rand"
+	"encoding/base64"
+	"fmt"
+	"math/rand"
+	"net"
+	"sync"
+	"syscall"
+	"testing"
+	"time"
+
+	"github.com/Psiphon-Inc/crypto/nacl/box"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+)
+
+var KB = 1024
+var MB = KB * KB
+
+func TestCachedResponse(t *testing.T) {
+
+	rand.Seed(time.Now().Unix())
+
+	testCases := []struct {
+		concurrentResponses int
+		responseSize        int
+		bufferSize          int
+		extendedBufferSize  int
+		extendedBufferCount int
+		minBytesPerWrite    int
+		maxBytesPerWrite    int
+		copyPosition        int
+		expectedSuccess     bool
+	}{
+		{1, 16, 16, 0, 0, 1, 1, 0, true},
+
+		{1, 31, 16, 0, 0, 1, 1, 15, true},
+
+		{1, 16, 2, 2, 7, 1, 1, 0, true},
+
+		{1, 31, 15, 3, 5, 1, 1, 1, true},
+
+		{1, 16, 16, 0, 0, 1, 1, 16, true},
+
+		{1, 64*KB + 1, 64 * KB, 64 * KB, 1, 1, 1 * KB, 64 * KB, true},
+
+		{1, 10 * MB, 64 * KB, 64 * KB, 158, 1, 32 * KB, 0, false},
+
+		{1, 10 * MB, 64 * KB, 64 * KB, 159, 1, 32 * KB, 0, true},
+
+		{1, 10 * MB, 64 * KB, 64 * KB, 160, 1, 32 * KB, 0, true},
+
+		{1, 128 * KB, 64 * KB, 0, 0, 1, 1 * KB, 64 * KB, true},
+
+		{1, 128 * KB, 64 * KB, 0, 0, 1, 1 * KB, 63 * KB, false},
+
+		{1, 200 * KB, 64 * KB, 0, 0, 1, 1 * KB, 136 * KB, true},
+
+		{10, 10 * MB, 64 * KB, 64 * KB, 1589, 1, 32 * KB, 0, false},
+
+		{10, 10 * MB, 64 * KB, 64 * KB, 1590, 1, 32 * KB, 0, true},
+	}
+
+	for _, testCase := range testCases {
+		description := fmt.Sprintf("test case: %+v", testCase)
+		t.Run(description, func(t *testing.T) {
+
+			pool := NewCachedResponseBufferPool(testCase.extendedBufferSize, testCase.extendedBufferCount)
+
+			responses := make([]*CachedResponse, testCase.concurrentResponses)
+			for i := 0; i < testCase.concurrentResponses; i++ {
+				responses[i] = NewCachedResponse(testCase.bufferSize, pool)
+			}
+
+			// Repeats exercise CachedResponse.Reset() and CachedResponseBufferPool replacement
+			for repeat := 0; repeat < 2; repeat++ {
+
+				t.Logf("repeat %d", repeat)
+
+				responseData := make([]byte, testCase.responseSize)
+				_, _ = rand.Read(responseData)
+
+				waitGroup := new(sync.WaitGroup)
+
+				// Goroutines exercise concurrent access to CachedResponseBufferPool
+				for _, response := range responses {
+					waitGroup.Add(1)
+					go func(response *CachedResponse) {
+						defer waitGroup.Done()
+
+						remainingSize := testCase.responseSize
+						for remainingSize > 0 {
+
+							writeSize := testCase.minBytesPerWrite
+							writeSize += rand.Intn(testCase.maxBytesPerWrite - testCase.minBytesPerWrite + 1)
+							if writeSize > remainingSize {
+								writeSize = remainingSize
+							}
+
+							offset := len(responseData) - remainingSize
+							response.Write(responseData[offset : offset+writeSize])
+							remainingSize -= writeSize
+						}
+					}(response)
+				}
+
+				waitGroup.Wait()
+
+				atLeastOneFailure := false
+
+				for i, response := range responses {
+
+					cachedResponseData := new(bytes.Buffer)
+
+					n, err := response.CopyFromPosition(testCase.copyPosition, cachedResponseData)
+
+					if testCase.expectedSuccess {
+						if err != nil {
+							t.Fatalf("CopyFromPosition unexpectedly failed for response %d: %s", i, err)
+						}
+						if n != cachedResponseData.Len() || n > response.Available() {
+							t.Fatalf("cached response size mismatch for response %d", i)
+						}
+						if bytes.Compare(responseData[testCase.copyPosition:], cachedResponseData.Bytes()) != 0 {
+							t.Fatalf("cached response data mismatch for response %d", i)
+						}
+					} else {
+						atLeastOneFailure = true
+					}
+				}
+
+				if !testCase.expectedSuccess && !atLeastOneFailure {
+					t.Fatalf("CopyFromPosition unexpectedly succeeded for all responses")
+				}
+
+				for _, response := range responses {
+					response.Reset()
+				}
+			}
+		})
+	}
+}
+
+func TestMeekResiliency(t *testing.T) {
+
+	upstreamData := make([]byte, 5*MB)
+	_, _ = rand.Read(upstreamData)
+
+	downstreamData := make([]byte, 5*MB)
+	_, _ = rand.Read(downstreamData)
+
+	minWrite, maxWrite := 1, 128*KB
+	minRead, maxRead := 1, 128*KB
+	minWait, maxWait := 1*time.Millisecond, 500*time.Millisecond
+
+	sendFunc := func(name string, conn net.Conn, data []byte) {
+		for sent := 0; sent < len(data); {
+			wait := minWait + time.Duration(rand.Int63n(int64(maxWait-minWait)+1))
+			time.Sleep(wait)
+			writeLen := minWrite + rand.Intn(maxWrite-minWrite+1)
+			writeLen = min(writeLen, len(data)-sent)
+			_, err := conn.Write(data[sent : sent+writeLen])
+			if err != nil {
+				t.Fatalf("conn.Write failed: %s", err)
+			}
+			sent += writeLen
+			fmt.Printf("%s sent %d/%d...\n", name, sent, len(data))
+		}
+		fmt.Printf("%s send complete\n", name)
+	}
+
+	recvFunc := func(name string, conn net.Conn, expectedData []byte) {
+		data := make([]byte, len(expectedData))
+		for received := 0; received < len(data); {
+			wait := minWait + time.Duration(rand.Int63n(int64(maxWait-minWait)+1))
+			time.Sleep(wait)
+			readLen := minRead + rand.Intn(maxRead-minRead+1)
+			readLen = min(readLen, len(data)-received)
+			n, err := conn.Read(data[received : received+readLen])
+			if err != nil {
+				t.Fatalf("conn.Read failed: %s", err)
+			}
+			received += n
+			if bytes.Compare(data[0:received], expectedData[0:received]) != 0 {
+				fmt.Printf("%s data check has failed...\n", name)
+				additionalInfo := ""
+				index := bytes.Index(expectedData, data[received-n:received])
+				if index != -1 {
+					// Helpful for debugging missing or repeated data...
+					additionalInfo = fmt.Sprintf(
+						" (last read of %d appears at %d)", n, index)
+				}
+				t.Fatalf("%s got unexpected data with %d/%d%s",
+					name, received, len(expectedData), additionalInfo)
+			}
+			fmt.Printf("%s received %d/%d...\n", name, received, len(expectedData))
+		}
+		fmt.Printf("%s receive complete\n", name)
+	}
+
+	// Run meek server
+
+	rawMeekCookieEncryptionPublicKey, rawMeekCookieEncryptionPrivateKey, err := box.GenerateKey(crypto_rand.Reader)
+	if err != nil {
+		t.Fatalf("box.GenerateKey failed: %s", err)
+	}
+	meekCookieEncryptionPublicKey := base64.StdEncoding.EncodeToString(rawMeekCookieEncryptionPublicKey[:])
+	meekCookieEncryptionPrivateKey := base64.StdEncoding.EncodeToString(rawMeekCookieEncryptionPrivateKey[:])
+	meekObfuscatedKey, err := common.MakeRandomStringHex(SSH_OBFUSCATED_KEY_BYTE_LENGTH)
+	if err != nil {
+		t.Fatalf("common.MakeRandomStringHex failed: %s", err)
+	}
+
+	mockSupport := &SupportServices{
+		Config: &Config{
+			MeekObfuscatedKey:              meekObfuscatedKey,
+			MeekCookieEncryptionPrivateKey: meekCookieEncryptionPrivateKey,
+		},
+	}
+
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("net.Listen failed: %s", err)
+	}
+	defer listener.Close()
+
+	serverAddress := listener.Addr().String()
+
+	relayWaitGroup := new(sync.WaitGroup)
+
+	clientHandler := func(_ string, conn net.Conn) {
+		name := "server"
+		relayWaitGroup.Add(1)
+		go func() {
+			defer relayWaitGroup.Done()
+			sendFunc(name, conn, downstreamData)
+		}()
+		relayWaitGroup.Add(1)
+		go func() {
+			defer relayWaitGroup.Done()
+			recvFunc(name, conn, upstreamData)
+		}()
+	}
+
+	stopBroadcast := make(chan struct{})
+
+	useTLS := false
+	useObfuscatedSessionTickets := false
+
+	server, err := NewMeekServer(
+		mockSupport,
+		listener,
+		useTLS,
+		useObfuscatedSessionTickets,
+		clientHandler,
+		stopBroadcast)
+	if err != nil {
+		t.Fatalf("NewMeekServer failed: %s", err)
+	}
+
+	serverWaitGroup := new(sync.WaitGroup)
+
+	serverWaitGroup.Add(1)
+	go func() {
+		defer serverWaitGroup.Done()
+		err := server.Run()
+		if err != nil {
+			t.Fatalf("MeekServer.Run failed: %s", err)
+		}
+	}()
+
+	// Run meek client
+
+	dialConfig := &psiphon.DialConfig{
+		PendingConns:            new(common.Conns),
+		UseIndistinguishableTLS: true,
+		DeviceBinder:            new(fileDescriptorInterruptor),
+	}
+
+	meekConfig := &psiphon.MeekConfig{
+		DialAddress:                   serverAddress,
+		UseHTTPS:                      useTLS,
+		UseObfuscatedSessionTickets:   useObfuscatedSessionTickets,
+		HostHeader:                    "example.com",
+		MeekCookieEncryptionPublicKey: meekCookieEncryptionPublicKey,
+		MeekObfuscatedKey:             meekObfuscatedKey,
+	}
+
+	clientConn, err := psiphon.DialMeek(meekConfig, dialConfig)
+	if err != nil {
+		t.Fatalf("psiphon.DialMeek failed: %s", err)
+	}
+
+	// Relay data through meek while interrupting underlying TCP connections
+
+	name := "client"
+
+	relayWaitGroup.Add(1)
+	go func() {
+		defer relayWaitGroup.Done()
+		sendFunc(name, clientConn, upstreamData)
+	}()
+
+	relayWaitGroup.Add(1)
+	go func() {
+		defer relayWaitGroup.Done()
+		recvFunc(name, clientConn, downstreamData)
+	}()
+
+	relayWaitGroup.Wait()
+
+	// Graceful shutdown
+
+	clientConn.Close()
+
+	listener.Close()
+	close(stopBroadcast)
+
+	// This wait will hang if shutdown is broken, and the test will ultimately panic
+	serverWaitGroup.Wait()
+}
+
+type fileDescriptorInterruptor struct {
+}
+
+func (interruptor *fileDescriptorInterruptor) BindToDevice(fileDescriptor int) error {
+	fdDup, err := syscall.Dup(fileDescriptor)
+	if err != nil {
+		return err
+	}
+	minAfter := 500 * time.Millisecond
+	maxAfter := 1 * time.Second
+	after := minAfter + time.Duration(rand.Int63n(int64(maxAfter-minAfter)+1))
+	time.AfterFunc(after, func() {
+		syscall.Shutdown(fdDup, syscall.SHUT_RDWR)
+		syscall.Close(fdDup)
+		fmt.Printf("interrupted TCP connection\n")
+	})
+	return nil
+}

+ 1 - 1
psiphon/server/net.go

@@ -61,7 +61,7 @@ import (
 // HTTPSServer is a wrapper around http.Server which adds the
 // ServeTLS function.
 type HTTPSServer struct {
-	http.Server
+	*http.Server
 }
 
 // ServeTLS is similar to http.Serve, but uses TLS.

+ 124 - 47
psiphon/server/tunnelServer.go

@@ -322,9 +322,14 @@ func (sshServer *sshServer) getEstablishTunnels() bool {
 func (sshServer *sshServer) runListener(
 	listener net.Listener,
 	listenerError chan<- error,
-	tunnelProtocol string) {
+	listenerTunnelProtocol string) {
 
-	handleClient := func(clientConn net.Conn) {
+	runningProtocols := make([]string, 0)
+	for tunnelProtocol, _ := range sshServer.support.Config.TunnelProtocolPorts {
+		runningProtocols = append(runningProtocols, tunnelProtocol)
+	}
+
+	handleClient := func(clientTunnelProtocol string, clientConn net.Conn) {
 
 		// Note: establish tunnel limiter cannot simply stop TCP
 		// listeners in all cases (e.g., meek) since SSH tunnel can
@@ -336,6 +341,19 @@ func (sshServer *sshServer) runListener(
 			return
 		}
 
+		// The tunnelProtocol passed to handleClient is used for stats,
+		// throttling, etc. When the tunnel protocol can be determined
+		// unambiguously from the listening port, use that protocol and
+		// don't use any client-declared value. Only use the client's
+		// value, if present, in special cases where the listenting port
+		// cannot distinguish the protocol.
+		tunnelProtocol := listenerTunnelProtocol
+		if clientTunnelProtocol != "" &&
+			protocol.UseClientTunnelProtocol(
+				clientTunnelProtocol, runningProtocols) {
+			tunnelProtocol = clientTunnelProtocol
+		}
+
 		// process each client connection concurrently
 		go sshServer.handleClient(tunnelProtocol, clientConn)
 	}
@@ -345,14 +363,14 @@ func (sshServer *sshServer) runListener(
 	// TunnelServer.Run will properly shut down instead of remaining
 	// running.
 
-	if protocol.TunnelProtocolUsesMeekHTTP(tunnelProtocol) ||
-		protocol.TunnelProtocolUsesMeekHTTPS(tunnelProtocol) {
+	if protocol.TunnelProtocolUsesMeekHTTP(listenerTunnelProtocol) ||
+		protocol.TunnelProtocolUsesMeekHTTPS(listenerTunnelProtocol) {
 
 		meekServer, err := NewMeekServer(
 			sshServer.support,
 			listener,
-			protocol.TunnelProtocolUsesMeekHTTPS(tunnelProtocol),
-			protocol.TunnelProtocolUsesObfuscatedSessionTickets(tunnelProtocol),
+			protocol.TunnelProtocolUsesMeekHTTPS(listenerTunnelProtocol),
+			protocol.TunnelProtocolUsesObfuscatedSessionTickets(listenerTunnelProtocol),
 			handleClient,
 			sshServer.shutdownBroadcast)
 		if err != nil {
@@ -393,7 +411,7 @@ func (sshServer *sshServer) runListener(
 				return
 			}
 
-			handleClient(conn)
+			handleClient("", conn)
 		}
 	}
 }
@@ -727,6 +745,7 @@ type trafficState struct {
 	concurrentPortForwardCount            int64
 	peakConcurrentPortForwardCount        int64
 	totalPortForwardCount                 int64
+	availablePortForwardCond              *sync.Cond
 }
 
 // qualityMetrics records upstream TCP dial attempts and
@@ -753,7 +772,7 @@ func newSshClient(
 
 	runContext, stopRunning := context.WithCancel(context.Background())
 
-	return &sshClient{
+	client := &sshClient{
 		sshServer:         sshServer,
 		tunnelProtocol:    tunnelProtocol,
 		geoIPData:         geoIPData,
@@ -762,10 +781,18 @@ func newSshClient(
 		runContext:        runContext,
 		stopRunning:       stopRunning,
 	}
+
+	client.tcpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
+	client.udpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
+
+	return client
 }
 
 func (sshClient *sshClient) run(clientConn net.Conn) {
 
+	// Some conns report additional metrics
+	metricsSource, isMetricsSource := clientConn.(MetricsSource)
+
 	// Set initial traffic rules, pre-handshake, based on currently known info.
 	sshClient.setTrafficRules()
 
@@ -886,7 +913,11 @@ func (sshClient *sshClient) run(clientConn net.Conn) {
 
 	sshClient.sshServer.unregisterEstablishedClient(sshClient)
 
-	sshClient.logTunnel()
+	var additionalMetrics LogFields
+	if isMetricsSource {
+		additionalMetrics = metricsSource.GetMetrics()
+	}
+	sshClient.logTunnel(additionalMetrics)
 
 	// Transfer OSL seed state -- the OSL progress -- from the closing
 	// client to the session cache so the client can resume its progress
@@ -1191,10 +1222,11 @@ func (sshClient *sshClient) runTunnel(
 
 			if sshClient.isTCPDialingPortForwardLimitExceeded() {
 				blockStartTime := monotime.Now()
-				ctx, cancelFunc := context.WithTimeout(sshClient.runContext, remainingDialTimeout)
-				sshClient.setTCPPortForwardDialingAvailableSignal(cancelFunc)
+				ctx, cancelCtx := context.WithTimeout(sshClient.runContext, remainingDialTimeout)
+				sshClient.setTCPPortForwardDialingAvailableSignal(cancelCtx)
 				<-ctx.Done()
 				sshClient.setTCPPortForwardDialingAvailableSignal(nil)
+				cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
 				remainingDialTimeout -= monotime.Since(blockStartTime)
 			}
 
@@ -1305,7 +1337,7 @@ func (sshClient *sshClient) runTunnel(
 	waitGroup.Wait()
 }
 
-func (sshClient *sshClient) logTunnel() {
+func (sshClient *sshClient) logTunnel(additionalMetrics LogFields) {
 
 	// Note: reporting duration based on last confirmed data transfer, which
 	// is reads for sshClient.activityConn.GetActiveDuration(), and not
@@ -1339,6 +1371,16 @@ func (sshClient *sshClient) logTunnel() {
 	logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
 	logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
 
+	// Merge in additional metrics from the optional metrics source
+	if additionalMetrics != nil {
+		for name, value := range additionalMetrics {
+			// Don't overwrite any basic fields
+			if logFields[name] == nil {
+				logFields[name] = value
+			}
+		}
+	}
+
 	sshClient.Unlock()
 
 	log.LogRawFieldsWithTimestamp(logFields)
@@ -1655,7 +1697,7 @@ func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
 	return false
 }
 
-func (sshClient *sshClient) isPortForwardLimitExceeded(
+func (sshClient *sshClient) isAtPortForwardLimit(
 	portForwardType int) bool {
 
 	sshClient.Lock()
@@ -1715,15 +1757,57 @@ func (sshClient *sshClient) abortedTCPPortForward() {
 	sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
 }
 
+// establishedPortForward increments the concurrent port
+// forward counter. closedPortForward decrements it, so it
+// must always be called for each establishedPortForward
+// call.
+//
+// When at the limit of established port forwards, the LRU
+// existing port forward is closed to make way for the newly
+// established one. There can be a minor delay as, in addition
+// to calling Close() on the port forward net.Conn,
+// establishedPortForward waits for the LRU's closedPortForward()
+// call which will decrement the concurrent counter. This
+// ensures all resources associated with the LRU (socket,
+// goroutine) are released or will very soon be released before
+// proceeding.
 func (sshClient *sshClient) establishedPortForward(
-	portForwardType int) {
-
-	sshClient.Lock()
-	defer sshClient.Unlock()
+	portForwardType int, portForwardLRU *common.LRUConns) {
 
 	var state *trafficState
 	if portForwardType == portForwardTypeTCP {
 		state = &sshClient.tcpTrafficState
+	} else {
+		state = &sshClient.udpTrafficState
+	}
+
+	// When the maximum number of port forwards is already
+	// established, close the LRU. CloseOldest will call
+	// Close on the port forward net.Conn. Both TCP and
+	// UDP port forwards have handler goroutines that may
+	// be blocked calling Read on the net.Conn. Close will
+	// eventually interrupt the Read and cause the handlers
+	// to exit, but not immediately. So the following logic
+	// waits for a LRU handler to be interrupted and signal
+	// availability.
+	//
+	// Note: the port forward limit can change via a traffic
+	// rules hot reload; the condition variable handles this
+	// case whereas a channel-based semaphore would not.
+
+	if sshClient.isAtPortForwardLimit(portForwardType) {
+		portForwardLRU.CloseOldest()
+		log.WithContext().Debug("closed LRU port forward")
+		state.availablePortForwardCond.L.Lock()
+		for sshClient.isAtPortForwardLimit(portForwardType) {
+			state.availablePortForwardCond.Wait()
+		}
+		state.availablePortForwardCond.L.Unlock()
+	}
+
+	sshClient.Lock()
+
+	if portForwardType == portForwardTypeTCP {
 
 		// Assumes TCP port forwards called dialingTCPPortForward
 		state.concurrentDialingPortForwardCount -= 1
@@ -1736,8 +1820,6 @@ func (sshClient *sshClient) establishedPortForward(
 			}
 		}
 
-	} else {
-		state = &sshClient.udpTrafficState
 	}
 
 	state.concurrentPortForwardCount += 1
@@ -1745,13 +1827,14 @@ func (sshClient *sshClient) establishedPortForward(
 		state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
 	}
 	state.totalPortForwardCount += 1
+
+	sshClient.Unlock()
 }
 
 func (sshClient *sshClient) closedPortForward(
 	portForwardType int, bytesUp, bytesDown int64) {
 
 	sshClient.Lock()
-	defer sshClient.Unlock()
 
 	var state *trafficState
 	if portForwardType == portForwardTypeTCP {
@@ -1763,6 +1846,12 @@ func (sshClient *sshClient) closedPortForward(
 	state.concurrentPortForwardCount -= 1
 	state.bytesUp += bytesUp
 	state.bytesDown += bytesDown
+
+	sshClient.Unlock()
+
+	// Signal any goroutine waiting in establishedPortForward
+	// that an established port forward slot is available.
+	state.availablePortForwardCond.Signal()
 }
 
 func (sshClient *sshClient) updateQualityMetricsWithDialResult(
@@ -1836,8 +1925,9 @@ func (sshClient *sshClient) handleTCPChannel(
 
 	log.WithContextFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
 
-	ctx, _ := context.WithTimeout(sshClient.runContext, remainingDialTimeout)
+	ctx, cancelCtx := context.WithTimeout(sshClient.runContext, remainingDialTimeout)
 	IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect)
+	cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
 
 	// TODO: shuffle list to try other IPs?
 	// TODO: IPv6 support
@@ -1894,8 +1984,9 @@ func (sshClient *sshClient) handleTCPChannel(
 
 	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
 
-	ctx, _ = context.WithTimeout(sshClient.runContext, remainingDialTimeout)
+	ctx, cancelCtx = context.WithTimeout(sshClient.runContext, remainingDialTimeout)
 	fwdConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteAddr)
+	cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
 
 	// Record port forward success or failure
 	sshClient.updateQualityMetricsWithDialResult(err == nil, monotime.Since(dialStartTime))
@@ -1924,10 +2015,20 @@ func (sshClient *sshClient) handleTCPChannel(
 	defer fwdChannel.Close()
 
 	// Release the dialing slot and acquire an established slot.
+	//
+	// establishedPortForward increments the concurrent TCP port
+	// forward counter and closes the LRU existing TCP port forward
+	// when already at the limit.
+	//
+	// Known limitations:
+	//
+	// - Closed LRU TCP sockets will enter the TIME_WAIT state,
+	//   continuing to consume some resources.
+
+	sshClient.establishedPortForward(portForwardTypeTCP, sshClient.tcpPortForwardLRU)
 
 	// "established = true" cancels the deferred abortedTCPPortForward()
 	established = true
-	sshClient.establishedPortForward(portForwardTypeTCP)
 
 	// TODO: 64-bit alignment? https://golang.org/pkg/sync/atomic/#pkg-note-BUG
 	var bytesUp, bytesDown int64
@@ -1936,30 +2037,6 @@ func (sshClient *sshClient) handleTCPChannel(
 			portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
 	}()
 
-	if exceeded := sshClient.isPortForwardLimitExceeded(portForwardTypeTCP); exceeded {
-
-		// Close the oldest TCP port forward. CloseOldest() closes
-		// the conn and the port forward's goroutines will complete
-		// the cleanup asynchronously.
-		//
-		// Some known limitations:
-		//
-		// - Since CloseOldest() closes the upstream socket but does not
-		//   clean up all resources associated with the port forward. These
-		//   include the goroutine(s) relaying traffic as well as the SSH
-		//   channel. Closing the socket will interrupt the goroutines which
-		//   will then complete the cleanup. But, since the full cleanup is
-		//   asynchronous, there exists a possibility that a client can consume
-		//   more than max port forward resources -- just not upstream sockets.
-		//
-		// - Closed sockets will enter the TIME_WAIT state, consuming some
-		//   resources.
-
-		sshClient.tcpPortForwardLRU.CloseOldest()
-
-		log.WithContext().Debug("closed LRU TCP port forward")
-	}
-
 	lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
 	defer lruEntry.Remove()
 

+ 5 - 13
psiphon/server/udp.go

@@ -171,22 +171,14 @@ func (mux *udpPortForwardMultiplexer) run() {
 
 			// Note: UDP port forward counting has no dialing phase
 
-			mux.sshClient.establishedPortForward(portForwardTypeUDP)
+			// establishedPortForward increments the concurrent UDP port
+			// forward counter and closes the LRU existing UDP port forward
+			// when already at the limit.
+
+			mux.sshClient.establishedPortForward(portForwardTypeUDP, mux.portForwardLRU)
 			// Can't defer sshClient.closedPortForward() here;
 			// relayDownstream will call sshClient.closedPortForward()
 
-			// TOCTOU note: important to increment the port forward count (via
-			// openPortForward) _before_ checking isPortForwardLimitExceeded
-			if exceeded := mux.sshClient.isPortForwardLimitExceeded(portForwardTypeUDP); exceeded {
-
-				// Close the oldest UDP port forward. CloseOldest() closes
-				// the conn and the port forward's goroutine will complete
-				// the cleanup asynchronously.
-				mux.portForwardLRU.CloseOldest()
-
-				log.WithContext().Debug("closed LRU UDP port forward")
-			}
-
 			log.WithContextFields(
 				LogFields{
 					"remoteAddr": fmt.Sprintf("%s:%d", dialIP.String(), dialPort),

+ 16 - 0
psiphon/server/utils.go

@@ -29,6 +29,7 @@ import (
 	"fmt"
 	"io"
 	"math/big"
+	"sync/atomic"
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
@@ -174,3 +175,18 @@ func (w *PanickingLogWriter) Write(p []byte) (n int, err error) {
 	}
 	return
 }
+
+func min(a, b int) int {
+	if a < b {
+		return a
+	}
+	return b
+}
+
+func greaterThanSwapInt64(addr *int64, new int64) bool {
+	old := atomic.LoadInt64(addr)
+	if new > old {
+		return atomic.CompareAndSwapInt64(addr, old, new)
+	}
+	return false
+}

+ 1 - 1
psiphon/server/webServer.go

@@ -90,7 +90,7 @@ func RunWebServer(
 	// https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts
 
 	server := &HTTPSServer{
-		http.Server{
+		&http.Server{
 			MaxHeaderBytes: MAX_API_PARAMS_SIZE,
 			Handler:        serveMux,
 			ReadTimeout:    WEB_SERVER_IO_TIMEOUT,

+ 19 - 12
psiphon/tunnel.go

@@ -570,6 +570,7 @@ func initMeekConfig(
 		TransformedHostName:           transformedHostName,
 		PsiphonServerAddress:          psiphonServerAddress,
 		SessionID:                     sessionId,
+		ClientTunnelProtocol:          selectedProtocol,
 		MeekCookieEncryptionPublicKey: serverEntry.MeekCookieEncryptionPublicKey,
 		MeekObfuscatedKey:             serverEntry.MeekObfuscatedKey,
 	}, nil
@@ -642,7 +643,7 @@ func dialSsh(
 	}
 
 	if meekConfig != nil || upstreamProxyType == "http" {
-		dialCustomHeaders, selectedUserAgent = UserAgentIfUnset(config.CustomHeaders)
+		dialCustomHeaders, selectedUserAgent = UserAgentIfUnset(dialCustomHeaders)
 	}
 
 	// Use an asynchronous callback to record the resolved IP address when
@@ -726,22 +727,21 @@ func dialSsh(
 
 		// For some direct connect servers, DialPluginProtocol
 		// will layer on another obfuscation protocol.
+
+		// Use a copy of DialConfig without pendingConns; the
+		// DialPluginProtocol must supply and manage its own
+		// for its base network connections.
+		pluginDialConfig := new(DialConfig)
+		*pluginDialConfig = *dialConfig
+		pluginDialConfig.PendingConns = nil
+
 		var dialedPlugin bool
 		dialedPlugin, dialConn, err = DialPluginProtocol(
 			config,
 			NewNoticeWriter("DialPluginProtocol"),
 			pendingConns,
-			func(_, addr string) (net.Conn, error) {
-
-				// Use a copy of DialConfig without pendingConns
-				// TODO: distinct pendingConns for plugins?
-				pluginDialConfig := new(DialConfig)
-				*pluginDialConfig = *dialConfig
-				pluginDialConfig.PendingConns = nil
-
-				return DialTCP(addr, pluginDialConfig)
-			},
-			directTCPDialAddress)
+			directTCPDialAddress,
+			dialConfig)
 
 		if !dialedPlugin && err != nil {
 			NoticeInfo("DialPluginProtocol intialization failed: %s", err)
@@ -1206,6 +1206,13 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 
 		tunnelDuration := tunnel.conn.GetLastActivityMonotime().Sub(tunnel.establishedTime)
 
+		// tunnelDuration can be < 0 when tunnel.establishedTime is recorded after the
+		// last tunnel.conn.Read() succeeds. In that case, the last read would be the
+		// handshake response, so the tunnel had, essentially, no duration.
+		if tunnelDuration < 0 {
+			tunnelDuration = 0
+		}
+
 		err := RecordTunnelStat(
 			tunnel.serverContext.sessionId,
 			tunnel.serverContext.tunnelNumber,