Sfoglia il codice sorgente

Add mutex around transform from socket fd to Conn to prevent concurrent Close()

Rod Hynes 11 anni fa
parent
commit
594546af53
2 ha cambiato i file con 19 aggiunte e 11 eliminazioni
  1. 13 5
      psiphon/TCPConn_unix.go
  2. 6 6
      psiphon/tunnel.go

+ 13 - 5
psiphon/TCPConn_unix.go

@@ -116,18 +116,26 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 		return nil, ContextError(err)
 	}
 
+	// Mutex required for:
+	// 1. preventing concurrent interruptibleTCPClose (via conn.Close())
+	//    while performing os.NewFile/net.FileConn transformation
+	// 2. writing conn.Conn, since conn remains in pendingConns, from
+	//    where conn.Close() may be called in another goroutine
+
+	conn.mutex.Lock()
+
 	// Convert the syscall socket to a net.Conn
 	file := os.NewFile(uintptr(conn.interruptible.socketFd), "")
-	defer file.Close()
 	fileConn, err := net.FileConn(file)
+	file.Close()
 	if err != nil {
+		// TODO: syscall.Close(conn.interruptible.socketFd)?
+		conn.mutex.Unlock()
 		return nil, ContextError(err)
 	}
-
-	// Need mutex to write conn.Conn since conn remains in pendingConns, from
-	// where conn.Close() may be called in another goroutine
-	conn.mutex.Lock()
+	conn.interruptible.socketFd = -1
 	conn.Conn = fileConn
+
 	conn.mutex.Unlock()
 
 	// Going through upstream HTTP proxy

+ 6 - 6
psiphon/tunnel.go

@@ -336,12 +336,6 @@ func dialSsh(
 func (tunnel *Tunnel) operateTunnel(config *Config, tunnelOwner TunnelOwner) {
 	defer tunnel.operateWaitGroup.Done()
 
-	tunnelClosedSignal := make(chan struct{}, 1)
-	err := tunnel.conn.SetClosedSignal(tunnelClosedSignal)
-	if err != nil {
-		err = fmt.Errorf("failed to set closed signal: %s", err)
-	}
-
 	// Note: not using a Ticker since NextSendPeriod() is not a fixed time period
 	statsTimer := time.NewTimer(NextSendPeriod())
 	defer statsTimer.Stop()
@@ -349,6 +343,12 @@ func (tunnel *Tunnel) operateTunnel(config *Config, tunnelOwner TunnelOwner) {
 	sshKeepAliveTicker := time.NewTicker(TUNNEL_SSH_KEEP_ALIVE_PERIOD)
 	defer sshKeepAliveTicker.Stop()
 
+	tunnelClosedSignal := make(chan struct{}, 1)
+	err := tunnel.conn.SetClosedSignal(tunnelClosedSignal)
+	if err != nil {
+		err = fmt.Errorf("failed to set closed signal: %s", err)
+	}
+
 	for err == nil {
 		select {
 		case <-statsTimer.C: