|
|
@@ -25,7 +25,6 @@ import (
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"net"
|
|
|
- "runtime/debug"
|
|
|
"sync"
|
|
|
"sync/atomic"
|
|
|
|
|
|
@@ -35,7 +34,7 @@ import (
|
|
|
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
|
|
|
)
|
|
|
|
|
|
-// handleUDPChannel implements UDP port forwarding. A single UDP
|
|
|
+// handleUdpgwChannel implements UDP port forwarding. A single UDP
|
|
|
// SSH channel follows the udpgw protocol, which multiplexes many
|
|
|
// UDP port forwards.
|
|
|
//
|
|
|
@@ -43,10 +42,10 @@ import (
|
|
|
// Copyright (c) 2009, Ambroz Bizjak <ambrop7@gmail.com>
|
|
|
// https://github.com/ambrop72/badvpn
|
|
|
//
|
|
|
-func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
+func (sshClient *sshClient) handleUdpgwChannel(newChannel ssh.NewChannel) {
|
|
|
|
|
|
// Accept this channel immediately. This channel will replace any
|
|
|
- // previously existing UDP channel for this client.
|
|
|
+ // previously existing udpgw channel for this client.
|
|
|
|
|
|
sshChannel, requests, err := newChannel.Accept()
|
|
|
if err != nil {
|
|
|
@@ -58,33 +57,75 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
go ssh.DiscardRequests(requests)
|
|
|
defer sshChannel.Close()
|
|
|
|
|
|
- sshClient.setUDPChannel(sshChannel)
|
|
|
-
|
|
|
- multiplexer := &udpPortForwardMultiplexer{
|
|
|
+ multiplexer := &udpgwPortForwardMultiplexer{
|
|
|
sshClient: sshClient,
|
|
|
sshChannel: sshChannel,
|
|
|
- portForwards: make(map[uint16]*udpPortForward),
|
|
|
+ portForwards: make(map[uint16]*udpgwPortForward),
|
|
|
portForwardLRU: common.NewLRUConns(),
|
|
|
relayWaitGroup: new(sync.WaitGroup),
|
|
|
+ runWaitGroup: new(sync.WaitGroup),
|
|
|
}
|
|
|
+
|
|
|
+ multiplexer.runWaitGroup.Add(1)
|
|
|
+
|
|
|
+ // setUdpgwChannelHandler will close any existing
|
|
|
+ // udpgwPortForwardMultiplexer, waiting for all run/relayDownstream
|
|
|
+ // goroutines to first terminate and all UDP socket resources to be
|
|
|
+ // cleaned up.
|
|
|
+ //
|
|
|
+ // This synchronous shutdown also ensures that the
|
|
|
+ // concurrentPortForwardCount is reduced to 0 before installing the new
|
|
|
+ // udpgwPortForwardMultiplexer and its LRU object. If the older handler
|
|
|
+ // were to dangle with open port forwards, and concurrentPortForwardCount
|
|
|
+ // were to hit the max, the wrong LRU, the new one, would be used to
|
|
|
+ // close the LRU port forward.
|
|
|
+ //
|
|
|
+ // Call setUdpgwHandler only after runWaitGroup is initialized, to ensure
|
|
|
+ // runWaitGroup.Wait() cannot be invoked (by some subsequent new udpgw
|
|
|
+ // channel) before initialized.
|
|
|
+
|
|
|
+ sshClient.setUdpgwChannelHandler(multiplexer)
|
|
|
+
|
|
|
multiplexer.run()
|
|
|
+ multiplexer.runWaitGroup.Done()
|
|
|
}
|
|
|
|
|
|
-type udpPortForwardMultiplexer struct {
|
|
|
+type udpgwPortForwardMultiplexer struct {
|
|
|
sshClient *sshClient
|
|
|
sshChannelWriteMutex sync.Mutex
|
|
|
sshChannel ssh.Channel
|
|
|
portForwardsMutex sync.Mutex
|
|
|
- portForwards map[uint16]*udpPortForward
|
|
|
+ portForwards map[uint16]*udpgwPortForward
|
|
|
portForwardLRU *common.LRUConns
|
|
|
relayWaitGroup *sync.WaitGroup
|
|
|
+ runWaitGroup *sync.WaitGroup
|
|
|
}
|
|
|
|
|
|
-func (mux *udpPortForwardMultiplexer) run() {
|
|
|
+func (mux *udpgwPortForwardMultiplexer) stop() {
|
|
|
|
|
|
- // In a loop, read udpgw messages from the client to this channel. Each message is
|
|
|
- // a UDP packet to send upstream either via a new port forward, or on an existing
|
|
|
- // port forward.
|
|
|
+ // udpgwPortForwardMultiplexer must be initialized by handleUdpgwChannel.
|
|
|
+ //
|
|
|
+ // stop closes the udpgw SSH channel, which will cause the run goroutine
|
|
|
+ // to exit its message read loop and await closure of all relayDownstream
|
|
|
+ // goroutines. Closing all port forward UDP conns will cause all
|
|
|
+ // relayDownstream to exit.
|
|
|
+
|
|
|
+ _ = mux.sshChannel.Close()
|
|
|
+
|
|
|
+ mux.portForwardsMutex.Lock()
|
|
|
+ for _, portForward := range mux.portForwards {
|
|
|
+ _ = portForward.conn.Close()
|
|
|
+ }
|
|
|
+ mux.portForwardsMutex.Unlock()
|
|
|
+
|
|
|
+ mux.runWaitGroup.Wait()
|
|
|
+}
|
|
|
+
|
|
|
+func (mux *udpgwPortForwardMultiplexer) run() {
|
|
|
+
|
|
|
+ // In a loop, read udpgw messages from the client to this channel. Each
|
|
|
+ // message contains a UDP packet to send upstream either via a new port
|
|
|
+ // forward, or on an existing port forward.
|
|
|
//
|
|
|
// A goroutine is run to read downstream packets for each UDP port forward. All read
|
|
|
// packets are encapsulated in udpgw protocol and sent down the channel to the client.
|
|
|
@@ -92,16 +133,6 @@ func (mux *udpPortForwardMultiplexer) run() {
|
|
|
// When the client disconnects or the server shuts down, the channel will close and
|
|
|
// readUdpgwMessage will exit with EOF.
|
|
|
|
|
|
- // Recover from and log any unexpected panics caused by udpgw input handling bugs.
|
|
|
- // Note: this covers the run() goroutine only and not relayDownstream() goroutines.
|
|
|
- defer func() {
|
|
|
- if e := recover(); e != nil {
|
|
|
- err := errors.Tracef(
|
|
|
- "udpPortForwardMultiplexer panic: %s: %s", e, debug.Stack())
|
|
|
- log.WithTraceFields(LogFields{"error": err}).Warning("run failed")
|
|
|
- }
|
|
|
- }()
|
|
|
-
|
|
|
buffer := make([]byte, udpgwProtocolMaxMessageSize)
|
|
|
for {
|
|
|
// Note: message.packet points to the reusable memory in "buffer".
|
|
|
@@ -119,27 +150,37 @@ func (mux *udpPortForwardMultiplexer) run() {
|
|
|
portForward := mux.portForwards[message.connID]
|
|
|
mux.portForwardsMutex.Unlock()
|
|
|
|
|
|
- if portForward != nil && message.discardExistingConn {
|
|
|
+ // In the udpgw protocol, an existing port forward is closed when
|
|
|
+ // either the discard flag is set or the remote address has changed.
|
|
|
+
|
|
|
+ if portForward != nil &&
|
|
|
+ (message.discardExistingConn ||
|
|
|
+ !bytes.Equal(portForward.remoteIP, message.remoteIP) ||
|
|
|
+ portForward.remotePort != message.remotePort) {
|
|
|
+
|
|
|
// The port forward's goroutine will complete cleanup, including
|
|
|
// tallying stats and calling sshClient.closedPortForward.
|
|
|
// portForward.conn.Close() will signal this shutdown.
|
|
|
- // TODO: wait for goroutine to exit before proceeding?
|
|
|
portForward.conn.Close()
|
|
|
- portForward = nil
|
|
|
- }
|
|
|
|
|
|
- if portForward != nil {
|
|
|
+ // Synchronously await the termination of the relayDownstream
|
|
|
+ // goroutine. This ensures that the previous goroutine won't
|
|
|
+ // invoke removePortForward, with the connID that will be reused
|
|
|
+ // for the new port forward, after this point.
|
|
|
+ //
|
|
|
+ // Limitation: this synchronous shutdown cannot prevent a "wrong
|
|
|
+ // remote address" error on the badvpn udpgw client, which occurs
|
|
|
+ // when the client recycles a port forward (setting discard) but
|
|
|
+ // receives, from the server, a udpgw message containing the old
|
|
|
+ // remote address for the previous port forward with the same
|
|
|
+ // conn ID. That downstream message from the server may be in
|
|
|
+ // flight in the SSH channel when the client discard message arrives.
|
|
|
+ portForward.relayWaitGroup.Wait()
|
|
|
|
|
|
- // Verify that portForward remote address matches latest message
|
|
|
-
|
|
|
- if !bytes.Equal(portForward.remoteIP, message.remoteIP) ||
|
|
|
- portForward.remotePort != message.remotePort {
|
|
|
-
|
|
|
- log.WithTrace().Warning("UDP port forward remote address mismatch")
|
|
|
- continue
|
|
|
- }
|
|
|
+ portForward = nil
|
|
|
+ }
|
|
|
|
|
|
- } else {
|
|
|
+ if portForward == nil {
|
|
|
|
|
|
// Create a new port forward
|
|
|
|
|
|
@@ -237,17 +278,18 @@ func (mux *udpPortForwardMultiplexer) run() {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- portForward = &udpPortForward{
|
|
|
- connID: message.connID,
|
|
|
- preambleSize: message.preambleSize,
|
|
|
- remoteIP: message.remoteIP,
|
|
|
- remotePort: message.remotePort,
|
|
|
- dialIP: dialIP,
|
|
|
- conn: conn,
|
|
|
- lruEntry: lruEntry,
|
|
|
- bytesUp: 0,
|
|
|
- bytesDown: 0,
|
|
|
- mux: mux,
|
|
|
+ portForward = &udpgwPortForward{
|
|
|
+ connID: message.connID,
|
|
|
+ preambleSize: message.preambleSize,
|
|
|
+ remoteIP: message.remoteIP,
|
|
|
+ remotePort: message.remotePort,
|
|
|
+ dialIP: dialIP,
|
|
|
+ conn: conn,
|
|
|
+ lruEntry: lruEntry,
|
|
|
+ bytesUp: 0,
|
|
|
+ bytesDown: 0,
|
|
|
+ relayWaitGroup: new(sync.WaitGroup),
|
|
|
+ mux: mux,
|
|
|
}
|
|
|
|
|
|
if message.forwardDNS {
|
|
|
@@ -258,6 +300,7 @@ func (mux *udpPortForwardMultiplexer) run() {
|
|
|
mux.portForwards[portForward.connID] = portForward
|
|
|
mux.portForwardsMutex.Unlock()
|
|
|
|
|
|
+ portForward.relayWaitGroup.Add(1)
|
|
|
mux.relayWaitGroup.Add(1)
|
|
|
go portForward.relayDownstream()
|
|
|
}
|
|
|
@@ -276,7 +319,7 @@ func (mux *udpPortForwardMultiplexer) run() {
|
|
|
atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet)))
|
|
|
}
|
|
|
|
|
|
- // Cleanup all UDP port forward workers when exiting
|
|
|
+ // Cleanup all udpgw port forward workers when exiting
|
|
|
|
|
|
mux.portForwardsMutex.Lock()
|
|
|
for _, portForward := range mux.portForwards {
|
|
|
@@ -288,13 +331,13 @@ func (mux *udpPortForwardMultiplexer) run() {
|
|
|
mux.relayWaitGroup.Wait()
|
|
|
}
|
|
|
|
|
|
-func (mux *udpPortForwardMultiplexer) removePortForward(connID uint16) {
|
|
|
+func (mux *udpgwPortForwardMultiplexer) removePortForward(connID uint16) {
|
|
|
mux.portForwardsMutex.Lock()
|
|
|
delete(mux.portForwards, connID)
|
|
|
mux.portForwardsMutex.Unlock()
|
|
|
}
|
|
|
|
|
|
-type udpPortForward struct {
|
|
|
+type udpgwPortForward struct {
|
|
|
// Note: 64-bit ints used with atomic operations are placed
|
|
|
// at the start of struct to ensure 64-bit alignment.
|
|
|
// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
|
|
|
@@ -309,10 +352,12 @@ type udpPortForward struct {
|
|
|
dialIP net.IP
|
|
|
conn net.Conn
|
|
|
lruEntry *common.LRUConnsEntry
|
|
|
- mux *udpPortForwardMultiplexer
|
|
|
+ relayWaitGroup *sync.WaitGroup
|
|
|
+ mux *udpgwPortForwardMultiplexer
|
|
|
}
|
|
|
|
|
|
-func (portForward *udpPortForward) relayDownstream() {
|
|
|
+func (portForward *udpgwPortForward) relayDownstream() {
|
|
|
+ defer portForward.relayWaitGroup.Done()
|
|
|
defer portForward.mux.relayWaitGroup.Done()
|
|
|
|
|
|
// Downstream UDP packets are read into the reusable memory
|