|
@@ -44,29 +44,7 @@ const _INVALID_FD = -1
|
|
|
// To implement socket device binding and interruptible connecting, the lower-level
|
|
// To implement socket device binding and interruptible connecting, the lower-level
|
|
|
// syscall APIs are used. The sequence of syscalls in this implementation are
|
|
// syscall APIs are used. The sequence of syscalls in this implementation are
|
|
|
// taken from: https://code.google.com/p/go/issues/detail?id=6966
|
|
// taken from: https://code.google.com/p/go/issues/detail?id=6966
|
|
|
-func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err error) {
|
|
|
|
|
-
|
|
|
|
|
- // Create a socket and then, before connecting, add a TCPConn with
|
|
|
|
|
- // the unconnected socket to pendingConns. This allows pendingConns to
|
|
|
|
|
- // abort connections in progress.
|
|
|
|
|
- socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return nil, ContextError(err)
|
|
|
|
|
- }
|
|
|
|
|
- defer func() {
|
|
|
|
|
- // Cleanup on error
|
|
|
|
|
- // (socketFd is reset to _INVALID_FD once it should no longer be closed)
|
|
|
|
|
- if err != nil && socketFd != _INVALID_FD {
|
|
|
|
|
- syscall.Close(socketFd)
|
|
|
|
|
- }
|
|
|
|
|
- }()
|
|
|
|
|
-
|
|
|
|
|
- if config.DeviceBinder != nil {
|
|
|
|
|
- err = config.DeviceBinder.BindToDevice(socketFd)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return nil, ContextError(fmt.Errorf("BindToDevice failed: %s", err))
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+func interruptibleTCPDial(addr string, config *DialConfig) (*TCPConn, error) {
|
|
|
|
|
|
|
|
// Get the remote IP and port, resolving a domain name if necessary
|
|
// Get the remote IP and port, resolving a domain name if necessary
|
|
|
// TODO: domain name resolution isn't interruptible
|
|
// TODO: domain name resolution isn't interruptible
|
|
@@ -100,13 +78,44 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
|
|
|
var ip [4]byte
|
|
var ip [4]byte
|
|
|
copy(ip[:], ipAddrs[index].To4())
|
|
copy(ip[:], ipAddrs[index].To4())
|
|
|
|
|
|
|
|
- // Enable interruption
|
|
|
|
|
- conn = &TCPConn{interruptible: interruptibleTCPSocket{socketFd: socketFd}}
|
|
|
|
|
|
|
+ // Create a socket and then, before connecting, add a TCPConn with
|
|
|
|
|
+ // the unconnected socket to pendingConns. This allows pendingConns to
|
|
|
|
|
+ // interrupt/abort connections in progress.
|
|
|
|
|
+ socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, ContextError(err)
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
|
|
+ conn := &TCPConn{interruptible: interruptibleTCPSocket{socketFd: socketFd}}
|
|
|
|
|
+
|
|
|
|
|
+ // Cleanup on error
|
|
|
|
|
+ defer func() {
|
|
|
|
|
+ // Mutex required since conn may be in pendingConns, through which
|
|
|
|
|
+ // conn.Close() may be called from another goroutine. There are two
|
|
|
|
|
+ // risks:
|
|
|
|
|
+ // 1. standard race conditions reading/writing conn members.
|
|
|
|
|
+ // 2. closing the fd more than once, with the chance that other
|
|
|
|
|
+ // concurrent goroutines or threads may have already reused the fd.
|
|
|
|
|
+ conn.mutex.Lock()
|
|
|
|
|
+ if err != nil && conn.interruptible.socketFd != _INVALID_FD {
|
|
|
|
|
+ syscall.Close(conn.interruptible.socketFd)
|
|
|
|
|
+ conn.interruptible.socketFd = _INVALID_FD
|
|
|
|
|
+ }
|
|
|
|
|
+ conn.mutex.Unlock()
|
|
|
|
|
+ }()
|
|
|
|
|
+
|
|
|
|
|
+ // Enable interruption
|
|
|
if !config.PendingConns.Add(conn) {
|
|
if !config.PendingConns.Add(conn) {
|
|
|
return nil, ContextError(errors.New("pending connections already closed"))
|
|
return nil, ContextError(errors.New("pending connections already closed"))
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if config.DeviceBinder != nil {
|
|
|
|
|
+ err = config.DeviceBinder.BindToDevice(conn.interruptible.socketFd)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, ContextError(fmt.Errorf("BindToDevice failed: %s", err))
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// Connect the socket
|
|
// Connect the socket
|
|
|
// TODO: adjust the timeout to account for time spent resolving hostname
|
|
// TODO: adjust the timeout to account for time spent resolving hostname
|
|
|
sockAddr := syscall.SockaddrInet4{Addr: ip, Port: port}
|
|
sockAddr := syscall.SockaddrInet4{Addr: ip, Port: port}
|
|
@@ -116,52 +125,42 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
|
|
|
errChannel <- errors.New("connect timeout")
|
|
errChannel <- errors.New("connect timeout")
|
|
|
})
|
|
})
|
|
|
go func() {
|
|
go func() {
|
|
|
- errChannel <- syscall.Connect(socketFd, &sockAddr)
|
|
|
|
|
|
|
+ errChannel <- syscall.Connect(conn.interruptible.socketFd, &sockAddr)
|
|
|
}()
|
|
}()
|
|
|
err = <-errChannel
|
|
err = <-errChannel
|
|
|
} else {
|
|
} else {
|
|
|
- err = syscall.Connect(socketFd, &sockAddr)
|
|
|
|
|
|
|
+ err = syscall.Connect(conn.interruptible.socketFd, &sockAddr)
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Mutex required for writing to conn, since conn remains in
|
|
|
|
|
- // pendingConns, through which conn.Close() may be called from
|
|
|
|
|
- // another goroutine.
|
|
|
|
|
-
|
|
|
|
|
- conn.mutex.Lock()
|
|
|
|
|
-
|
|
|
|
|
- // From this point, ensure conn.interruptible.socketFd is reset
|
|
|
|
|
- // since the fd value may be reused for a different file or socket
|
|
|
|
|
- // before Close() -- and interruptibleTCPClose() -- is called for
|
|
|
|
|
- // this conn.
|
|
|
|
|
- conn.interruptible.socketFd = _INVALID_FD // (requires mutex)
|
|
|
|
|
-
|
|
|
|
|
- // This is the syscall.Connect result
|
|
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- conn.mutex.Unlock()
|
|
|
|
|
return nil, ContextError(err)
|
|
return nil, ContextError(err)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Convert the socket fd to a net.Conn
|
|
// Convert the socket fd to a net.Conn
|
|
|
|
|
+ // See mutex note above.
|
|
|
|
|
+ conn.mutex.Lock()
|
|
|
|
|
|
|
|
- file := os.NewFile(uintptr(socketFd), "")
|
|
|
|
|
- fileConn, err := net.FileConn(file)
|
|
|
|
|
- file.Close()
|
|
|
|
|
- // No more deferred fd clean up on err
|
|
|
|
|
- socketFd = _INVALID_FD
|
|
|
|
|
|
|
+ file := os.NewFile(uintptr(conn.interruptible.socketFd), "")
|
|
|
|
|
+ fileConn, err := net.FileConn(file) // net.FileConn() dups the fd
|
|
|
|
|
+ file.Close() // file.Close() closes the fd
|
|
|
|
|
+ conn.interruptible.socketFd = _INVALID_FD
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
conn.mutex.Unlock()
|
|
conn.mutex.Unlock()
|
|
|
return nil, ContextError(err)
|
|
return nil, ContextError(err)
|
|
|
}
|
|
}
|
|
|
- conn.Conn = fileConn // (requires mutex)
|
|
|
|
|
-
|
|
|
|
|
|
|
+ conn.Conn = fileConn
|
|
|
conn.mutex.Unlock()
|
|
conn.mutex.Unlock()
|
|
|
|
|
|
|
|
return conn, nil
|
|
return conn, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func interruptibleTCPClose(interruptible interruptibleTCPSocket) error {
|
|
func interruptibleTCPClose(interruptible interruptibleTCPSocket) error {
|
|
|
|
|
+
|
|
|
|
|
+ // Assumes conn.mutex is held
|
|
|
|
|
+
|
|
|
if interruptible.socketFd == _INVALID_FD {
|
|
if interruptible.socketFd == _INVALID_FD {
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
|
- return syscall.Close(interruptible.socketFd)
|
|
|
|
|
|
|
+ err := syscall.Close(interruptible.socketFd)
|
|
|
|
|
+ interruptible.socketFd = _INVALID_FD
|
|
|
|
|
+ return err
|
|
|
}
|
|
}
|