Просмотр исходного кода

Synchronize udpgw port forward conn ID reuse

Rod Hynes 4 лет назад
Родитель
Сommit
060d7fdc05
1 измененных файлов с 29 добавлено и 11 удалено
  1. 29 11
      psiphon/server/udp.go

+ 29 - 11
psiphon/server/udp.go

@@ -119,8 +119,22 @@ func (mux *udpgwPortForwardMultiplexer) run() {
 			// The port forward's goroutine will complete cleanup, including
 			// tallying stats and calling sshClient.closedPortForward.
 			// portForward.conn.Close() will signal this shutdown.
-			// TODO: wait for goroutine to exit before proceeding?
 			portForward.conn.Close()
+
+			// Synchronously await the termination of the relayDownstream
+			// goroutine. This ensures that the previous goroutine won't
+			// invoke removePortForward, with the connID that will be reused
+			// for the new port forward, after this point.
+			//
+			// Limitation: this synchronous shutdown cannot prevent a "wrong
+			// remote address" error on the badvpn udpgw client, which occurs
+			// when the client recycles a port forward (setting discard) but
+			// receives, from the server, a udpgw message containing the old
+			// remote address for the previous port forward with the same
+			// conn ID. That downstream message from the server may be in
+			// flight in the SSH channel when the client discard message arrives.
+			portForward.relayWaitGroup.Wait()
+
 			portForward = nil
 		}
 
@@ -223,16 +237,17 @@ func (mux *udpgwPortForwardMultiplexer) run() {
 			}
 
 			portForward = &udpgwPortForward{
-				connID:       message.connID,
-				preambleSize: message.preambleSize,
-				remoteIP:     message.remoteIP,
-				remotePort:   message.remotePort,
-				dialIP:       dialIP,
-				conn:         conn,
-				lruEntry:     lruEntry,
-				bytesUp:      0,
-				bytesDown:    0,
-				mux:          mux,
+				connID:         message.connID,
+				preambleSize:   message.preambleSize,
+				remoteIP:       message.remoteIP,
+				remotePort:     message.remotePort,
+				dialIP:         dialIP,
+				conn:           conn,
+				lruEntry:       lruEntry,
+				bytesUp:        0,
+				bytesDown:      0,
+				relayWaitGroup: new(sync.WaitGroup),
+				mux:            mux,
 			}
 
 			if message.forwardDNS {
@@ -243,6 +258,7 @@ func (mux *udpgwPortForwardMultiplexer) run() {
 			mux.portForwards[portForward.connID] = portForward
 			mux.portForwardsMutex.Unlock()
 
+			portForward.relayWaitGroup.Add(1)
 			mux.relayWaitGroup.Add(1)
 			go portForward.relayDownstream()
 		}
@@ -294,10 +310,12 @@ type udpgwPortForward struct {
 	dialIP            net.IP
 	conn              net.Conn
 	lruEntry          *common.LRUConnsEntry
+	relayWaitGroup    *sync.WaitGroup
 	mux               *udpgwPortForwardMultiplexer
 }
 
 func (portForward *udpgwPortForward) relayDownstream() {
+	defer portForward.relayWaitGroup.Done()
 	defer portForward.mux.relayWaitGroup.Done()
 
 	// Downstream UDP packets are read into the reusable memory