Просмотр исходного кода

Synchronize udpgw channel replacement

Rod Hynes 4 лет назад
Родитель
Сommit
b2aea840d5
2 измененных файлов с 55 добавлено и 13 удалено
  1. 9 9
      psiphon/server/tunnelServer.go
  2. 46 4
      psiphon/server/udp.go

+ 9 - 9
psiphon/server/tunnelServer.go

@@ -1260,7 +1260,7 @@ type sshClient struct {
 	isFirstTunnelInSession               bool
 	isFirstTunnelInSession               bool
 	supportsServerRequests               bool
 	supportsServerRequests               bool
 	handshakeState                       handshakeState
 	handshakeState                       handshakeState
-	udpgwChannel                         ssh.Channel
+	udpgwChannelHandler                  *udpgwPortForwardMultiplexer
 	packetTunnelChannel                  ssh.Channel
 	packetTunnelChannel                  ssh.Channel
 	trafficRules                         TrafficRules
 	trafficRules                         TrafficRules
 	tcpTrafficState                      trafficState
 	tcpTrafficState                      trafficState
@@ -2561,16 +2561,16 @@ func (sshClient *sshClient) setPacketTunnelChannel(channel ssh.Channel) {
 	sshClient.Unlock()
 	sshClient.Unlock()
 }
 }
 
 
-// setUdpgwChannel sets the single udpgw channel for this sshClient.
-// Each sshClient may have only one concurrent udpgw channel. Each
-// udpgw channel multiplexes many UDP port forwards via the udpgw
-// protocol. Any existing udpgw channel is closed.
-func (sshClient *sshClient) setUdpgwChannel(channel ssh.Channel) {
+// setUdpgwChannelHandler sets the single udpgw channel handler for this
+// sshClient. Each sshClient may have only one concurrent udpgw
+// channel/handler. Each udpgw channel multiplexes many UDP port forwards via
+// the udpgw protocol. Any existing udpgw channel/handler is closed.
+func (sshClient *sshClient) setUdpgwChannelHandler(udpgwChannelHandler *udpgwPortForwardMultiplexer) {
 	sshClient.Lock()
 	sshClient.Lock()
-	if sshClient.udpgwChannel != nil {
-		sshClient.udpgwChannel.Close()
+	if sshClient.udpgwChannelHandler != nil {
+		sshClient.udpgwChannelHandler.stop()
 	}
 	}
-	sshClient.udpgwChannel = channel
+	sshClient.udpgwChannelHandler = udpgwChannelHandler
 	sshClient.Unlock()
 	sshClient.Unlock()
 }
 }
 
 

+ 46 - 4
psiphon/server/udp.go

@@ -45,7 +45,7 @@ import (
 func (sshClient *sshClient) handleUdpgwChannel(newChannel ssh.NewChannel) {
 func (sshClient *sshClient) handleUdpgwChannel(newChannel ssh.NewChannel) {
 
 
 	// Accept this channel immediately. This channel will replace any
 	// Accept this channel immediately. This channel will replace any
-	// previously existing UDP channel for this client.
+	// previously existing udpgw channel for this client.
 
 
 	sshChannel, requests, err := newChannel.Accept()
 	sshChannel, requests, err := newChannel.Accept()
 	if err != nil {
 	if err != nil {
@@ -57,16 +57,37 @@ func (sshClient *sshClient) handleUdpgwChannel(newChannel ssh.NewChannel) {
 	go ssh.DiscardRequests(requests)
 	go ssh.DiscardRequests(requests)
 	defer sshChannel.Close()
 	defer sshChannel.Close()
 
 
-	sshClient.setUdpgwChannel(sshChannel)
-
 	multiplexer := &udpgwPortForwardMultiplexer{
 	multiplexer := &udpgwPortForwardMultiplexer{
 		sshClient:      sshClient,
 		sshClient:      sshClient,
 		sshChannel:     sshChannel,
 		sshChannel:     sshChannel,
 		portForwards:   make(map[uint16]*udpgwPortForward),
 		portForwards:   make(map[uint16]*udpgwPortForward),
 		portForwardLRU: common.NewLRUConns(),
 		portForwardLRU: common.NewLRUConns(),
 		relayWaitGroup: new(sync.WaitGroup),
 		relayWaitGroup: new(sync.WaitGroup),
+		runWaitGroup:   new(sync.WaitGroup),
 	}
 	}
+
+	multiplexer.runWaitGroup.Add(1)
+
+	// setUdpgwChannelHandler will close any existing
+	// udpgwPortForwardMultiplexer, waiting for all run/relayDownstream
+	// goroutines to first terminate and all UDP socket resources to be
+	// cleaned up.
+	//
+	// This synchronous shutdown also ensures that the
+	// concurrentPortForwardCount is reduced to 0 before installing the new
+	// udpgwPortForwardMultiplexer and its LRU object. If the older handler
+	// were to dangle with open port forwards, and concurrentPortForwardCount
+	// were to hit the max, the wrong LRU, the new one, would be used to
+	// close the LRU port forward.
+	//
+	// Call setUdpgwHandler only after runWaitGroup is initialized, to ensure
+	// runWaitGroup.Wait() cannot be invoked (by some subsequent new udpgw
+	// channel) before initialized.
+
+	sshClient.setUdpgwChannelHandler(multiplexer)
+
 	multiplexer.run()
 	multiplexer.run()
+	multiplexer.runWaitGroup.Done()
 }
 }
 
 
 type udpgwPortForwardMultiplexer struct {
 type udpgwPortForwardMultiplexer struct {
@@ -77,6 +98,27 @@ type udpgwPortForwardMultiplexer struct {
 	portForwards         map[uint16]*udpgwPortForward
 	portForwards         map[uint16]*udpgwPortForward
 	portForwardLRU       *common.LRUConns
 	portForwardLRU       *common.LRUConns
 	relayWaitGroup       *sync.WaitGroup
 	relayWaitGroup       *sync.WaitGroup
+	runWaitGroup         *sync.WaitGroup
+}
+
+func (mux *udpgwPortForwardMultiplexer) stop() {
+
+	// udpgwPortForwardMultiplexer must be initialized by handleUdpgwChannel.
+	//
+	// stop closes the udpgw SSH channel, which will cause the run goroutine
+	// to exit its message read loop and await closure of all relayDownstream
+	// goroutines. Closing all port forward UDP conns will cause all
+	// relayDownstream to exit.
+
+	_ = mux.sshChannel.Close()
+
+	mux.portForwardsMutex.Lock()
+	for _, portForward := range mux.portForwards {
+		_ = portForward.conn.Close()
+	}
+	mux.portForwardsMutex.Unlock()
+
+	mux.runWaitGroup.Wait()
 }
 }
 
 
 func (mux *udpgwPortForwardMultiplexer) run() {
 func (mux *udpgwPortForwardMultiplexer) run() {
@@ -277,7 +319,7 @@ func (mux *udpgwPortForwardMultiplexer) run() {
 		atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet)))
 		atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet)))
 	}
 	}
 
 
-	// Cleanup all UDP port forward workers when exiting
+	// Cleanup all udpgw port forward workers when exiting
 
 
 	mux.portForwardsMutex.Lock()
 	mux.portForwardsMutex.Lock()
 	for _, portForward := range mux.portForwards {
 	for _, portForward := range mux.portForwards {