Просмотр исходного кода

Fixes and cleanup of UDP port forwarding

Rod Hynes 9 лет назад
Родитель
Сommit
ef52ed0775
2 измененных файлов с 148 добавлено и 125 удалено
  1. 5 2
      psiphon/server/config.go
  2. 143 123
      psiphon/server/udpChannel.go

+ 5 - 2
psiphon/server/config.go

@@ -323,9 +323,12 @@ func LoadConfig(configJSONs [][]byte) (*Config, error) {
 	}
 
 	validateNetworkAddress := func(address string) error {
-		_, portStr, err := net.SplitHostPort(config.DNSServerAddress)
+		host, port, err := net.SplitHostPort(address)
+		if err == nil && net.ParseIP(host) == nil {
+			err = errors.New("Host must be an IP address")
+		}
 		if err == nil {
-			_, err = strconv.Atoi(portStr)
+			_, err = strconv.Atoi(port)
 		}
 		return err
 	}

+ 143 - 123
psiphon/server/udpChannel.go

@@ -21,7 +21,6 @@ package server
 
 import (
 	"bytes"
-	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
@@ -84,6 +83,9 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 
 	type udpPortForward struct {
 		connID       uint16
+		preambleSize int
+		remoteIP     []byte
+		remotePort   uint16
 		conn         *net.UDPConn
 		lastActivity int64
 		bytesUp      int64
@@ -96,11 +98,9 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 	buffer := make([]byte, udpgwProtocolMaxMessageSize)
 
 	for {
-		// Note: udpProtocolMessage.packet points to the resuable
-		// memory in "buffer". Each readUdpgwMessage call will overwrite
-		// the last udpProtocolMessage.packet.
-		udpProtocolMessage, err := readUdpgwMessage(
-			sshClient.sshServer.config, fwdChannel, buffer)
+		// Note: message.packet points to the reusable memory in "buffer".
+		// Each readUdpgwMessage call will overwrite the last message.packet.
+		message, err := readUdpgwMessage(fwdChannel, buffer)
 		if err != nil {
 			if err != io.EOF {
 				log.WithContextFields(LogFields{"error": err}).Warning("readUpdgwMessage failed")
@@ -109,10 +109,10 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 		}
 
 		portForwardsMutex.Lock()
-		portForward := portForwards[udpProtocolMessage.connID]
+		portForward := portForwards[message.connID]
 		portForwardsMutex.Unlock()
 
-		if portForward != nil && udpProtocolMessage.discardExistingConn {
+		if portForward != nil && message.discardExistingConn {
 			// The port forward's goroutine will complete cleanup, including
 			// tallying stats and calling sshClient.closedPortForward.
 			// portForward.conn.Close() will signal this shutdown.
@@ -121,10 +121,23 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 			portForward = nil
 		}
 
-		if portForward == nil {
+		if portForward != nil {
+
+			// Verify that portForward remote address matches latest message
+
+			if 0 != bytes.Compare(portForward.remoteIP, message.remoteIP) ||
+				portForward.remotePort != message.remotePort {
+
+				log.WithContext().Warning("UDP port forward remote address mismatch")
+				continue
+			}
+
+		} else {
+
+			// Create a new port forward
 
 			if !sshClient.isPortForwardPermitted(
-				udpProtocolMessage.portToConnect,
+				int(message.remotePort),
 				sshClient.trafficRules.AllowUDPPorts,
 				sshClient.trafficRules.DenyUDPPorts) {
 				// The udpgw protocol has no error response, so
@@ -150,25 +163,38 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 					}
 				}
 				if oldestPortForward != nil {
-					// *** comment: let goro call closePortForward
+					// The port forward's goroutine will complete cleanup
 					oldestPortForward.conn.Close()
 				}
 				portForwardsMutex.Unlock()
 			}
 
+			dialIP := message.remoteIP
+			dialPort := int(message.remotePort)
+
+			// Transparent DNS forwarding
+			if message.forwardDNS && sshClient.sshServer.config.DNSServerAddress != "" {
+				// Note: DNSServerAddress is validated in LoadConfig
+				host, portStr, _ := net.SplitHostPort(
+					sshClient.sshServer.config.DNSServerAddress)
+				dialIP = net.ParseIP(host)
+				dialPort, _ = strconv.Atoi(portStr)
+			}
+
 			// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
-			// TODO: IPv6 support
-			updConn, err := net.Dial(
-				"udp4",
-				fmt.Sprintf("%s:%d", udpProtocolMessage.hostToConnect, udpProtocolMessage.portToConnect))
+			updConn, err := net.DialUDP(
+				"udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort})
 			if err != nil {
 				log.WithContextFields(LogFields{"error": err}).Warning("DialUDP failed")
 				continue
 			}
 
-			portForward := &udpPortForward{
-				connID:       udpProtocolMessage.connID,
-				conn:         updConn.(*net.UDPConn),
+			portForward = &udpPortForward{
+				connID:       message.connID,
+				preambleSize: message.preambleSize,
+				remoteIP:     message.remoteIP,
+				remotePort:   message.remotePort,
+				conn:         updConn,
 				lastActivity: time.Now().UnixNano(),
 				bytesUp:      0,
 				bytesDown:    0,
@@ -186,18 +212,19 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 				defer relayWaitGroup.Done()
 
 				// Downstream UDP packets are read into the reusable memory
-				// in "buffer" starting at the offset udpgwProtocolHeaderSize,
-				// leaving enough space to write the udpgw header into the
-				// same buffer and use for writing to the ssh channel.
+				// in "buffer" starting at the offset past the udpgw message
+				// header and address, leaving enough space to write the udpgw
+				// values into the same buffer and use for writing to the ssh
+				// channel.
 				//
 				// Note: there is one downstream buffer per UDP port forward,
 				// while for upstream there is one buffer per client.
 				// TODO: is the buffer size larger than necessary?
 				buffer := make([]byte, udpgwProtocolMaxMessageSize)
-				packetBuffer := buffer[udpgwProtocolHeaderSize:udpgwProtocolMaxMessageSize]
+				packetBuffer := buffer[portForward.preambleSize:udpgwProtocolMaxMessageSize]
 				for {
 					// TODO: if read buffer is too small, excess bytes are discarded?
-					packetSize, _, err := portForward.conn.ReadFrom(packetBuffer)
+					packetSize, err := portForward.conn.Read(packetBuffer)
 					if packetSize > udpgwProtocolMaxPayloadSize {
 						err = fmt.Errorf("unexpected packet size: %d", packetSize)
 					}
@@ -208,9 +235,17 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 						break
 					}
 
-					writeUdpgwHeader(buffer, uint16(packetSize), portForward.connID)
+					err = writeUdpgwPreamble(
+						portForward.preambleSize,
+						portForward.connID,
+						portForward.remoteIP,
+						portForward.remotePort,
+						uint16(packetSize),
+						buffer)
+					if err == nil {
+						_, err = fwdChannel.Write(buffer[0 : portForward.preambleSize+packetSize])
+					}
 
-					_, err = fwdChannel.Write(buffer[0 : udpgwProtocolHeaderSize+packetSize])
 					if err != nil {
 						// Close the channel, which will interrupt the main loop.
 						fwdChannel.Close()
@@ -236,14 +271,14 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 		}
 
 		// Note: assumes UDP writes won't block (https://golang.org/pkg/net/#UDPConn.WriteToUDP)
-		_, err = portForward.conn.WriteTo(udpProtocolMessage.packet, nil)
+		_, err = portForward.conn.Write(message.packet)
 		if err != nil {
 			log.WithContextFields(LogFields{"error": err}).Warning("upstream UDP relay failed")
 			// The port forward's goroutine will complete cleanup
 			portForward.conn.Close()
 		}
 		atomic.StoreInt64(&portForward.lastActivity, time.Now().UnixNano())
-		atomic.AddInt64(&portForward.bytesUp, int64(len(udpProtocolMessage.packet)))
+		atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet)))
 	}
 
 	// Cleanup all UDP port forward workers when exiting
@@ -265,149 +300,134 @@ const (
 	udpgwProtocolFlagDNS       = 1 << 2
 	udpgwProtocolFlagIPv6      = 1 << 3
 
-	udpgwProtocolHeaderSize     = 3
-	udpgwProtocolIPv4AddrSize   = 6
-	udpgwProtocolIPv6AddrSize   = 18
-	udpgwProtocolMaxPayloadSize = 32768
-	udpgwProtocolMaxMessageSize = udpgwProtocolHeaderSize +
-		udpgwProtocolIPv6AddrSize +
-		udpgwProtocolMaxPayloadSize
+	udpgwProtocolMaxPreambleSize = 23
+	udpgwProtocolMaxPayloadSize  = 32768
+	udpgwProtocolMaxMessageSize  = udpgwProtocolMaxPreambleSize + udpgwProtocolMaxPayloadSize
 )
 
-type udpgwHeader struct {
-	Size   uint16
-	Flags  uint8
-	ConnID uint16
-}
-
-type udpgwAddrIPv4 struct {
-	IP   uint32
-	Port uint16
-}
-
-type udpgwAddrIPv6 struct {
-	IP   [16]uint8
-	Port uint16
-}
-
 type udpProtocolMessage struct {
 	connID              uint16
+	preambleSize        int
+	remoteIP            []byte
+	remotePort          uint16
 	discardExistingConn bool
-	hostToConnect       string
-	portToConnect       int
+	forwardDNS          bool
 	packet              []byte
 }
 
 func readUdpgwMessage(
-	config *Config, reader io.Reader, buffer []byte) (*udpProtocolMessage, error) {
+	reader io.Reader, buffer []byte) (*udpProtocolMessage, error) {
+
+	// udpgw message layout:
+	//
+	// | 2 byte size | 3 byte header | 6 or 18 byte address | variable length packet |
 
 	for {
-		// Read udpgwHeader
+		// Read message
 
-		_, err := io.ReadFull(reader, buffer[0:udpgwProtocolHeaderSize])
+		_, err := io.ReadFull(reader, buffer[0:2])
 		if err != nil {
 			return nil, psiphon.ContextError(err)
 		}
 
-		var header udpgwHeader
-		err = binary.Read(
-			bytes.NewReader(buffer[0:udpgwProtocolHeaderSize]), binary.BigEndian, &header)
-		if err != nil {
-			return nil, psiphon.ContextError(err)
-		}
+		size := uint16(buffer[0]) + uint16(buffer[1])<<8
 
-		if int(header.Size) < udpgwProtocolHeaderSize || int(header.Size) > len(buffer) {
+		if int(size) > len(buffer)-2 {
 			return nil, psiphon.ContextError(errors.New("invalid udpgw message size"))
 		}
 
-		_, err = io.ReadFull(reader, buffer[udpgwProtocolHeaderSize:header.Size])
+		_, err = io.ReadFull(reader, buffer[2:2+size])
 		if err != nil {
 			return nil, psiphon.ContextError(err)
 		}
 
+		flags := buffer[2]
+
+		connID := uint16(buffer[3]) + uint16(buffer[4])<<8
+
 		// Ignore udpgw keep-alive messages -- read another message
 
-		if header.Flags&udpgwProtocolFlagKeepalive == udpgwProtocolFlagKeepalive {
+		if flags&udpgwProtocolFlagKeepalive == udpgwProtocolFlagKeepalive {
 			continue
 		}
 
-		// Read udpgwAddrIPv4 or udpgwAddrIPv6
+		// Read address
 
-		var hostToConnect string
-		var portToConnect int
-		var packetOffset int
+		var remoteIP []byte
+		var remotePort uint16
+		var packetStart, packetEnd int
 
-		if header.Flags&udpgwProtocolFlagIPv6 == udpgwProtocolFlagIPv6 {
+		if flags&udpgwProtocolFlagIPv6 == udpgwProtocolFlagIPv6 {
 
-			var addr udpgwAddrIPv6
-			err = binary.Read(
-				bytes.NewReader(
-					buffer[udpgwProtocolHeaderSize:udpgwProtocolHeaderSize+udpgwProtocolIPv6AddrSize]),
-				binary.BigEndian, &addr)
-			if err != nil {
-				return nil, psiphon.ContextError(err)
+			if size < 21 {
+				return nil, psiphon.ContextError(errors.New("invalid udpgw message size"))
 			}
 
-			ip := make(net.IP, 16)
-			copy(ip, addr.IP[:])
-
-			hostToConnect = ip.String()
-			portToConnect = int(addr.Port)
-			packetOffset = udpgwProtocolHeaderSize + udpgwProtocolIPv6AddrSize
+			remoteIP = make([]byte, 16)
+			copy(remoteIP, buffer[5:21])
+			remotePort = uint16(buffer[21]) + uint16(buffer[22])<<8
+			packetStart = 23
+			packetEnd = 23 + int(size) - 2
 
 		} else {
 
-			var addr udpgwAddrIPv4
-			err = binary.Read(
-				bytes.NewReader(
-					buffer[udpgwProtocolHeaderSize:udpgwProtocolHeaderSize+udpgwProtocolIPv4AddrSize]),
-				binary.BigEndian, &addr)
-
-			ip := make(net.IP, 4)
-			binary.BigEndian.PutUint32(ip, addr.IP)
+			if size < 9 {
+				return nil, psiphon.ContextError(errors.New("invalid udpgw message size"))
+			}
 
-			hostToConnect = net.IP(ip).String()
-			portToConnect = int(addr.Port)
-			packetOffset = udpgwProtocolHeaderSize + udpgwProtocolIPv4AddrSize
+			remoteIP = make([]byte, 4)
+			copy(remoteIP, buffer[5:9])
+			remotePort = uint16(buffer[9]) + uint16(buffer[10])<<8
+			packetStart = 11
+			packetEnd = 11 + int(size) - 2
 		}
 
 		// Assemble message
 		// Note: udpProtocolMessage.packet references memory in the input buffer
 
-		udpProtocolMessage := &udpProtocolMessage{
-			connID:              header.ConnID,
-			discardExistingConn: header.Flags&udpgwProtocolFlagRebind == udpgwProtocolFlagRebind,
-			hostToConnect:       hostToConnect,
-			portToConnect:       portToConnect,
-			packet:              buffer[packetOffset : int(header.Size)-packetOffset],
+		message := &udpProtocolMessage{
+			connID:              connID,
+			preambleSize:        packetStart,
+			remoteIP:            remoteIP,
+			remotePort:          remotePort,
+			discardExistingConn: flags&udpgwProtocolFlagRebind == udpgwProtocolFlagRebind,
+			forwardDNS:          flags&udpgwProtocolFlagDNS == udpgwProtocolFlagDNS,
+			packet:              buffer[packetStart:packetEnd],
 		}
 
-		// Transparent DNS forwarding
-
-		if (header.Flags&udpgwProtocolFlagDNS == udpgwProtocolFlagDNS) &&
-			config.DNSServerAddress != "" {
+		return message, nil
+	}
+}
 
-			// Note: DNSServerAddress SplitHostPort is checked in LoadConfig
-			host, portStr, _ := net.SplitHostPort(config.DNSServerAddress)
-			port, _ := strconv.Atoi(portStr)
-			udpProtocolMessage.hostToConnect = host
-			udpProtocolMessage.portToConnect = port
-		}
+func writeUdpgwPreamble(
+	preambleSize int,
+	connID uint16,
+	remoteIP []byte,
+	remotePort uint16,
+	packetSize uint16,
+	buffer []byte) error {
 
-		return udpProtocolMessage, nil
+	if preambleSize != 7+len(remoteIP) {
+		return errors.New("invalid udpgw preamble size")
 	}
-}
 
-func writeUdpgwHeader(
-	buffer []byte, packetSize uint16, connID uint16) {
-	// TODO: write directly into buffer
-	header := make([]byte, 0, udpgwProtocolHeaderSize)
-	binary.Write(
-		bytes.NewBuffer(header),
-		binary.BigEndian,
-		&udpgwHeader{
-			Size:   udpgwProtocolHeaderSize + packetSize,
-			Flags:  0,
-			ConnID: connID})
-	copy(buffer[0:udpgwProtocolHeaderSize], header)
+	size := uint16(preambleSize-2) + packetSize
+
+	// size
+	buffer[0] = byte(size & 0xFF)
+	buffer[1] = byte(size >> 8)
+
+	// flags
+	buffer[2] = 0
+
+	// connID
+	buffer[3] = byte(connID & 0xFF)
+	buffer[4] = byte(connID >> 8)
+
+	// addr
+	copy(buffer[5:5+len(remoteIP)], remoteIP)
+	buffer[5+len(remoteIP)] = byte(remotePort & 0xFF)
+	buffer[6+len(remoteIP)] = byte(remotePort >> 8)
+
+	return nil
 }