Kaynağa Gözat

Fix inproxy.reliableConn shutdown deadlock

Rod Hynes 10 ay önce
ebeveyn
işleme
cd7d523a45
1 değiştirilmiş dosya ile 60 ekleme ve 49 silme
  1. 60 49
      psiphon/common/inproxy/webrtc.go

+ 60 - 49
psiphon/common/inproxy/webrtc.go

@@ -98,7 +98,7 @@ type webRTCConn struct {
 	mutex                         sync.Mutex
 	udpConn                       net.PacketConn
 	portMapper                    *portMapper
-	isClosed                      bool
+	isClosed                      int32
 	closedSignal                  chan struct{}
 	readyToProxySignal            chan struct{}
 	readyToProxyOnce              sync.Once
@@ -1336,34 +1336,52 @@ func (conn *webRTCConn) recordSelectedICECandidateStats() error {
 }
 
 func (conn *webRTCConn) Close() error {
-	conn.mutex.Lock()
-	defer conn.mutex.Unlock()
 
-	if conn.isClosed {
+	if !atomic.CompareAndSwapInt32(&conn.isClosed, 0, 1) {
 		return nil
 	}
 
-	if conn.portMapper != nil {
-		conn.portMapper.close()
+	// Synchronize reading these conn fields, which may be initialized by
+	// concurrent callbacks such as onDataChannel and onMediaTrack.
+	//
+	// To avoid potential deadlocks, don't continue to hold the lock while
+	// closing individual components. For example, internally, the quic-go
+	// implementation underlying reliableConn can concurrently call through
+	// to writeMediaTrackPacket, which attempts to temporarily lock
+	// conn.mutex, while reliableConn's quicConn.Close will wait on that
+	// write operation.
+
+	conn.mutex.Lock()
+	portMapper := conn.portMapper
+	sendMediaTrackRTP := conn.sendMediaTrackRTP
+	mediaTrackReliabilityLayer := conn.mediaTrackReliabilityLayer
+	dataChannelConn := conn.dataChannelConn
+	dataChannel := conn.dataChannel
+	peerConnection := conn.peerConnection
+	udpConn := conn.udpConn
+	conn.mutex.Unlock()
+
+	if portMapper != nil {
+		portMapper.close()
 	}
 
 	// Neither sendMediaTrack nor receiveMediaTrack have a Close operation.
 
-	if conn.sendMediaTrackRTP != nil {
-		_ = conn.sendMediaTrackRTP.Stop()
+	if sendMediaTrackRTP != nil {
+		_ = sendMediaTrackRTP.Stop()
 	}
-	if conn.mediaTrackReliabilityLayer != nil {
-		_ = conn.mediaTrackReliabilityLayer.Close()
+	if mediaTrackReliabilityLayer != nil {
+		_ = mediaTrackReliabilityLayer.Close()
 	}
-	if conn.dataChannelConn != nil {
-		_ = conn.dataChannelConn.Close()
+	if dataChannelConn != nil {
+		_ = dataChannelConn.Close()
 	}
-	if conn.dataChannel != nil {
-		_ = conn.dataChannel.Close()
+	if dataChannel != nil {
+		_ = dataChannel.Close()
 	}
-	if conn.peerConnection != nil {
+	if peerConnection != nil {
 		// TODO: use PeerConnection.GracefulClose (requires pion/webrtc 3.2.51)?
-		_ = conn.peerConnection.Close()
+		_ = peerConnection.Close()
 	}
 
 	// Close the udpConn to interrupt any blocking DTLS handshake:
@@ -1372,22 +1390,17 @@ func (conn *webRTCConn) Close() error {
 	// Limitation: there is no guarantee that pion sends any closing packets
 	// before the UDP socket is closed here.
 
-	if conn.udpConn != nil {
-		_ = conn.udpConn.Close()
+	if udpConn != nil {
+		_ = udpConn.Close()
 	}
 
 	close(conn.closedSignal)
 
-	conn.isClosed = true
-
 	return nil
 }
 
 func (conn *webRTCConn) IsClosed() bool {
-	conn.mutex.Lock()
-	defer conn.mutex.Unlock()
-
-	return conn.isClosed
+	return atomic.LoadInt32(&conn.isClosed) == 1
 }
 
 func (conn *webRTCConn) Read(p []byte) (int, error) {
@@ -1455,7 +1468,7 @@ func (conn *webRTCConn) SetReadDeadline(t time.Time) error {
 	conn.mutex.Lock()
 	defer conn.mutex.Unlock()
 
-	if conn.isClosed {
+	if conn.IsClosed() {
 		return errors.TraceNew("closed")
 	}
 
@@ -1679,16 +1692,15 @@ func (conn *webRTCConn) readDataChannel(p []byte) (int, error) {
 
 func (conn *webRTCConn) readDataChannelMessage(p []byte) (int, error) {
 
+	if conn.IsClosed() {
+		return 0, errors.TraceNew("closed")
+	}
+
 	// Don't hold this lock, or else concurrent Writes will be blocked.
 	conn.mutex.Lock()
-	isClosed := conn.isClosed
 	dataChannelConn := conn.dataChannelConn
 	conn.mutex.Unlock()
 
-	if isClosed {
-		return 0, errors.TraceNew("closed")
-	}
-
 	if dataChannelConn == nil {
 		return 0, errors.TraceNew("no data channel")
 	}
@@ -1794,21 +1806,22 @@ func (conn *webRTCConn) writeDataChannelMessage(p []byte, decoy bool) (int, erro
 		return 0, nil
 	}
 
+	if conn.IsClosed() {
+		return 0, errors.TraceNew("closed")
+	}
+
 	// Don't hold this lock, or else concurrent Reads will be blocked.
 	conn.mutex.Lock()
-	isClosed := conn.isClosed
-	bufferedAmount := conn.dataChannel.BufferedAmount()
+	dataChannel := conn.dataChannel
 	dataChannelConn := conn.dataChannelConn
 	conn.mutex.Unlock()
 
-	if isClosed {
-		return 0, errors.TraceNew("closed")
-	}
-
-	if dataChannelConn == nil {
+	if dataChannel == nil || dataChannelConn == nil {
 		return 0, errors.TraceNew("no data channel")
 	}
 
+	bufferedAmount := dataChannel.BufferedAmount()
+
 	// Only proceed with a decoy message when no pending writes are buffered.
 	//
 	// This check is made before acquiring conn.writeMutex so that, in most
@@ -1939,7 +1952,7 @@ func (conn *webRTCConn) writeDataChannelMessage(p []byte, decoy bool) (int, erro
 
 	// If the pion write buffer is too full, wait for a signal that sufficient
 	// write data has been consumed before writing more.
-	if !isClosed && bufferedAmount+uint64(writeSize) > dataChannelMaxBufferedAmount {
+	if !conn.IsClosed() && bufferedAmount+uint64(writeSize) > dataChannelMaxBufferedAmount {
 		select {
 		case <-conn.dataChannelWriteBufferSignal:
 		case <-conn.closedSignal:
@@ -2056,16 +2069,15 @@ func (conn *webRTCConn) readMediaTrackPacket(p []byte) (int, error) {
 		return 0, errors.TraceNew("closed")
 	}
 
+	if conn.IsClosed() {
+		return 0, errors.TraceNew("closed")
+	}
+
 	// Don't hold this lock, or else concurrent Writes will be blocked.
 	conn.mutex.Lock()
-	isClosed := conn.isClosed
 	receiveMediaTrack := conn.receiveMediaTrack
 	conn.mutex.Unlock()
 
-	if isClosed {
-		return 0, errors.TraceNew("closed")
-	}
-
 	if receiveMediaTrack == nil {
 		return 0, errors.TraceNew("no media track")
 	}
@@ -2139,16 +2151,15 @@ func (conn *webRTCConn) writeMediaTrackPacket(p []byte, decoy bool) (int, error)
 		return 0, errors.TraceNew("invalid write parameters")
 	}
 
+	if conn.IsClosed() {
+		return 0, errors.TraceNew("closed")
+	}
+
 	// Don't hold this lock, or else concurrent Writes will be blocked.
 	conn.mutex.Lock()
-	isClosed := conn.isClosed
 	sendMediaTrack := conn.sendMediaTrack
 	conn.mutex.Unlock()
 
-	if isClosed {
-		return 0, errors.TraceNew("closed")
-	}
-
 	if sendMediaTrack == nil {
 		return 0, errors.TraceNew("no media track")
 	}
@@ -2666,7 +2677,7 @@ func (conn *reliableConn) Close() error {
 
 	// Close mediaTrackConn first, or else quic-go's Close will attempt to
 	// Write, which leads to deadlock between webRTCConn.writeMediaTrack and
-	// webRTCConn.Close. The graceful QUIC close write will fails, but that's
+	// webRTCConn.Close. The graceful QUIC close write will fail, but that's
 	// not an issue.
 
 	_ = conn.mediaTrackConn.Close()