|
@@ -20,7 +20,6 @@
|
|
|
package psiphon
|
|
package psiphon
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
- "context"
|
|
|
|
|
"net"
|
|
"net"
|
|
|
"sync"
|
|
"sync"
|
|
|
"sync/atomic"
|
|
"sync/atomic"
|
|
@@ -41,8 +40,7 @@ type PacketTunnelTransport struct {
|
|
|
// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
|
|
// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
|
|
|
readTimeout int64
|
|
readTimeout int64
|
|
|
readDeadline int64
|
|
readDeadline int64
|
|
|
- runCtx context.Context
|
|
|
|
|
- stopRunning context.CancelFunc
|
|
|
|
|
|
|
+ closed int32
|
|
|
workers *sync.WaitGroup
|
|
workers *sync.WaitGroup
|
|
|
readMutex sync.Mutex
|
|
readMutex sync.Mutex
|
|
|
writeMutex sync.Mutex
|
|
writeMutex sync.Mutex
|
|
@@ -54,12 +52,7 @@ type PacketTunnelTransport struct {
|
|
|
|
|
|
|
|
// NewPacketTunnelTransport initializes a PacketTunnelTransport.
|
|
// NewPacketTunnelTransport initializes a PacketTunnelTransport.
|
|
|
func NewPacketTunnelTransport() *PacketTunnelTransport {
|
|
func NewPacketTunnelTransport() *PacketTunnelTransport {
|
|
|
-
|
|
|
|
|
- runCtx, stopRunning := context.WithCancel(context.Background())
|
|
|
|
|
-
|
|
|
|
|
return &PacketTunnelTransport{
|
|
return &PacketTunnelTransport{
|
|
|
- runCtx: runCtx,
|
|
|
|
|
- stopRunning: stopRunning,
|
|
|
|
|
workers: new(sync.WaitGroup),
|
|
workers: new(sync.WaitGroup),
|
|
|
channelReady: sync.NewCond(new(sync.Mutex)),
|
|
channelReady: sync.NewCond(new(sync.Mutex)),
|
|
|
}
|
|
}
|
|
@@ -73,8 +66,8 @@ func (p *PacketTunnelTransport) Read(data []byte) (int, error) {
|
|
|
p.readMutex.Lock()
|
|
p.readMutex.Lock()
|
|
|
defer p.readMutex.Unlock()
|
|
defer p.readMutex.Unlock()
|
|
|
|
|
|
|
|
- // getChannel will block if there's no channel.
|
|
|
|
|
-
|
|
|
|
|
|
|
+ // getChannel will block if there's no channel, or return an error when
|
|
|
|
|
+ // closed.
|
|
|
channelConn, channelTunnel, err := p.getChannel()
|
|
channelConn, channelTunnel, err := p.getChannel()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return 0, errors.Trace(err)
|
|
return 0, errors.Trace(err)
|
|
@@ -109,6 +102,8 @@ func (p *PacketTunnelTransport) Write(data []byte) (int, error) {
|
|
|
p.writeMutex.Lock()
|
|
p.writeMutex.Lock()
|
|
|
defer p.writeMutex.Unlock()
|
|
defer p.writeMutex.Unlock()
|
|
|
|
|
|
|
|
|
|
+ // getChannel will block if there's no channel, or return an error when
|
|
|
|
|
+ // closed.
|
|
|
channelConn, channelTunnel, err := p.getChannel()
|
|
channelConn, channelTunnel, err := p.getChannel()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return 0, errors.Trace(err)
|
|
return 0, errors.Trace(err)
|
|
@@ -186,12 +181,14 @@ func (p *PacketTunnelTransport) Write(data []byte) (int, error) {
|
|
|
// closed and any blocking Read/Write calls will be interrupted.
|
|
// closed and any blocking Read/Write calls will be interrupted.
|
|
|
func (p *PacketTunnelTransport) Close() error {
|
|
func (p *PacketTunnelTransport) Close() error {
|
|
|
|
|
|
|
|
- p.stopRunning()
|
|
|
|
|
|
|
+ if !atomic.CompareAndSwapInt32(&p.closed, 0, 1) {
|
|
|
|
|
+ return nil
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
p.workers.Wait()
|
|
p.workers.Wait()
|
|
|
|
|
|
|
|
// This broadcast is to wake up reads or writes blocking in getChannel; those
|
|
// This broadcast is to wake up reads or writes blocking in getChannel; those
|
|
|
- // getChannel calls should then abort on the p.runCtx.Done() check.
|
|
|
|
|
|
|
+ // getChannel calls should then abort on the p.closed check.
|
|
|
p.channelReady.Broadcast()
|
|
p.channelReady.Broadcast()
|
|
|
|
|
|
|
|
p.channelMutex.Lock()
|
|
p.channelMutex.Lock()
|
|
@@ -207,8 +204,22 @@ func (p *PacketTunnelTransport) Close() error {
|
|
|
// UseTunnel sets the PacketTunnelTransport to use a new transport channel within
|
|
// UseTunnel sets the PacketTunnelTransport to use a new transport channel within
|
|
|
// the specified tunnel. UseTunnel does not block on the open channel call; it spawns
|
|
// the specified tunnel. UseTunnel does not block on the open channel call; it spawns
|
|
|
// a worker that calls tunnel.DialPacketTunnelChannel and uses the resulting channel.
|
|
// a worker that calls tunnel.DialPacketTunnelChannel and uses the resulting channel.
|
|
|
|
|
+// UseTunnel has no effect once Close is called.
|
|
|
|
|
+//
|
|
|
|
|
+// Note that a blocked tunnel.DialPacketTunnelChannel with block Close;
|
|
|
|
|
+// callers should arrange for DialPacketTunnelChannel to be interrupted when
|
|
|
|
|
+// calling Close.
|
|
|
func (p *PacketTunnelTransport) UseTunnel(tunnel *Tunnel) {
|
|
func (p *PacketTunnelTransport) UseTunnel(tunnel *Tunnel) {
|
|
|
|
|
|
|
|
|
|
+ // Don't start a worker when closed, after which workers.Wait may be called.
|
|
|
|
|
+ if atomic.LoadInt32(&p.closed) == 1 {
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Spawning a new worker ensures that the latest tunnel is used to dial a
|
|
|
|
|
+ // new channel without delaying, as might happen if using a single worker
|
|
|
|
|
+ // that consumes a channel of tunnels.
|
|
|
|
|
+
|
|
|
p.workers.Add(1)
|
|
p.workers.Add(1)
|
|
|
go func(tunnel *Tunnel) {
|
|
go func(tunnel *Tunnel) {
|
|
|
defer p.workers.Done()
|
|
defer p.workers.Done()
|
|
@@ -241,11 +252,9 @@ func (p *PacketTunnelTransport) setChannel(
|
|
|
// Concurrency note: this check is within the mutex to ensure that a
|
|
// Concurrency note: this check is within the mutex to ensure that a
|
|
|
// UseTunnel call concurrent with a Close call doesn't leave a channel
|
|
// UseTunnel call concurrent with a Close call doesn't leave a channel
|
|
|
// set.
|
|
// set.
|
|
|
- select {
|
|
|
|
|
- case <-p.runCtx.Done():
|
|
|
|
|
|
|
+ if atomic.LoadInt32(&p.closed) == 1 {
|
|
|
p.channelMutex.Unlock()
|
|
p.channelMutex.Unlock()
|
|
|
return
|
|
return
|
|
|
- default:
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Interrupt Read/Write calls blocking on any previous channel.
|
|
// Interrupt Read/Write calls blocking on any previous channel.
|
|
@@ -279,10 +288,8 @@ func (p *PacketTunnelTransport) getChannel() (net.Conn, *Tunnel, error) {
|
|
|
defer p.channelReady.L.Unlock()
|
|
defer p.channelReady.L.Unlock()
|
|
|
for {
|
|
for {
|
|
|
|
|
|
|
|
- select {
|
|
|
|
|
- case <-p.runCtx.Done():
|
|
|
|
|
|
|
+ if atomic.LoadInt32(&p.closed) == 1 {
|
|
|
return nil, nil, errors.TraceNew("already closed")
|
|
return nil, nil, errors.TraceNew("already closed")
|
|
|
- default:
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
p.channelMutex.Lock()
|
|
p.channelMutex.Lock()
|