|
|
@@ -21,7 +21,6 @@ package server
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
- "encoding/binary"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
@@ -84,6 +83,9 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
|
|
|
type udpPortForward struct {
|
|
|
connID uint16
|
|
|
+ preambleSize int
|
|
|
+ remoteIP []byte
|
|
|
+ remotePort uint16
|
|
|
conn *net.UDPConn
|
|
|
lastActivity int64
|
|
|
bytesUp int64
|
|
|
@@ -96,11 +98,9 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
buffer := make([]byte, udpgwProtocolMaxMessageSize)
|
|
|
|
|
|
for {
|
|
|
- // Note: udpProtocolMessage.packet points to the resuable
|
|
|
- // memory in "buffer". Each readUdpgwMessage call will overwrite
|
|
|
- // the last udpProtocolMessage.packet.
|
|
|
- udpProtocolMessage, err := readUdpgwMessage(
|
|
|
- sshClient.sshServer.config, fwdChannel, buffer)
|
|
|
+ // Note: message.packet points to the reusable memory in "buffer".
|
|
|
+ // Each readUdpgwMessage call will overwrite the last message.packet.
|
|
|
+ message, err := readUdpgwMessage(fwdChannel, buffer)
|
|
|
if err != nil {
|
|
|
if err != io.EOF {
|
|
|
log.WithContextFields(LogFields{"error": err}).Warning("readUpdgwMessage failed")
|
|
|
@@ -109,10 +109,10 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
}
|
|
|
|
|
|
portForwardsMutex.Lock()
|
|
|
- portForward := portForwards[udpProtocolMessage.connID]
|
|
|
+ portForward := portForwards[message.connID]
|
|
|
portForwardsMutex.Unlock()
|
|
|
|
|
|
- if portForward != nil && udpProtocolMessage.discardExistingConn {
|
|
|
+ if portForward != nil && message.discardExistingConn {
|
|
|
// The port forward's goroutine will complete cleanup, including
|
|
|
// tallying stats and calling sshClient.closedPortForward.
|
|
|
// portForward.conn.Close() will signal this shutdown.
|
|
|
@@ -121,10 +121,23 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
portForward = nil
|
|
|
}
|
|
|
|
|
|
- if portForward == nil {
|
|
|
+ if portForward != nil {
|
|
|
+
|
|
|
+ // Verify that portForward remote address matches latest message
|
|
|
+
|
|
|
+ if 0 != bytes.Compare(portForward.remoteIP, message.remoteIP) ||
|
|
|
+ portForward.remotePort != message.remotePort {
|
|
|
+
|
|
|
+ log.WithContext().Warning("UDP port forward remote address mismatch")
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ } else {
|
|
|
+
|
|
|
+ // Create a new port forward
|
|
|
|
|
|
if !sshClient.isPortForwardPermitted(
|
|
|
- udpProtocolMessage.portToConnect,
|
|
|
+ int(message.remotePort),
|
|
|
sshClient.trafficRules.AllowUDPPorts,
|
|
|
sshClient.trafficRules.DenyUDPPorts) {
|
|
|
// The udpgw protocol has no error response, so
|
|
|
@@ -150,25 +163,38 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
}
|
|
|
}
|
|
|
if oldestPortForward != nil {
|
|
|
- // *** comment: let goro call closePortForward
|
|
|
+ // The port forward's goroutine will complete cleanup
|
|
|
oldestPortForward.conn.Close()
|
|
|
}
|
|
|
portForwardsMutex.Unlock()
|
|
|
}
|
|
|
|
|
|
+ dialIP := message.remoteIP
|
|
|
+ dialPort := int(message.remotePort)
|
|
|
+
|
|
|
+ // Transparent DNS forwarding
|
|
|
+ if message.forwardDNS && sshClient.sshServer.config.DNSServerAddress != "" {
|
|
|
+ // Note: DNSServerAddress is validated in LoadConfig
|
|
|
+ host, portStr, _ := net.SplitHostPort(
|
|
|
+ sshClient.sshServer.config.DNSServerAddress)
|
|
|
+ dialIP = net.ParseIP(host)
|
|
|
+ dialPort, _ = strconv.Atoi(portStr)
|
|
|
+ }
|
|
|
+
|
|
|
// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
|
|
|
- // TODO: IPv6 support
|
|
|
- updConn, err := net.Dial(
|
|
|
- "udp4",
|
|
|
- fmt.Sprintf("%s:%d", udpProtocolMessage.hostToConnect, udpProtocolMessage.portToConnect))
|
|
|
+ updConn, err := net.DialUDP(
|
|
|
+ "udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort})
|
|
|
if err != nil {
|
|
|
log.WithContextFields(LogFields{"error": err}).Warning("DialUDP failed")
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- portForward := &udpPortForward{
|
|
|
- connID: udpProtocolMessage.connID,
|
|
|
- conn: updConn.(*net.UDPConn),
|
|
|
+ portForward = &udpPortForward{
|
|
|
+ connID: message.connID,
|
|
|
+ preambleSize: message.preambleSize,
|
|
|
+ remoteIP: message.remoteIP,
|
|
|
+ remotePort: message.remotePort,
|
|
|
+ conn: updConn,
|
|
|
lastActivity: time.Now().UnixNano(),
|
|
|
bytesUp: 0,
|
|
|
bytesDown: 0,
|
|
|
@@ -186,18 +212,19 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
defer relayWaitGroup.Done()
|
|
|
|
|
|
// Downstream UDP packets are read into the reusable memory
|
|
|
- // in "buffer" starting at the offset udpgwProtocolHeaderSize,
|
|
|
- // leaving enough space to write the udpgw header into the
|
|
|
- // same buffer and use for writing to the ssh channel.
|
|
|
+ // in "buffer" starting at the offset past the udpgw message
|
|
|
+ // header and address, leaving enough space to write the udpgw
|
|
|
+ // values into the same buffer and use for writing to the ssh
|
|
|
+ // channel.
|
|
|
//
|
|
|
// Note: there is one downstream buffer per UDP port forward,
|
|
|
// while for upstream there is one buffer per client.
|
|
|
// TODO: is the buffer size larger than necessary?
|
|
|
buffer := make([]byte, udpgwProtocolMaxMessageSize)
|
|
|
- packetBuffer := buffer[udpgwProtocolHeaderSize:udpgwProtocolMaxMessageSize]
|
|
|
+ packetBuffer := buffer[portForward.preambleSize:udpgwProtocolMaxMessageSize]
|
|
|
for {
|
|
|
// TODO: if read buffer is too small, excess bytes are discarded?
|
|
|
- packetSize, _, err := portForward.conn.ReadFrom(packetBuffer)
|
|
|
+ packetSize, err := portForward.conn.Read(packetBuffer)
|
|
|
if packetSize > udpgwProtocolMaxPayloadSize {
|
|
|
err = fmt.Errorf("unexpected packet size: %d", packetSize)
|
|
|
}
|
|
|
@@ -208,9 +235,17 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
break
|
|
|
}
|
|
|
|
|
|
- writeUdpgwHeader(buffer, uint16(packetSize), portForward.connID)
|
|
|
+ err = writeUdpgwPreamble(
|
|
|
+ portForward.preambleSize,
|
|
|
+ portForward.connID,
|
|
|
+ portForward.remoteIP,
|
|
|
+ portForward.remotePort,
|
|
|
+ uint16(packetSize),
|
|
|
+ buffer)
|
|
|
+ if err == nil {
|
|
|
+ _, err = fwdChannel.Write(buffer[0 : portForward.preambleSize+packetSize])
|
|
|
+ }
|
|
|
|
|
|
- _, err = fwdChannel.Write(buffer[0 : udpgwProtocolHeaderSize+packetSize])
|
|
|
if err != nil {
|
|
|
// Close the channel, which will interrupt the main loop.
|
|
|
fwdChannel.Close()
|
|
|
@@ -236,14 +271,14 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
}
|
|
|
|
|
|
// Note: assumes UDP writes won't block (https://golang.org/pkg/net/#UDPConn.WriteToUDP)
|
|
|
- _, err = portForward.conn.WriteTo(udpProtocolMessage.packet, nil)
|
|
|
+ _, err = portForward.conn.Write(message.packet)
|
|
|
if err != nil {
|
|
|
log.WithContextFields(LogFields{"error": err}).Warning("upstream UDP relay failed")
|
|
|
// The port forward's goroutine will complete cleanup
|
|
|
portForward.conn.Close()
|
|
|
}
|
|
|
atomic.StoreInt64(&portForward.lastActivity, time.Now().UnixNano())
|
|
|
- atomic.AddInt64(&portForward.bytesUp, int64(len(udpProtocolMessage.packet)))
|
|
|
+ atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet)))
|
|
|
}
|
|
|
|
|
|
// Cleanup all UDP port forward workers when exiting
|
|
|
@@ -265,149 +300,134 @@ const (
|
|
|
udpgwProtocolFlagDNS = 1 << 2
|
|
|
udpgwProtocolFlagIPv6 = 1 << 3
|
|
|
|
|
|
- udpgwProtocolHeaderSize = 3
|
|
|
- udpgwProtocolIPv4AddrSize = 6
|
|
|
- udpgwProtocolIPv6AddrSize = 18
|
|
|
- udpgwProtocolMaxPayloadSize = 32768
|
|
|
- udpgwProtocolMaxMessageSize = udpgwProtocolHeaderSize +
|
|
|
- udpgwProtocolIPv6AddrSize +
|
|
|
- udpgwProtocolMaxPayloadSize
|
|
|
+ udpgwProtocolMaxPreambleSize = 23
|
|
|
+ udpgwProtocolMaxPayloadSize = 32768
|
|
|
+ udpgwProtocolMaxMessageSize = udpgwProtocolMaxPreambleSize + udpgwProtocolMaxPayloadSize
|
|
|
)
|
|
|
|
|
|
-type udpgwHeader struct {
|
|
|
- Size uint16
|
|
|
- Flags uint8
|
|
|
- ConnID uint16
|
|
|
-}
|
|
|
-
|
|
|
-type udpgwAddrIPv4 struct {
|
|
|
- IP uint32
|
|
|
- Port uint16
|
|
|
-}
|
|
|
-
|
|
|
-type udpgwAddrIPv6 struct {
|
|
|
- IP [16]uint8
|
|
|
- Port uint16
|
|
|
-}
|
|
|
-
|
|
|
type udpProtocolMessage struct {
|
|
|
connID uint16
|
|
|
+ preambleSize int
|
|
|
+ remoteIP []byte
|
|
|
+ remotePort uint16
|
|
|
discardExistingConn bool
|
|
|
- hostToConnect string
|
|
|
- portToConnect int
|
|
|
+ forwardDNS bool
|
|
|
packet []byte
|
|
|
}
|
|
|
|
|
|
func readUdpgwMessage(
|
|
|
- config *Config, reader io.Reader, buffer []byte) (*udpProtocolMessage, error) {
|
|
|
+ reader io.Reader, buffer []byte) (*udpProtocolMessage, error) {
|
|
|
+
|
|
|
+ // udpgw message layout:
|
|
|
+ //
|
|
|
+ // | 2 byte size | 3 byte header | 6 or 18 byte address | variable length packet |
|
|
|
|
|
|
for {
|
|
|
- // Read udpgwHeader
|
|
|
+ // Read message
|
|
|
|
|
|
- _, err := io.ReadFull(reader, buffer[0:udpgwProtocolHeaderSize])
|
|
|
+ _, err := io.ReadFull(reader, buffer[0:2])
|
|
|
if err != nil {
|
|
|
return nil, psiphon.ContextError(err)
|
|
|
}
|
|
|
|
|
|
- var header udpgwHeader
|
|
|
- err = binary.Read(
|
|
|
- bytes.NewReader(buffer[0:udpgwProtocolHeaderSize]), binary.BigEndian, &header)
|
|
|
- if err != nil {
|
|
|
- return nil, psiphon.ContextError(err)
|
|
|
- }
|
|
|
+ size := uint16(buffer[0]) + uint16(buffer[1])<<8
|
|
|
|
|
|
- if int(header.Size) < udpgwProtocolHeaderSize || int(header.Size) > len(buffer) {
|
|
|
+ if int(size) > len(buffer)-2 {
|
|
|
return nil, psiphon.ContextError(errors.New("invalid udpgw message size"))
|
|
|
}
|
|
|
|
|
|
- _, err = io.ReadFull(reader, buffer[udpgwProtocolHeaderSize:header.Size])
|
|
|
+ _, err = io.ReadFull(reader, buffer[2:2+size])
|
|
|
if err != nil {
|
|
|
return nil, psiphon.ContextError(err)
|
|
|
}
|
|
|
|
|
|
+ flags := buffer[2]
|
|
|
+
|
|
|
+ connID := uint16(buffer[3]) + uint16(buffer[4])<<8
|
|
|
+
|
|
|
// Ignore udpgw keep-alive messages -- read another message
|
|
|
|
|
|
- if header.Flags&udpgwProtocolFlagKeepalive == udpgwProtocolFlagKeepalive {
|
|
|
+ if flags&udpgwProtocolFlagKeepalive == udpgwProtocolFlagKeepalive {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- // Read udpgwAddrIPv4 or udpgwAddrIPv6
|
|
|
+ // Read address
|
|
|
|
|
|
- var hostToConnect string
|
|
|
- var portToConnect int
|
|
|
- var packetOffset int
|
|
|
+ var remoteIP []byte
|
|
|
+ var remotePort uint16
|
|
|
+ var packetStart, packetEnd int
|
|
|
|
|
|
- if header.Flags&udpgwProtocolFlagIPv6 == udpgwProtocolFlagIPv6 {
|
|
|
+ if flags&udpgwProtocolFlagIPv6 == udpgwProtocolFlagIPv6 {
|
|
|
|
|
|
- var addr udpgwAddrIPv6
|
|
|
- err = binary.Read(
|
|
|
- bytes.NewReader(
|
|
|
- buffer[udpgwProtocolHeaderSize:udpgwProtocolHeaderSize+udpgwProtocolIPv6AddrSize]),
|
|
|
- binary.BigEndian, &addr)
|
|
|
- if err != nil {
|
|
|
- return nil, psiphon.ContextError(err)
|
|
|
+ if size < 21 {
|
|
|
+ return nil, psiphon.ContextError(errors.New("invalid udpgw message size"))
|
|
|
}
|
|
|
|
|
|
- ip := make(net.IP, 16)
|
|
|
- copy(ip, addr.IP[:])
|
|
|
-
|
|
|
- hostToConnect = ip.String()
|
|
|
- portToConnect = int(addr.Port)
|
|
|
- packetOffset = udpgwProtocolHeaderSize + udpgwProtocolIPv6AddrSize
|
|
|
+ remoteIP = make([]byte, 16)
|
|
|
+ copy(remoteIP, buffer[5:21])
|
|
|
+ remotePort = uint16(buffer[21]) + uint16(buffer[22])<<8
|
|
|
+ packetStart = 23
|
|
|
+ packetEnd = 23 + int(size) - 2
|
|
|
|
|
|
} else {
|
|
|
|
|
|
- var addr udpgwAddrIPv4
|
|
|
- err = binary.Read(
|
|
|
- bytes.NewReader(
|
|
|
- buffer[udpgwProtocolHeaderSize:udpgwProtocolHeaderSize+udpgwProtocolIPv4AddrSize]),
|
|
|
- binary.BigEndian, &addr)
|
|
|
-
|
|
|
- ip := make(net.IP, 4)
|
|
|
- binary.BigEndian.PutUint32(ip, addr.IP)
|
|
|
+ if size < 9 {
|
|
|
+ return nil, psiphon.ContextError(errors.New("invalid udpgw message size"))
|
|
|
+ }
|
|
|
|
|
|
- hostToConnect = net.IP(ip).String()
|
|
|
- portToConnect = int(addr.Port)
|
|
|
- packetOffset = udpgwProtocolHeaderSize + udpgwProtocolIPv4AddrSize
|
|
|
+ remoteIP = make([]byte, 4)
|
|
|
+ copy(remoteIP, buffer[5:9])
|
|
|
+ remotePort = uint16(buffer[9]) + uint16(buffer[10])<<8
|
|
|
+ packetStart = 11
|
|
|
+ packetEnd = 11 + int(size) - 2
|
|
|
}
|
|
|
|
|
|
// Assemble message
|
|
|
// Note: udpProtocolMessage.packet references memory in the input buffer
|
|
|
|
|
|
- udpProtocolMessage := &udpProtocolMessage{
|
|
|
- connID: header.ConnID,
|
|
|
- discardExistingConn: header.Flags&udpgwProtocolFlagRebind == udpgwProtocolFlagRebind,
|
|
|
- hostToConnect: hostToConnect,
|
|
|
- portToConnect: portToConnect,
|
|
|
- packet: buffer[packetOffset : int(header.Size)-packetOffset],
|
|
|
+ message := &udpProtocolMessage{
|
|
|
+ connID: connID,
|
|
|
+ preambleSize: packetStart,
|
|
|
+ remoteIP: remoteIP,
|
|
|
+ remotePort: remotePort,
|
|
|
+ discardExistingConn: flags&udpgwProtocolFlagRebind == udpgwProtocolFlagRebind,
|
|
|
+ forwardDNS: flags&udpgwProtocolFlagDNS == udpgwProtocolFlagDNS,
|
|
|
+ packet: buffer[packetStart:packetEnd],
|
|
|
}
|
|
|
|
|
|
- // Transparent DNS forwarding
|
|
|
-
|
|
|
- if (header.Flags&udpgwProtocolFlagDNS == udpgwProtocolFlagDNS) &&
|
|
|
- config.DNSServerAddress != "" {
|
|
|
+ return message, nil
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
- // Note: DNSServerAddress SplitHostPort is checked in LoadConfig
|
|
|
- host, portStr, _ := net.SplitHostPort(config.DNSServerAddress)
|
|
|
- port, _ := strconv.Atoi(portStr)
|
|
|
- udpProtocolMessage.hostToConnect = host
|
|
|
- udpProtocolMessage.portToConnect = port
|
|
|
- }
|
|
|
+func writeUdpgwPreamble(
|
|
|
+ preambleSize int,
|
|
|
+ connID uint16,
|
|
|
+ remoteIP []byte,
|
|
|
+ remotePort uint16,
|
|
|
+ packetSize uint16,
|
|
|
+ buffer []byte) error {
|
|
|
|
|
|
- return udpProtocolMessage, nil
|
|
|
+ if preambleSize != 7+len(remoteIP) {
|
|
|
+ return errors.New("invalid udpgw preamble size")
|
|
|
}
|
|
|
-}
|
|
|
|
|
|
-func writeUdpgwHeader(
|
|
|
- buffer []byte, packetSize uint16, connID uint16) {
|
|
|
- // TODO: write directly into buffer
|
|
|
- header := make([]byte, 0, udpgwProtocolHeaderSize)
|
|
|
- binary.Write(
|
|
|
- bytes.NewBuffer(header),
|
|
|
- binary.BigEndian,
|
|
|
- &udpgwHeader{
|
|
|
- Size: udpgwProtocolHeaderSize + packetSize,
|
|
|
- Flags: 0,
|
|
|
- ConnID: connID})
|
|
|
- copy(buffer[0:udpgwProtocolHeaderSize], header)
|
|
|
+ size := uint16(preambleSize-2) + packetSize
|
|
|
+
|
|
|
+ // size
|
|
|
+ buffer[0] = byte(size & 0xFF)
|
|
|
+ buffer[1] = byte(size >> 8)
|
|
|
+
|
|
|
+ // flags
|
|
|
+ buffer[2] = 0
|
|
|
+
|
|
|
+ // connID
|
|
|
+ buffer[3] = byte(connID & 0xFF)
|
|
|
+ buffer[4] = byte(connID >> 8)
|
|
|
+
|
|
|
+ // addr
|
|
|
+ copy(buffer[5:5+len(remoteIP)], remoteIP)
|
|
|
+ buffer[5+len(remoteIP)] = byte(remotePort & 0xFF)
|
|
|
+ buffer[6+len(remoteIP)] = byte(remotePort >> 8)
|
|
|
+
|
|
|
+ return nil
|
|
|
}
|