|
@@ -45,7 +45,7 @@ import (
|
|
|
func (sshClient *sshClient) handleUdpgwChannel(newChannel ssh.NewChannel) {
|
|
func (sshClient *sshClient) handleUdpgwChannel(newChannel ssh.NewChannel) {
|
|
|
|
|
|
|
|
// Accept this channel immediately. This channel will replace any
|
|
// Accept this channel immediately. This channel will replace any
|
|
|
- // previously existing UDP channel for this client.
|
|
|
|
|
|
|
+ // previously existing udpgw channel for this client.
|
|
|
|
|
|
|
|
sshChannel, requests, err := newChannel.Accept()
|
|
sshChannel, requests, err := newChannel.Accept()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
@@ -57,16 +57,37 @@ func (sshClient *sshClient) handleUdpgwChannel(newChannel ssh.NewChannel) {
|
|
|
go ssh.DiscardRequests(requests)
|
|
go ssh.DiscardRequests(requests)
|
|
|
defer sshChannel.Close()
|
|
defer sshChannel.Close()
|
|
|
|
|
|
|
|
- sshClient.setUdpgwChannel(sshChannel)
|
|
|
|
|
-
|
|
|
|
|
multiplexer := &udpgwPortForwardMultiplexer{
|
|
multiplexer := &udpgwPortForwardMultiplexer{
|
|
|
sshClient: sshClient,
|
|
sshClient: sshClient,
|
|
|
sshChannel: sshChannel,
|
|
sshChannel: sshChannel,
|
|
|
portForwards: make(map[uint16]*udpgwPortForward),
|
|
portForwards: make(map[uint16]*udpgwPortForward),
|
|
|
portForwardLRU: common.NewLRUConns(),
|
|
portForwardLRU: common.NewLRUConns(),
|
|
|
relayWaitGroup: new(sync.WaitGroup),
|
|
relayWaitGroup: new(sync.WaitGroup),
|
|
|
|
|
+ runWaitGroup: new(sync.WaitGroup),
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ multiplexer.runWaitGroup.Add(1)
|
|
|
|
|
+
|
|
|
|
|
+ // setUdpgwChannelHandler will close any existing
|
|
|
|
|
+ // udpgwPortForwardMultiplexer, waiting for all run/relayDownstream
|
|
|
|
|
+ // goroutines to first terminate and all UDP socket resources to be
|
|
|
|
|
+ // cleaned up.
|
|
|
|
|
+ //
|
|
|
|
|
+ // This synchronous shutdown also ensures that the
|
|
|
|
|
+ // concurrentPortForwardCount is reduced to 0 before installing the new
|
|
|
|
|
+ // udpgwPortForwardMultiplexer and its LRU object. If the older handler
|
|
|
|
|
+ // were to dangle with open port forwards, and concurrentPortForwardCount
|
|
|
|
|
+ // were to hit the max, the wrong LRU, the new one, would be used to
|
|
|
|
|
+ // close the LRU port forward.
|
|
|
|
|
+ //
|
|
|
|
|
+ // Call setUdpgwHandler only after runWaitGroup is initialized, to ensure
|
|
|
|
|
+ // runWaitGroup.Wait() cannot be invoked (by some subsequent new udpgw
|
|
|
|
|
+ // channel) before initialized.
|
|
|
|
|
+
|
|
|
|
|
+ sshClient.setUdpgwChannelHandler(multiplexer)
|
|
|
|
|
+
|
|
|
multiplexer.run()
|
|
multiplexer.run()
|
|
|
|
|
+ multiplexer.runWaitGroup.Done()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
type udpgwPortForwardMultiplexer struct {
|
|
type udpgwPortForwardMultiplexer struct {
|
|
@@ -77,6 +98,27 @@ type udpgwPortForwardMultiplexer struct {
|
|
|
portForwards map[uint16]*udpgwPortForward
|
|
portForwards map[uint16]*udpgwPortForward
|
|
|
portForwardLRU *common.LRUConns
|
|
portForwardLRU *common.LRUConns
|
|
|
relayWaitGroup *sync.WaitGroup
|
|
relayWaitGroup *sync.WaitGroup
|
|
|
|
|
+ runWaitGroup *sync.WaitGroup
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (mux *udpgwPortForwardMultiplexer) stop() {
|
|
|
|
|
+
|
|
|
|
|
+ // udpgwPortForwardMultiplexer must be initialized by handleUdpgwChannel.
|
|
|
|
|
+ //
|
|
|
|
|
+ // stop closes the udpgw SSH channel, which will cause the run goroutine
|
|
|
|
|
+ // to exit its message read loop and await closure of all relayDownstream
|
|
|
|
|
+ // goroutines. Closing all port forward UDP conns will cause all
|
|
|
|
|
+ // relayDownstream to exit.
|
|
|
|
|
+
|
|
|
|
|
+ _ = mux.sshChannel.Close()
|
|
|
|
|
+
|
|
|
|
|
+ mux.portForwardsMutex.Lock()
|
|
|
|
|
+ for _, portForward := range mux.portForwards {
|
|
|
|
|
+ _ = portForward.conn.Close()
|
|
|
|
|
+ }
|
|
|
|
|
+ mux.portForwardsMutex.Unlock()
|
|
|
|
|
+
|
|
|
|
|
+ mux.runWaitGroup.Wait()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (mux *udpgwPortForwardMultiplexer) run() {
|
|
func (mux *udpgwPortForwardMultiplexer) run() {
|
|
@@ -277,7 +319,7 @@ func (mux *udpgwPortForwardMultiplexer) run() {
|
|
|
atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet)))
|
|
atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet)))
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Cleanup all UDP port forward workers when exiting
|
|
|
|
|
|
|
+ // Cleanup all udpgw port forward workers when exiting
|
|
|
|
|
|
|
|
mux.portForwardsMutex.Lock()
|
|
mux.portForwardsMutex.Lock()
|
|
|
for _, portForward := range mux.portForwards {
|
|
for _, portForward := range mux.portForwards {
|