浏览代码

Fix: pendingConns race condition

Don't allow adding a conn to pendingConns once pendingConns has
been closed. This interrupts an EstablishTunnel which starts after
pendingConns is closed but before it adds its conn to pendingConns.
Rod Hynes 11 年之前
父节点
当前提交
5566dde0d7
共有 4 个文件被更改,包括 43 次插入15 次删除
  1. 7 3
      psiphon/TCPConn_unix.go
  2. 4 1
      psiphon/TCPConn_windows.go
  3. 18 3
      psiphon/conn.go
  4. 14 8
      psiphon/controller.go

+ 7 - 3
psiphon/TCPConn_unix.go

@@ -51,8 +51,8 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 		return nil, ContextError(err)
 	}
 	defer func() {
-		// Cleanup on error
-		if err != nil {
+		// Cleanup on error (fd isset to -1 when it should no longer be closed)
+		if err != nil && socketFd != -1 {
 			syscall.Close(socketFd)
 		}
 	}()
@@ -95,7 +95,10 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 		interruptible: interruptibleTCPSocket{socketFd: socketFd},
 		readTimeout:   config.ReadTimeout,
 		writeTimeout:  config.WriteTimeout}
-	config.PendingConns.Add(conn)
+
+	if !config.PendingConns.Add(conn) {
+		return nil, ContextError(errors.New("pending connections already closed"))
+	}
 
 	// Connect the socket
 	// TODO: adjust the timeout to account for time spent resolving hostname
@@ -134,6 +137,7 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 		return nil, ContextError(err)
 	}
 	conn.interruptible.socketFd = -1
+	socketFd = -1
 	conn.Conn = fileConn
 
 	conn.mutex.Unlock()

+ 4 - 1
psiphon/TCPConn_windows.go

@@ -55,7 +55,10 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 		interruptible: interruptibleTCPSocket{results: make(chan *interruptibleDialResult, 2)},
 		readTimeout:   config.ReadTimeout,
 		writeTimeout:  config.WriteTimeout}
-	config.PendingConns.Add(conn)
+
+	if !config.PendingConns.Add(conn) {
+		return nil, ContextError(errors.New("pending connections already closed"))
+	}
 
 	// Call the blocking Dial in a goroutine
 	results := conn.interruptible.results

+ 18 - 3
psiphon/conn.go

@@ -87,18 +87,32 @@ type Conn interface {
 // Conns is a synchronized list of Conns that is used to coordinate
 // interrupting a set of goroutines establishing connections, or
 // close a set of open connections, etc.
+// Once the list is closed, no more items may be added to the
+// list (unless it is reset).
 type Conns struct {
-	mutex sync.Mutex
-	conns map[net.Conn]bool
+	mutex    sync.Mutex
+	isClosed bool
+	conns    map[net.Conn]bool
 }
 
-func (conns *Conns) Add(conn net.Conn) {
+func (conns *Conns) Reset() {
 	conns.mutex.Lock()
 	defer conns.mutex.Unlock()
+	conns.isClosed = false
+	conns.conns = make(map[net.Conn]bool)
+}
+
+func (conns *Conns) Add(conn net.Conn) bool {
+	conns.mutex.Lock()
+	defer conns.mutex.Unlock()
+	if conns.isClosed {
+		return false
+	}
 	if conns.conns == nil {
 		conns.conns = make(map[net.Conn]bool)
 	}
 	conns.conns[conn] = true
+	return true
 }
 
 func (conns *Conns) Remove(conn net.Conn) {
@@ -110,6 +124,7 @@ func (conns *Conns) Remove(conn net.Conn) {
 func (conns *Conns) CloseAll() {
 	conns.mutex.Lock()
 	defer conns.mutex.Unlock()
+	conns.isClosed = true
 	for conn, _ := range conns.conns {
 		conn.Close()
 	}

+ 14 - 8
psiphon/controller.go

@@ -377,6 +377,7 @@ func (controller *Controller) startEstablishing() {
 	controller.establishWaitGroup = new(sync.WaitGroup)
 	controller.stopEstablishingBroadcast = make(chan struct{})
 	controller.candidateServerEntries = make(chan *ServerEntry)
+	controller.pendingConns.Reset()
 
 	for i := 0; i < controller.config.ConnectionWorkerPoolSize; i++ {
 		controller.establishWaitGroup.Add(1)
@@ -470,12 +471,10 @@ func (controller *Controller) establishTunnelWorker() {
 	defer controller.establishWaitGroup.Done()
 loop:
 	for serverEntry := range controller.candidateServerEntries {
-		// Note: don't receive from candidateQueue and broadcastStopWorkers in the same
-		// select, since we want to prioritize receiving the stop signal
-		select {
-		case <-controller.stopEstablishingBroadcast:
+		// Note: don't receive from candidateServerEntries and stopEstablishingBroadcast
+		// in the same select, since we want to prioritize receiving the stop signal
+		if controller.isStopEstablishingBroadcast() {
 			break loop
-		default:
 		}
 
 		// There may already be a tunnel to this candidate. If so, skip it.
@@ -491,10 +490,8 @@ loop:
 		if err != nil {
 			// Before emitting error, check if establish interrupted, in which
 			// case the error is noise.
-			select {
-			case <-controller.stopEstablishingBroadcast:
+			if controller.isStopEstablishingBroadcast() {
 				break loop
-			default:
 			}
 			Notice(NOTICE_INFO, "failed to connect to %s: %s", serverEntry.IpAddress, err)
 			continue
@@ -512,3 +509,12 @@ loop:
 	}
 	Notice(NOTICE_INFO, "stopped establish worker")
 }
+
+func (controller *Controller) isStopEstablishingBroadcast() bool {
+	select {
+	case <-controller.stopEstablishingBroadcast:
+		return true
+	default:
+	}
+	return false
+}