Răsfoiți Sursa

Fix: psiphond shutdown delay due to incomplete TLS handshake

- Update vendored github.com/Psiphon-Labs/tls-tris

- Add automated test

github.com/Psiphon-Labs/tls-tris commit message:

Apply upstream deadlock fix

- https://github.com/golang/go/commit/e5b13401c6b19f58a8439f1019a80fe540c0c687

- Due to the handshake/Close mutex, psiphond shutdown would get delayed
  when shutdown is initiated while a client has established a TCP connection
  but not yet competed the TLS handshake.

- Note that the upstream changes to readRecord are not applied to tls-tris
  since tls-tris has deviated from upstream and was no longer checking
  c.handshakeComplete in readRecord.
Rod Hynes 6 ani în urmă
părinte
comite
656d5bf52d

+ 37 - 1
psiphon/server/server_test.go

@@ -128,6 +128,7 @@ func TestSSH(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    true,
 		})
 }
 
@@ -146,6 +147,7 @@ func TestOSSH(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    true,
 		})
 }
 
@@ -164,6 +166,7 @@ func TestFragmentedOSSH(t *testing.T) {
 			forceFragmenting:     true,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    true,
 		})
 }
 
@@ -182,6 +185,7 @@ func TestUnfrontedMeek(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    true,
 		})
 }
 
@@ -201,6 +205,7 @@ func TestUnfrontedMeekHTTPS(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    true,
 		})
 }
 
@@ -220,6 +225,7 @@ func TestUnfrontedMeekHTTPSTLS13(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    true,
 		})
 }
 
@@ -239,6 +245,7 @@ func TestUnfrontedMeekSessionTicket(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    true,
 		})
 }
 
@@ -258,6 +265,7 @@ func TestUnfrontedMeekSessionTicketTLS13(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    true,
 		})
 }
 
@@ -276,6 +284,7 @@ func TestQUICOSSH(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    false,
 		})
 }
 
@@ -297,6 +306,7 @@ func TestMarionetteOSSH(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    false,
 		})
 }
 
@@ -315,6 +325,7 @@ func TestWebTransportAPIRequests(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    false,
 		})
 }
 
@@ -333,6 +344,7 @@ func TestHotReload(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    false,
 		})
 }
 
@@ -351,6 +363,7 @@ func TestDefaultSponsorID(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    false,
 		})
 }
 
@@ -369,6 +382,7 @@ func TestDenyTrafficRules(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    false,
 		})
 }
 
@@ -387,6 +401,7 @@ func TestOmitAuthorization(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    false,
 		})
 }
 
@@ -405,6 +420,7 @@ func TestNoAuthorization(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    false,
 		})
 }
 
@@ -423,6 +439,7 @@ func TestUnusedAuthorization(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    false,
 		})
 }
 
@@ -441,6 +458,7 @@ func TestTCPOnlySLOK(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    false,
 		})
 }
 
@@ -459,6 +477,7 @@ func TestUDPOnlySLOK(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    false,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    false,
 		})
 }
 
@@ -477,6 +496,7 @@ func TestLivenessTest(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    true,
 			doPruneServerEntries: false,
+			doDanglingTCPConn:    false,
 		})
 }
 
@@ -495,6 +515,7 @@ func TestPruneServerEntries(t *testing.T) {
 			forceFragmenting:     false,
 			forceLivenessTest:    true,
 			doPruneServerEntries: true,
+			doDanglingTCPConn:    false,
 		})
 }
 
@@ -512,6 +533,7 @@ type runServerConfig struct {
 	forceFragmenting     bool
 	forceLivenessTest    bool
 	doPruneServerEntries bool
+	doDanglingTCPConn    bool
 }
 
 var (
@@ -573,12 +595,13 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		// Workaround for macOS firewall.
 		psiphonServerIPAddress = "127.0.0.1"
 	}
+	psiphonServerPort := 4000
 
 	generateConfigParams := &GenerateConfigParams{
 		ServerIPAddress:      psiphonServerIPAddress,
 		EnableSSHAPIRequests: runConfig.enableSSHAPIRequests,
 		WebServerPort:        8000,
-		TunnelProtocolPorts:  map[string]int{runConfig.tunnelProtocol: 4000},
+		TunnelProtocolPorts:  map[string]int{runConfig.tunnelProtocol: psiphonServerPort},
 	}
 
 	if protocol.TunnelProtocolUsesMarionette(runConfig.tunnelProtocol) {
@@ -1062,6 +1085,19 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			"prune server entries timeout exceeded")
 	}
 
+	if runConfig.doDanglingTCPConn {
+
+		// Test: client that has established TCP connection but not completed
+		// any handshakes must not block/delay server shutdown
+
+		danglingConn, err := net.Dial(
+			"tcp", net.JoinHostPort(psiphonServerIPAddress, strconv.Itoa(psiphonServerPort)))
+		if err != nil {
+			t.Fatalf("TCP dial failed: %s", err)
+		}
+		defer danglingConn.Close()
+	}
+
 	// Shutdown to ensure logs/notices are flushed
 
 	stopClient()

+ 22 - 17
vendor/github.com/Psiphon-Labs/tls-tris/conn.go

@@ -34,6 +34,14 @@ type Conn struct {
 	// confirmMutex is held by any read operation before handshakeConfirmed
 	confirmMutex sync.Mutex
 
+	// [Psiphon]
+	// https://github.com/golang/go/commit/e5b13401c6b19f58a8439f1019a80fe540c0c687
+	//
+	// handshakeStatus is 1 if the connection is currently transferring
+	// application data (i.e. is not currently processing a handshake).
+	// This field is only to be accessed with sync/atomic.
+	handshakeStatus uint32
+
 	// constant after handshake; protected by handshakeMutex
 	handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
 	handshakeErr   error      // error resulting from handshake
@@ -42,9 +50,6 @@ type Conn struct {
 	vers           uint16     // TLS version
 	haveVers       bool       // version has been negotiated
 	config         *Config    // configuration passed to constructor
-	// handshakeComplete is true if the connection reached application data
-	// and it's equivalent to phase > handshakeRunning
-	handshakeComplete bool
 	// handshakes counts the number of handshakes performed on the
 	// connection so far. If renegotiation is disabled then this is either
 	// zero or one.
@@ -1241,7 +1246,7 @@ func (c *Conn) Write(b []byte) (int, error) {
 		return 0, err
 	}
 
-	if !c.handshakeComplete {
+	if !c.handshakeComplete() {
 		return 0, alertInternalError
 	}
 
@@ -1325,7 +1330,7 @@ func (c *Conn) handleRenegotiation(*helloRequestMsg) error {
 	defer c.handshakeMutex.Unlock()
 
 	c.phase = handshakeRunning
-	c.handshakeComplete = false
+	atomic.StoreUint32(&c.handshakeStatus, 0)
 	if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
 		c.handshakes++
 	}
@@ -1573,11 +1578,9 @@ func (c *Conn) Close() error {
 
 	var alertErr error
 
-	c.handshakeMutex.Lock()
-	if c.handshakeComplete {
+	if c.handshakeComplete() {
 		alertErr = c.closeNotify()
 	}
-	c.handshakeMutex.Unlock()
 
 	if err := c.conn.Close(); err != nil {
 		return err
@@ -1591,9 +1594,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
 // called once the handshake has completed and does not call CloseWrite on the
 // underlying connection. Most callers should just use Close.
 func (c *Conn) CloseWrite() error {
-	c.handshakeMutex.Lock()
-	defer c.handshakeMutex.Unlock()
-	if !c.handshakeComplete {
+	if !c.handshakeComplete() {
 		return errEarlyCloseWrite
 	}
 
@@ -1625,7 +1626,7 @@ func (c *Conn) Handshake() error {
 	if err := c.handshakeErr; err != nil {
 		return err
 	}
-	if c.handshakeComplete {
+	if c.handshakeComplete() {
 		return nil
 	}
 
@@ -1634,7 +1635,7 @@ func (c *Conn) Handshake() error {
 
 	// The handshake cannot have completed when handshakeMutex was unlocked
 	// because this goroutine set handshakeCond.
-	if c.handshakeErr != nil || c.handshakeComplete {
+	if c.handshakeErr != nil || c.handshakeComplete() {
 		panic("handshake should not have been able to complete after handshakeCond was set")
 	}
 
@@ -1656,7 +1657,7 @@ func (c *Conn) Handshake() error {
 		c.flush()
 	}
 
-	if c.handshakeErr == nil && !c.handshakeComplete {
+	if c.handshakeErr == nil && !c.handshakeComplete() {
 		panic("handshake should have had a result.")
 	}
 
@@ -1669,10 +1670,10 @@ func (c *Conn) ConnectionState() ConnectionState {
 	defer c.handshakeMutex.Unlock()
 
 	var state ConnectionState
-	state.HandshakeComplete = c.handshakeComplete
+	state.HandshakeComplete = c.handshakeComplete()
 	state.ServerName = c.serverName
 
-	if c.handshakeComplete {
+	if state.HandshakeComplete {
 		state.ConnectionID = c.connID
 		state.ClientHello = c.clientHello
 		state.Version = c.vers
@@ -1721,7 +1722,7 @@ func (c *Conn) VerifyHostname(host string) error {
 	if !c.isClient {
 		return errors.New("tls: VerifyHostname called on TLS server connection")
 	}
-	if !c.handshakeComplete {
+	if !c.handshakeComplete() {
 		return errors.New("tls: handshake has not yet been performed")
 	}
 	if len(c.verifiedChains) == 0 {
@@ -1729,3 +1730,7 @@ func (c *Conn) VerifyHostname(host string) error {
 	}
 	return c.peerCertificates[0].VerifyHostname(host)
 }
+
+func (c *Conn) handshakeComplete() bool {
+	return atomic.LoadUint32(&c.handshakeStatus) == 1
+}

+ 4 - 1
vendor/github.com/Psiphon-Labs/tls-tris/handshake_client.go

@@ -331,7 +331,10 @@ func (hs *clientHandshakeState) handshake() error {
 	c.didResume = isResume
 	c.phase = handshakeConfirmed
 	atomic.StoreInt32(&c.handshakeConfirmed, 1)
-	c.handshakeComplete = true
+
+	// [Psiphon]
+	// https://github.com/golang/go/commit/e5b13401c6b19f58a8439f1019a80fe540c0c687
+	atomic.StoreUint32(&c.handshakeStatus, 1)
 
 	return nil
 }

+ 5 - 2
vendor/github.com/Psiphon-Labs/tls-tris/handshake_server.go

@@ -87,7 +87,7 @@ func (c *Conn) serverHandshake() error {
 				return err
 			}
 		}
-		c.handshakeComplete = true
+		atomic.StoreUint32(&c.handshakeStatus, 1)
 		return nil
 	} else if isResume {
 		// The client has included a session ticket and so we do an abbreviated handshake.
@@ -145,7 +145,10 @@ func (c *Conn) serverHandshake() error {
 	}
 	c.phase = handshakeConfirmed
 	atomic.StoreInt32(&c.handshakeConfirmed, 1)
-	c.handshakeComplete = true
+
+	// [Psiphon]
+	// https://github.com/golang/go/commit/e5b13401c6b19f58a8439f1019a80fe540c0c687
+	atomic.StoreUint32(&c.handshakeStatus, 1)
 
 	return nil
 }

+ 3 - 3
vendor/vendor.json

@@ -63,10 +63,10 @@
 			"revisionTime": "2019-12-04T18:36:04Z"
 		},
 		{
-			"checksumSHA1": "DmQc8vfP44VMUTIZJGM4WslOyk0=",
+			"checksumSHA1": "GkQMbmKt3ls0vmJ33GdcxOqikXA=",
 			"path": "github.com/Psiphon-Labs/tls-tris",
-			"revision": "b5083341bf6cb581f3319c6dfbb39dd6ae3a97ea",
-			"revisionTime": "2019-03-21T17:45:24Z"
+			"revision": "e98b032bc3ced03cc324827b86c8bb3802401d3d",
+			"revisionTime": "2019-12-05T15:29:33Z"
 		},
 		{
 			"checksumSHA1": "30PBqj9BW03KCVqASvLg3bR+xYc=",