Эх сурвалжийг харах

Fix potential sync.WaitGroup race condition

- In the previous code, it was possible for WaitGroup.Add to be called, with a
  counter of zero, after WaitGroup.Wait.
Rod Hynes 4 жил өмнө
parent
commit
4ab7f1d5f0

+ 25 - 18
psiphon/packetTunnelTransport.go

@@ -20,7 +20,6 @@
 package psiphon
 package psiphon
 
 
 import (
 import (
-	"context"
 	"net"
 	"net"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
@@ -41,8 +40,7 @@ type PacketTunnelTransport struct {
 	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
 	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
 	readTimeout   int64
 	readTimeout   int64
 	readDeadline  int64
 	readDeadline  int64
-	runCtx        context.Context
-	stopRunning   context.CancelFunc
+	closed        int32
 	workers       *sync.WaitGroup
 	workers       *sync.WaitGroup
 	readMutex     sync.Mutex
 	readMutex     sync.Mutex
 	writeMutex    sync.Mutex
 	writeMutex    sync.Mutex
@@ -54,12 +52,7 @@ type PacketTunnelTransport struct {
 
 
 // NewPacketTunnelTransport initializes a PacketTunnelTransport.
 // NewPacketTunnelTransport initializes a PacketTunnelTransport.
 func NewPacketTunnelTransport() *PacketTunnelTransport {
 func NewPacketTunnelTransport() *PacketTunnelTransport {
-
-	runCtx, stopRunning := context.WithCancel(context.Background())
-
 	return &PacketTunnelTransport{
 	return &PacketTunnelTransport{
-		runCtx:       runCtx,
-		stopRunning:  stopRunning,
 		workers:      new(sync.WaitGroup),
 		workers:      new(sync.WaitGroup),
 		channelReady: sync.NewCond(new(sync.Mutex)),
 		channelReady: sync.NewCond(new(sync.Mutex)),
 	}
 	}
@@ -73,8 +66,8 @@ func (p *PacketTunnelTransport) Read(data []byte) (int, error) {
 	p.readMutex.Lock()
 	p.readMutex.Lock()
 	defer p.readMutex.Unlock()
 	defer p.readMutex.Unlock()
 
 
-	// getChannel will block if there's no channel.
-
+	// getChannel will block if there's no channel, or return an error when
+	// closed.
 	channelConn, channelTunnel, err := p.getChannel()
 	channelConn, channelTunnel, err := p.getChannel()
 	if err != nil {
 	if err != nil {
 		return 0, errors.Trace(err)
 		return 0, errors.Trace(err)
@@ -109,6 +102,8 @@ func (p *PacketTunnelTransport) Write(data []byte) (int, error) {
 	p.writeMutex.Lock()
 	p.writeMutex.Lock()
 	defer p.writeMutex.Unlock()
 	defer p.writeMutex.Unlock()
 
 
+	// getChannel will block if there's no channel, or return an error when
+	// closed.
 	channelConn, channelTunnel, err := p.getChannel()
 	channelConn, channelTunnel, err := p.getChannel()
 	if err != nil {
 	if err != nil {
 		return 0, errors.Trace(err)
 		return 0, errors.Trace(err)
@@ -186,12 +181,14 @@ func (p *PacketTunnelTransport) Write(data []byte) (int, error) {
 // closed and any blocking Read/Write calls will be interrupted.
 // closed and any blocking Read/Write calls will be interrupted.
 func (p *PacketTunnelTransport) Close() error {
 func (p *PacketTunnelTransport) Close() error {
 
 
-	p.stopRunning()
+	if !atomic.CompareAndSwapInt32(&p.closed, 0, 1) {
+		return nil
+	}
 
 
 	p.workers.Wait()
 	p.workers.Wait()
 
 
 	// This broadcast is to wake up reads or writes blocking in getChannel; those
 	// This broadcast is to wake up reads or writes blocking in getChannel; those
-	// getChannel calls should then abort on the p.runCtx.Done() check.
+	// getChannel calls should then abort on the p.closed check.
 	p.channelReady.Broadcast()
 	p.channelReady.Broadcast()
 
 
 	p.channelMutex.Lock()
 	p.channelMutex.Lock()
@@ -207,8 +204,22 @@ func (p *PacketTunnelTransport) Close() error {
 // UseTunnel sets the PacketTunnelTransport to use a new transport channel within
 // UseTunnel sets the PacketTunnelTransport to use a new transport channel within
 // the specified tunnel. UseTunnel does not block on the open channel call; it spawns
 // the specified tunnel. UseTunnel does not block on the open channel call; it spawns
 // a worker that calls tunnel.DialPacketTunnelChannel and uses the resulting channel.
 // a worker that calls tunnel.DialPacketTunnelChannel and uses the resulting channel.
+// UseTunnel has no effect once Close is called.
+//
+// Note that a blocked tunnel.DialPacketTunnelChannel with block Close;
+// callers should arrange for DialPacketTunnelChannel to be interrupted when
+// calling Close.
 func (p *PacketTunnelTransport) UseTunnel(tunnel *Tunnel) {
 func (p *PacketTunnelTransport) UseTunnel(tunnel *Tunnel) {
 
 
+	// Don't start a worker when closed, after which workers.Wait may be called.
+	if atomic.LoadInt32(&p.closed) == 1 {
+		return
+	}
+
+	// Spawning a new worker ensures that the latest tunnel is used to dial a
+	// new channel without delaying, as might happen if using a single worker
+	// that consumes a channel of tunnels.
+
 	p.workers.Add(1)
 	p.workers.Add(1)
 	go func(tunnel *Tunnel) {
 	go func(tunnel *Tunnel) {
 		defer p.workers.Done()
 		defer p.workers.Done()
@@ -241,11 +252,9 @@ func (p *PacketTunnelTransport) setChannel(
 	// Concurrency note: this check is within the mutex to ensure that a
 	// Concurrency note: this check is within the mutex to ensure that a
 	// UseTunnel call concurrent with a Close call doesn't leave a channel
 	// UseTunnel call concurrent with a Close call doesn't leave a channel
 	// set.
 	// set.
-	select {
-	case <-p.runCtx.Done():
+	if atomic.LoadInt32(&p.closed) == 1 {
 		p.channelMutex.Unlock()
 		p.channelMutex.Unlock()
 		return
 		return
-	default:
 	}
 	}
 
 
 	// Interrupt Read/Write calls blocking on any previous channel.
 	// Interrupt Read/Write calls blocking on any previous channel.
@@ -279,10 +288,8 @@ func (p *PacketTunnelTransport) getChannel() (net.Conn, *Tunnel, error) {
 	defer p.channelReady.L.Unlock()
 	defer p.channelReady.L.Unlock()
 	for {
 	for {
 
 
-		select {
-		case <-p.runCtx.Done():
+		if atomic.LoadInt32(&p.closed) == 1 {
 			return nil, nil, errors.TraceNew("already closed")
 			return nil, nil, errors.TraceNew("already closed")
-		default:
 		}
 		}
 
 
 		p.channelMutex.Lock()
 		p.channelMutex.Lock()