|
|
@@ -23,6 +23,7 @@ import (
|
|
|
"errors"
|
|
|
"net"
|
|
|
"os"
|
|
|
+ "sync"
|
|
|
"syscall"
|
|
|
"time"
|
|
|
)
|
|
|
@@ -36,19 +37,26 @@ import (
|
|
|
// routing compatibility, for example).
|
|
|
type Conn struct {
|
|
|
net.Conn
|
|
|
- socketFd int
|
|
|
- needCloseSocketFd bool
|
|
|
- isDisconnected bool
|
|
|
- disconnectionSignal chan bool
|
|
|
- readTimeout time.Duration
|
|
|
- writeTimeout time.Duration
|
|
|
+ mutex sync.Mutex
|
|
|
+ socketFd int
|
|
|
+ isClosed bool
|
|
|
+ closedSignal chan bool
|
|
|
+ readTimeout time.Duration
|
|
|
+ writeTimeout time.Duration
|
|
|
}
|
|
|
|
|
|
-// NewConn creates a new, configured Conn. Unlike standard Dial
|
|
|
-// functions, this does not return a connected net.Conn. Call the Connect function
|
|
|
-// to complete the connection establishment. To implement device binding and
|
|
|
-// interruptible connecting, the lower-level syscall APIs are used.
|
|
|
-func NewConn(readTimeout, writeTimeout time.Duration, deviceName string) (*Conn, error) {
|
|
|
+// NewConn creates a new, connected Conn. The connection can be interrupted
|
|
|
+// using pendingConns.interrupt(): the new Conn is added to pendingConns
|
|
|
+// before the socket connect beings. The caller is responsible for removing the
|
|
|
+// returned Conn from pendingConns.
|
|
|
+// To implement device binding and interruptible connecting, the lower-level
|
|
|
+// syscall APIs are used. The sequence of syscalls in this implementation are
|
|
|
+// taken from: https://code.google.com/p/go/issues/detail?id=6966
|
|
|
+func Dial(
|
|
|
+ ipAddress string, port int,
|
|
|
+ readTimeout, writeTimeout time.Duration,
|
|
|
+ pendingConns *PendingConns) (conn *Conn, err error) {
|
|
|
+
|
|
|
socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
@@ -58,7 +66,7 @@ func NewConn(readTimeout, writeTimeout time.Duration, deviceName string) (*Conn,
|
|
|
syscall.Close(socketFd)
|
|
|
return nil, err
|
|
|
}
|
|
|
- if deviceName != "" {
|
|
|
+ /*
|
|
|
// TODO: requires root, which we won't have on Android in VpnService mode
|
|
|
// an alternative may be to use http://golang.org/pkg/syscall/#UnixRights to
|
|
|
// send the fd to the main Android process which receives the fd with
|
|
|
@@ -69,74 +77,74 @@ func NewConn(readTimeout, writeTimeout time.Duration, deviceName string) (*Conn,
|
|
|
// https://code.google.com/p/ics-openvpn/source/browse/main/src/main/java/de/blinkt/openvpn/core/OpenVpnManagementThread.java#164
|
|
|
const SO_BINDTODEVICE = 0x19 // only defined for Linux
|
|
|
err = syscall.SetsockoptString(socketFd, syscall.SOL_SOCKET, SO_BINDTODEVICE, deviceName)
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- return &Conn{
|
|
|
- socketFd: socketFd,
|
|
|
- needCloseSocketFd: true,
|
|
|
- readTimeout: readTimeout,
|
|
|
- writeTimeout: writeTimeout}, nil
|
|
|
-}
|
|
|
-
|
|
|
-// Connect establishes a connection to the specified host. The sequence of
|
|
|
-// syscalls in this implementation are taken from: https://code.google.com/p/go/issues/detail?id=6966
|
|
|
-func (conn *Conn) Connect(ipAddress string, port int) (err error) {
|
|
|
+ */
|
|
|
+ conn = &Conn{
|
|
|
+ socketFd: socketFd,
|
|
|
+ readTimeout: readTimeout,
|
|
|
+ writeTimeout: writeTimeout}
|
|
|
+ pendingConns.Add(conn)
|
|
|
// TODO: domain name resolution (for meek)
|
|
|
var addr [4]byte
|
|
|
copy(addr[:], net.ParseIP(ipAddress).To4())
|
|
|
sockAddr := syscall.SockaddrInet4{Addr: addr, Port: port}
|
|
|
err = syscall.Connect(conn.socketFd, &sockAddr)
|
|
|
if err != nil {
|
|
|
- return err
|
|
|
+ return nil, err
|
|
|
}
|
|
|
file := os.NewFile(uintptr(conn.socketFd), "")
|
|
|
defer file.Close()
|
|
|
- fileConn, err := net.FileConn(file)
|
|
|
+ conn.Conn, err = net.FileConn(file)
|
|
|
if err != nil {
|
|
|
- return err
|
|
|
+ return nil, err
|
|
|
}
|
|
|
- conn.Conn = fileConn
|
|
|
- conn.needCloseSocketFd = false
|
|
|
- return nil
|
|
|
+ return conn, nil
|
|
|
}
|
|
|
|
|
|
-// SetDisconnectionSignal sets the channel which will be signaled
|
|
|
-// when the connection terminates. This function returns an error
|
|
|
-// if the connection is already disconnected (and would never send
|
|
|
+// SetClosedSignal sets the channel which will be signaled
|
|
|
+// when the connection is closed. This function returns an error
|
|
|
+// if the connection is already closed (and would never send
|
|
|
// the signal).
|
|
|
-func (conn *Conn) SetDisconnectionSignal(disconnectionSignal chan bool) (err error) {
|
|
|
- if conn.isDisconnected {
|
|
|
- return errors.New("connection is already disconnected")
|
|
|
+func (conn *Conn) SetClosedSignal(closedSignal chan bool) (err error) {
|
|
|
+ // TEMP **** needs comments
|
|
|
+ conn.mutex.Lock()
|
|
|
+ defer conn.mutex.Unlock()
|
|
|
+ if conn.isClosed {
|
|
|
+ return errors.New("connection is already closed")
|
|
|
}
|
|
|
- conn.disconnectionSignal = disconnectionSignal
|
|
|
+ conn.closedSignal = closedSignal
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-// Close terminates down an established (net.Conn) or establishing (socketFd) connection.
|
|
|
+// Close terminates a connected (net.Conn) or connecting (socketFd) Conn.
|
|
|
+// A mutex syncs access to conn struct, allowing Close() to be called
|
|
|
+// from a goroutine that wants to interrupt the primary goroutine using
|
|
|
+// the connection.
|
|
|
func (conn *Conn) Close() (err error) {
|
|
|
- if conn.needCloseSocketFd {
|
|
|
- err = syscall.Close(conn.socketFd)
|
|
|
- conn.needCloseSocketFd = false
|
|
|
- }
|
|
|
- if conn.Conn != nil {
|
|
|
- err = conn.Conn.Close()
|
|
|
+ var closedSignal chan bool
|
|
|
+ conn.mutex.Lock()
|
|
|
+ if !conn.isClosed {
|
|
|
+ if conn.Conn == nil {
|
|
|
+ err = syscall.Close(conn.socketFd)
|
|
|
+ } else {
|
|
|
+ err = conn.Conn.Close()
|
|
|
+ }
|
|
|
+ closedSignal = conn.closedSignal
|
|
|
+ conn.isClosed = true
|
|
|
}
|
|
|
- if conn.disconnectionSignal != nil {
|
|
|
+ conn.mutex.Unlock()
|
|
|
+ if closedSignal != nil {
|
|
|
select {
|
|
|
- case conn.disconnectionSignal <- true:
|
|
|
+ case closedSignal <- true:
|
|
|
default:
|
|
|
}
|
|
|
}
|
|
|
- conn.isDisconnected = true
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
// Read wraps standard Read to add an idle timeout. The connection
|
|
|
-// is explicitly terminated on timeout.
|
|
|
+// is explicitly closed on timeout.
|
|
|
func (conn *Conn) Read(buffer []byte) (n int, err error) {
|
|
|
- if conn.Conn == nil {
|
|
|
- return 0, errors.New("not connected")
|
|
|
- }
|
|
|
+ // Note: no mutex on the conn.readTimeout access
|
|
|
if conn.readTimeout != 0 {
|
|
|
err = conn.Conn.SetReadDeadline(time.Now().Add(conn.readTimeout))
|
|
|
if err != nil {
|
|
|
@@ -151,11 +159,9 @@ func (conn *Conn) Read(buffer []byte) (n int, err error) {
|
|
|
}
|
|
|
|
|
|
// Write wraps standard Write to add an idle timeout The connection
|
|
|
-// is explicitly terminated on timeout.
|
|
|
+// is explicitly closed on timeout.
|
|
|
func (conn *Conn) Write(buffer []byte) (n int, err error) {
|
|
|
- if conn.Conn == nil {
|
|
|
- return 0, errors.New("not connected")
|
|
|
- }
|
|
|
+ // Note: no mutex on the conn.writeTimeout access
|
|
|
if conn.writeTimeout != 0 {
|
|
|
err = conn.Conn.SetWriteDeadline(time.Now().Add(conn.writeTimeout))
|
|
|
if err != nil {
|
|
|
@@ -168,3 +174,35 @@ func (conn *Conn) Write(buffer []byte) (n int, err error) {
|
|
|
}
|
|
|
return
|
|
|
}
|
|
|
+
|
|
|
+// PendingConns is a synchronized list of Conns that's used to coordinate
|
|
|
+// interrupting a set of goroutines establishing connections.
|
|
|
+type PendingConns struct {
|
|
|
+ mutex sync.Mutex
|
|
|
+ conns []*Conn
|
|
|
+}
|
|
|
+
|
|
|
+func (pendingConns *PendingConns) Add(conn *Conn) {
|
|
|
+ pendingConns.mutex.Lock()
|
|
|
+ defer pendingConns.mutex.Unlock()
|
|
|
+ pendingConns.conns = append(pendingConns.conns, conn)
|
|
|
+}
|
|
|
+
|
|
|
+func (pendingConns *PendingConns) Remove(conn *Conn) {
|
|
|
+ pendingConns.mutex.Lock()
|
|
|
+ defer pendingConns.mutex.Unlock()
|
|
|
+ for index, pendingConn := range pendingConns.conns {
|
|
|
+ if conn == pendingConn {
|
|
|
+ pendingConns.conns = append(pendingConns.conns[:index], pendingConns.conns[index+1:]...)
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (pendingConns *PendingConns) Interrupt() {
|
|
|
+ pendingConns.mutex.Lock()
|
|
|
+ defer pendingConns.mutex.Unlock()
|
|
|
+ for _, conn := range pendingConns.conns {
|
|
|
+ conn.Close()
|
|
|
+ }
|
|
|
+}
|