Browse Source

Fix: revised fix for 3b29161

* In some cases interruptibleTCPClose was closing random files/sockets.
* Change in 3b29161 was insufficient -- fd could still be closed twice
  E.g., deferred Close(socketFd) would happen when syscall.Connect failed.
* Rewrote to ensure fd only accessed through conn.interruptible.socketFd
  variable, with mutex, and clearing to _INVALID_FD on each close case.
Rod Hynes 10 years ago
parent
commit
a01e1350d0
1 changed files with 49 additions and 50 deletions
  1. 49 50
      psiphon/TCPConn_unix.go

+ 49 - 50
psiphon/TCPConn_unix.go

@@ -44,29 +44,7 @@ const _INVALID_FD = -1
 // To implement socket device binding and interruptible connecting, the lower-level
 // syscall APIs are used. The sequence of syscalls in this implementation are
 // taken from: https://code.google.com/p/go/issues/detail?id=6966
-func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err error) {
-
-	// Create a socket and then, before connecting, add a TCPConn with
-	// the unconnected socket to pendingConns. This allows pendingConns to
-	// abort connections in progress.
-	socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
-	if err != nil {
-		return nil, ContextError(err)
-	}
-	defer func() {
-		// Cleanup on error
-		// (socketFd is reset to _INVALID_FD once it should no longer be closed)
-		if err != nil && socketFd != _INVALID_FD {
-			syscall.Close(socketFd)
-		}
-	}()
-
-	if config.DeviceBinder != nil {
-		err = config.DeviceBinder.BindToDevice(socketFd)
-		if err != nil {
-			return nil, ContextError(fmt.Errorf("BindToDevice failed: %s", err))
-		}
-	}
+func interruptibleTCPDial(addr string, config *DialConfig) (*TCPConn, error) {
 
 	// Get the remote IP and port, resolving a domain name if necessary
 	// TODO: domain name resolution isn't interruptible
@@ -100,13 +78,44 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 	var ip [4]byte
 	copy(ip[:], ipAddrs[index].To4())
 
-	// Enable interruption
-	conn = &TCPConn{interruptible: interruptibleTCPSocket{socketFd: socketFd}}
+	// Create a socket and then, before connecting, add a TCPConn with
+	// the unconnected socket to pendingConns. This allows pendingConns to
+	// interrupt/abort connections in progress.
+	socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
+	if err != nil {
+		return nil, ContextError(err)
+	}
 
+	conn := &TCPConn{interruptible: interruptibleTCPSocket{socketFd: socketFd}}
+
+	// Cleanup on error
+	defer func() {
+		// Mutex required since conn may be in pendingConns, through which
+		// conn.Close() may be called from another goroutine. There are two
+		// risks:
+		// 1. standard race conditions reading/writing conn members.
+		// 2. closing the fd more than once, with the chance that other
+		//    concurrent goroutines or threads may have already reused the fd.
+		conn.mutex.Lock()
+		if err != nil && conn.interruptible.socketFd != _INVALID_FD {
+			syscall.Close(conn.interruptible.socketFd)
+			conn.interruptible.socketFd = _INVALID_FD
+		}
+		conn.mutex.Unlock()
+	}()
+
+	// Enable interruption
 	if !config.PendingConns.Add(conn) {
 		return nil, ContextError(errors.New("pending connections already closed"))
 	}
 
+	if config.DeviceBinder != nil {
+		err = config.DeviceBinder.BindToDevice(conn.interruptible.socketFd)
+		if err != nil {
+			return nil, ContextError(fmt.Errorf("BindToDevice failed: %s", err))
+		}
+	}
+
 	// Connect the socket
 	// TODO: adjust the timeout to account for time spent resolving hostname
 	sockAddr := syscall.SockaddrInet4{Addr: ip, Port: port}
@@ -116,52 +125,42 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 			errChannel <- errors.New("connect timeout")
 		})
 		go func() {
-			errChannel <- syscall.Connect(socketFd, &sockAddr)
+			errChannel <- syscall.Connect(conn.interruptible.socketFd, &sockAddr)
 		}()
 		err = <-errChannel
 	} else {
-		err = syscall.Connect(socketFd, &sockAddr)
+		err = syscall.Connect(conn.interruptible.socketFd, &sockAddr)
 	}
-
-	// Mutex required for writing to conn, since conn remains in
-	// pendingConns, through which conn.Close() may be called from
-	// another goroutine.
-
-	conn.mutex.Lock()
-
-	// From this point, ensure conn.interruptible.socketFd is reset
-	// since the fd value may be reused for a different file or socket
-	// before Close() -- and interruptibleTCPClose() -- is called for
-	// this conn.
-	conn.interruptible.socketFd = _INVALID_FD // (requires mutex)
-
-	// This is the syscall.Connect result
 	if err != nil {
-		conn.mutex.Unlock()
 		return nil, ContextError(err)
 	}
 
 	// Convert the socket fd to a net.Conn
+	// See mutex note above.
+	conn.mutex.Lock()
 
-	file := os.NewFile(uintptr(socketFd), "")
-	fileConn, err := net.FileConn(file)
-	file.Close()
-	// No more deferred fd clean up on err
-	socketFd = _INVALID_FD
+	file := os.NewFile(uintptr(conn.interruptible.socketFd), "")
+	fileConn, err := net.FileConn(file) // net.FileConn() dups the fd
+	file.Close()                        // file.Close() closes the fd
+	conn.interruptible.socketFd = _INVALID_FD
 	if err != nil {
 		conn.mutex.Unlock()
 		return nil, ContextError(err)
 	}
-	conn.Conn = fileConn // (requires mutex)
-
+	conn.Conn = fileConn
 	conn.mutex.Unlock()
 
 	return conn, nil
 }
 
 func interruptibleTCPClose(interruptible interruptibleTCPSocket) error {
+
+	// Assumes conn.mutex is held
+
 	if interruptible.socketFd == _INVALID_FD {
 		return nil
 	}
-	return syscall.Close(interruptible.socketFd)
+	err := syscall.Close(interruptible.socketFd)
+	interruptible.socketFd = _INVALID_FD
+	return err
 }