Przeglądaj źródła

Merge pull request #600 from rod-hynes/master

udpgw fixes
Rod Hynes 4 lat temu
rodzic
commit
a000eb2ba7
3 zmienionych plików z 187 dodań i 66 usunięć
  1. 70 0
      psiphon/server/server_test.go
  2. 18 12
      psiphon/server/tunnelServer.go
  3. 99 54
      psiphon/server/udp.go

+ 70 - 0
psiphon/server/server_test.go

@@ -1338,6 +1338,10 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	expectServerBPFField := ServerBPFEnabled() && doServerTactics
 	expectServerPacketManipulationField := runConfig.doPacketManipulation
 	expectBurstFields := runConfig.doBurstMonitor
+	expectTCPPortForwardDial := runConfig.doTunneledWebRequest
+	expectTCPDataTransfer := runConfig.doTunneledWebRequest && !expectTrafficFailure && !runConfig.doSplitTunnel
+	// Even with expectTrafficFailure, DNS port forwards will succeed
+	expectUDPDataTransfer := runConfig.doTunneledNTPRequest
 
 	select {
 	case logFields := <-serverTunnelLog:
@@ -1347,6 +1351,9 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			expectServerBPFField,
 			expectServerPacketManipulationField,
 			expectBurstFields,
+			expectTCPPortForwardDial,
+			expectTCPDataTransfer,
+			expectUDPDataTransfer,
 			logFields)
 		if err != nil {
 			t.Fatalf("invalid server tunnel log fields: %s", err)
@@ -1404,6 +1411,9 @@ func checkExpectedServerTunnelLogFields(
 	expectServerBPFField bool,
 	expectServerPacketManipulationField bool,
 	expectBurstFields bool,
+	expectTCPPortForwardDial bool,
+	expectTCPDataTransfer bool,
+	expectUDPDataTransfer bool,
 	fields map[string]interface{}) error {
 
 	// Limitations:
@@ -1649,6 +1659,66 @@ func checkExpectedServerTunnelLogFields(
 		return fmt.Errorf("unexpected network_type '%s'", fields["network_type"])
 	}
 
+	var checkTCPMetric func(float64) bool
+	if expectTCPPortForwardDial {
+		checkTCPMetric = func(f float64) bool { return f > 0 }
+	} else {
+		checkTCPMetric = func(f float64) bool { return f == 0 }
+	}
+
+	for _, name := range []string{
+		"peak_concurrent_dialing_port_forward_count_tcp",
+	} {
+		if fields[name] == nil {
+			return fmt.Errorf("missing expected field '%s'", name)
+		}
+		if !checkTCPMetric(fields[name].(float64)) {
+			return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name])
+		}
+	}
+
+	if expectTCPDataTransfer {
+		checkTCPMetric = func(f float64) bool { return f > 0 }
+	} else {
+		checkTCPMetric = func(f float64) bool { return f == 0 }
+	}
+
+	for _, name := range []string{
+		"bytes_up_tcp",
+		"bytes_down_tcp",
+		"peak_concurrent_port_forward_count_tcp",
+		"total_port_forward_count_tcp",
+	} {
+		if fields[name] == nil {
+			return fmt.Errorf("missing expected field '%s'", name)
+		}
+		if !checkTCPMetric(fields[name].(float64)) {
+			return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name])
+		}
+	}
+
+	var checkUDPMetric func(float64) bool
+	if expectUDPDataTransfer {
+		checkUDPMetric = func(f float64) bool { return f > 0 }
+	} else {
+		checkUDPMetric = func(f float64) bool { return f == 0 }
+	}
+
+	for _, name := range []string{
+		"bytes_up_udp",
+		"bytes_down_udp",
+		"peak_concurrent_port_forward_count_udp",
+		"total_port_forward_count_udp",
+		"total_udpgw_channel_count",
+	} {
+		if fields[name] == nil {
+			return fmt.Errorf("missing expected field '%s'", name)
+		}
+		if !checkUDPMetric(fields[name].(float64)) {
+			return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name])
+		}
+	}
+
 	return nil
 }
 

+ 18 - 12
psiphon/server/tunnelServer.go

@@ -1260,8 +1260,10 @@ type sshClient struct {
 	isFirstTunnelInSession               bool
 	supportsServerRequests               bool
 	handshakeState                       handshakeState
-	udpChannel                           ssh.Channel
+	udpgwChannelHandler                  *udpgwPortForwardMultiplexer
+	totalUdpgwChannelCount               int
 	packetTunnelChannel                  ssh.Channel
+	totalPacketTunnelChannelCount        int
 	trafficRules                         TrafficRules
 	tcpTrafficState                      trafficState
 	udpTrafficState                      trafficState
@@ -2495,11 +2497,11 @@ func (sshClient *sshClient) handleNewTCPPortForwardChannel(
 
 	// Intercept TCP port forwards to a specified udpgw server and handle directly.
 	// TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
-	isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
+	isUdpgwChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
 		sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
 			net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect)))
 
-	if isUDPChannel {
+	if isUdpgwChannel {
 
 		// Dispatch immediately. handleUDPChannel runs the udpgw protocol in its
 		// own worker goroutine.
@@ -2507,7 +2509,7 @@ func (sshClient *sshClient) handleNewTCPPortForwardChannel(
 		waitGroup.Add(1)
 		go func(channel ssh.NewChannel) {
 			defer waitGroup.Done()
-			sshClient.handleUDPChannel(channel)
+			sshClient.handleUdpgwChannel(channel)
 		}(newChannel)
 
 	} else {
@@ -2558,19 +2560,21 @@ func (sshClient *sshClient) setPacketTunnelChannel(channel ssh.Channel) {
 		sshClient.packetTunnelChannel.Close()
 	}
 	sshClient.packetTunnelChannel = channel
+	sshClient.totalPacketTunnelChannelCount += 1
 	sshClient.Unlock()
 }
 
-// setUDPChannel sets the single UDP channel for this sshClient.
-// Each sshClient may have only one concurrent UDP channel. Each
-// UDP channel multiplexes many UDP port forwards via the udpgw
-// protocol. Any existing UDP channel is closed.
-func (sshClient *sshClient) setUDPChannel(channel ssh.Channel) {
+// setUdpgwChannelHandler sets the single udpgw channel handler for this
+// sshClient. Each sshClient may have only one concurrent udpgw
+// channel/handler. Each udpgw channel multiplexes many UDP port forwards via
+// the udpgw protocol. Any existing udpgw channel/handler is closed.
+func (sshClient *sshClient) setUdpgwChannelHandler(udpgwChannelHandler *udpgwPortForwardMultiplexer) {
 	sshClient.Lock()
-	if sshClient.udpChannel != nil {
-		sshClient.udpChannel.Close()
+	if sshClient.udpgwChannelHandler != nil {
+		sshClient.udpgwChannelHandler.stop()
 	}
-	sshClient.udpChannel = channel
+	sshClient.udpgwChannelHandler = udpgwChannelHandler
+	sshClient.totalUdpgwChannelCount += 1
 	sshClient.Unlock()
 }
 
@@ -2616,6 +2620,8 @@ func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) {
 	// sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful
 	logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
 	logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
+	logFields["total_udpgw_channel_count"] = sshClient.totalUdpgwChannelCount
+	logFields["total_packet_tunnel_channel_count"] = sshClient.totalPacketTunnelChannelCount
 
 	logFields["pre_handshake_random_stream_count"] = sshClient.preHandshakeRandomStreamMetrics.count
 	logFields["pre_handshake_random_stream_upstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.upstreamBytes

+ 99 - 54
psiphon/server/udp.go

@@ -25,7 +25,6 @@ import (
 	"fmt"
 	"io"
 	"net"
-	"runtime/debug"
 	"sync"
 	"sync/atomic"
 
@@ -35,7 +34,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 )
 
-// handleUDPChannel implements UDP port forwarding. A single UDP
+// handleUdpgwChannel implements UDP port forwarding. A single UDP
 // SSH channel follows the udpgw protocol, which multiplexes many
 // UDP port forwards.
 //
@@ -43,10 +42,10 @@ import (
 // Copyright (c) 2009, Ambroz Bizjak <ambrop7@gmail.com>
 // https://github.com/ambrop72/badvpn
 //
-func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
+func (sshClient *sshClient) handleUdpgwChannel(newChannel ssh.NewChannel) {
 
 	// Accept this channel immediately. This channel will replace any
-	// previously existing UDP channel for this client.
+	// previously existing udpgw channel for this client.
 
 	sshChannel, requests, err := newChannel.Accept()
 	if err != nil {
@@ -58,33 +57,75 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 	go ssh.DiscardRequests(requests)
 	defer sshChannel.Close()
 
-	sshClient.setUDPChannel(sshChannel)
-
-	multiplexer := &udpPortForwardMultiplexer{
+	multiplexer := &udpgwPortForwardMultiplexer{
 		sshClient:      sshClient,
 		sshChannel:     sshChannel,
-		portForwards:   make(map[uint16]*udpPortForward),
+		portForwards:   make(map[uint16]*udpgwPortForward),
 		portForwardLRU: common.NewLRUConns(),
 		relayWaitGroup: new(sync.WaitGroup),
+		runWaitGroup:   new(sync.WaitGroup),
 	}
+
+	multiplexer.runWaitGroup.Add(1)
+
+	// setUdpgwChannelHandler will close any existing
+	// udpgwPortForwardMultiplexer, waiting for all run/relayDownstream
+	// goroutines to first terminate and all UDP socket resources to be
+	// cleaned up.
+	//
+	// This synchronous shutdown also ensures that the
+	// concurrentPortForwardCount is reduced to 0 before installing the new
+	// udpgwPortForwardMultiplexer and its LRU object. If the older handler
+	// were to dangle with open port forwards, and concurrentPortForwardCount
+	// were to hit the max, the wrong LRU, the new one, would be used to
+	// close the LRU port forward.
+	//
+	// Call setUdpgwHandler only after runWaitGroup is initialized, to ensure
+	// runWaitGroup.Wait() cannot be invoked (by some subsequent new udpgw
+	// channel) before initialized.
+
+	sshClient.setUdpgwChannelHandler(multiplexer)
+
 	multiplexer.run()
+	multiplexer.runWaitGroup.Done()
 }
 
-type udpPortForwardMultiplexer struct {
+type udpgwPortForwardMultiplexer struct {
 	sshClient            *sshClient
 	sshChannelWriteMutex sync.Mutex
 	sshChannel           ssh.Channel
 	portForwardsMutex    sync.Mutex
-	portForwards         map[uint16]*udpPortForward
+	portForwards         map[uint16]*udpgwPortForward
 	portForwardLRU       *common.LRUConns
 	relayWaitGroup       *sync.WaitGroup
+	runWaitGroup         *sync.WaitGroup
 }
 
-func (mux *udpPortForwardMultiplexer) run() {
+func (mux *udpgwPortForwardMultiplexer) stop() {
 
-	// In a loop, read udpgw messages from the client to this channel. Each message is
-	// a UDP packet to send upstream either via a new port forward, or on an existing
-	// port forward.
+	// udpgwPortForwardMultiplexer must be initialized by handleUdpgwChannel.
+	//
+	// stop closes the udpgw SSH channel, which will cause the run goroutine
+	// to exit its message read loop and await closure of all relayDownstream
+	// goroutines. Closing all port forward UDP conns will cause all
+	// relayDownstream to exit.
+
+	_ = mux.sshChannel.Close()
+
+	mux.portForwardsMutex.Lock()
+	for _, portForward := range mux.portForwards {
+		_ = portForward.conn.Close()
+	}
+	mux.portForwardsMutex.Unlock()
+
+	mux.runWaitGroup.Wait()
+}
+
+func (mux *udpgwPortForwardMultiplexer) run() {
+
+	// In a loop, read udpgw messages from the client to this channel. Each
+	// message contains a UDP packet to send upstream either via a new port
+	// forward, or on an existing port forward.
 	//
 	// A goroutine is run to read downstream packets for each UDP port forward. All read
 	// packets are encapsulated in udpgw protocol and sent down the channel to the client.
@@ -92,16 +133,6 @@ func (mux *udpPortForwardMultiplexer) run() {
 	// When the client disconnects or the server shuts down, the channel will close and
 	// readUdpgwMessage will exit with EOF.
 
-	// Recover from and log any unexpected panics caused by udpgw input handling bugs.
-	// Note: this covers the run() goroutine only and not relayDownstream() goroutines.
-	defer func() {
-		if e := recover(); e != nil {
-			err := errors.Tracef(
-				"udpPortForwardMultiplexer panic: %s: %s", e, debug.Stack())
-			log.WithTraceFields(LogFields{"error": err}).Warning("run failed")
-		}
-	}()
-
 	buffer := make([]byte, udpgwProtocolMaxMessageSize)
 	for {
 		// Note: message.packet points to the reusable memory in "buffer".
@@ -119,27 +150,37 @@ func (mux *udpPortForwardMultiplexer) run() {
 		portForward := mux.portForwards[message.connID]
 		mux.portForwardsMutex.Unlock()
 
-		if portForward != nil && message.discardExistingConn {
+		// In the udpgw protocol, an existing port forward is closed when
+		// either the discard flag is set or the remote address has changed.
+
+		if portForward != nil &&
+			(message.discardExistingConn ||
+				!bytes.Equal(portForward.remoteIP, message.remoteIP) ||
+				portForward.remotePort != message.remotePort) {
+
 			// The port forward's goroutine will complete cleanup, including
 			// tallying stats and calling sshClient.closedPortForward.
 			// portForward.conn.Close() will signal this shutdown.
-			// TODO: wait for goroutine to exit before proceeding?
 			portForward.conn.Close()
-			portForward = nil
-		}
 
-		if portForward != nil {
+			// Synchronously await the termination of the relayDownstream
+			// goroutine. This ensures that the previous goroutine won't
+			// invoke removePortForward, with the connID that will be reused
+			// for the new port forward, after this point.
+			//
+			// Limitation: this synchronous shutdown cannot prevent a "wrong
+			// remote address" error on the badvpn udpgw client, which occurs
+			// when the client recycles a port forward (setting discard) but
+			// receives, from the server, a udpgw message containing the old
+			// remote address for the previous port forward with the same
+			// conn ID. That downstream message from the server may be in
+			// flight in the SSH channel when the client discard message arrives.
+			portForward.relayWaitGroup.Wait()
 
-			// Verify that portForward remote address matches latest message
-
-			if !bytes.Equal(portForward.remoteIP, message.remoteIP) ||
-				portForward.remotePort != message.remotePort {
-
-				log.WithTrace().Warning("UDP port forward remote address mismatch")
-				continue
-			}
+			portForward = nil
+		}
 
-		} else {
+		if portForward == nil {
 
 			// Create a new port forward
 
@@ -237,17 +278,18 @@ func (mux *udpPortForwardMultiplexer) run() {
 				continue
 			}
 
-			portForward = &udpPortForward{
-				connID:       message.connID,
-				preambleSize: message.preambleSize,
-				remoteIP:     message.remoteIP,
-				remotePort:   message.remotePort,
-				dialIP:       dialIP,
-				conn:         conn,
-				lruEntry:     lruEntry,
-				bytesUp:      0,
-				bytesDown:    0,
-				mux:          mux,
+			portForward = &udpgwPortForward{
+				connID:         message.connID,
+				preambleSize:   message.preambleSize,
+				remoteIP:       message.remoteIP,
+				remotePort:     message.remotePort,
+				dialIP:         dialIP,
+				conn:           conn,
+				lruEntry:       lruEntry,
+				bytesUp:        0,
+				bytesDown:      0,
+				relayWaitGroup: new(sync.WaitGroup),
+				mux:            mux,
 			}
 
 			if message.forwardDNS {
@@ -258,6 +300,7 @@ func (mux *udpPortForwardMultiplexer) run() {
 			mux.portForwards[portForward.connID] = portForward
 			mux.portForwardsMutex.Unlock()
 
+			portForward.relayWaitGroup.Add(1)
 			mux.relayWaitGroup.Add(1)
 			go portForward.relayDownstream()
 		}
@@ -276,7 +319,7 @@ func (mux *udpPortForwardMultiplexer) run() {
 		atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet)))
 	}
 
-	// Cleanup all UDP port forward workers when exiting
+	// Cleanup all udpgw port forward workers when exiting
 
 	mux.portForwardsMutex.Lock()
 	for _, portForward := range mux.portForwards {
@@ -288,13 +331,13 @@ func (mux *udpPortForwardMultiplexer) run() {
 	mux.relayWaitGroup.Wait()
 }
 
-func (mux *udpPortForwardMultiplexer) removePortForward(connID uint16) {
+func (mux *udpgwPortForwardMultiplexer) removePortForward(connID uint16) {
 	mux.portForwardsMutex.Lock()
 	delete(mux.portForwards, connID)
 	mux.portForwardsMutex.Unlock()
 }
 
-type udpPortForward struct {
+type udpgwPortForward struct {
 	// Note: 64-bit ints used with atomic operations are placed
 	// at the start of struct to ensure 64-bit alignment.
 	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
@@ -309,10 +352,12 @@ type udpPortForward struct {
 	dialIP            net.IP
 	conn              net.Conn
 	lruEntry          *common.LRUConnsEntry
-	mux               *udpPortForwardMultiplexer
+	relayWaitGroup    *sync.WaitGroup
+	mux               *udpgwPortForwardMultiplexer
 }
 
-func (portForward *udpPortForward) relayDownstream() {
+func (portForward *udpgwPortForward) relayDownstream() {
+	defer portForward.relayWaitGroup.Done()
 	defer portForward.mux.relayWaitGroup.Done()
 
 	// Downstream UDP packets are read into the reusable memory