|
|
@@ -61,15 +61,34 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
// Accept this channel immediately. This channel will replace any
|
|
|
// previously existing UDP channel for this client.
|
|
|
|
|
|
- fwdChannel, requests, err := newChannel.Accept()
|
|
|
+ sshChannel, requests, err := newChannel.Accept()
|
|
|
if err != nil {
|
|
|
log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
|
|
|
return
|
|
|
}
|
|
|
go ssh.DiscardRequests(requests)
|
|
|
- defer fwdChannel.Close()
|
|
|
+ defer sshChannel.Close()
|
|
|
|
|
|
- sshClient.setUDPChannel(fwdChannel)
|
|
|
+ sshClient.setUDPChannel(sshChannel)
|
|
|
+
|
|
|
+ multiplexer := &udpPortForwardMultiplexer{
|
|
|
+ sshClient: sshClient,
|
|
|
+ sshChannel: sshChannel,
|
|
|
+ portForwards: make(map[uint16]*udpPortForward),
|
|
|
+ relayWaitGroup: new(sync.WaitGroup),
|
|
|
+ }
|
|
|
+ multiplexer.run()
|
|
|
+}
|
|
|
+
|
|
|
+type udpPortForwardMultiplexer struct {
|
|
|
+ sshClient *sshClient
|
|
|
+ sshChannel ssh.Channel
|
|
|
+ portForwardsMutex sync.Mutex
|
|
|
+ portForwards map[uint16]*udpPortForward
|
|
|
+ relayWaitGroup *sync.WaitGroup
|
|
|
+}
|
|
|
+
|
|
|
+func (mux *udpPortForwardMultiplexer) run() {
|
|
|
|
|
|
// In a loop, read udpgw messages from the client to this channel. Each message is
|
|
|
// a UDP packet to send upstream either via a new port forward, or on an existing
|
|
|
@@ -81,26 +100,11 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
// When the client disconnects or the server shuts down, the channel will close and
|
|
|
// readUdpgwMessage will exit with EOF.
|
|
|
|
|
|
- type udpPortForward struct {
|
|
|
- connID uint16
|
|
|
- preambleSize int
|
|
|
- remoteIP []byte
|
|
|
- remotePort uint16
|
|
|
- conn *net.UDPConn
|
|
|
- lastActivity int64
|
|
|
- bytesUp int64
|
|
|
- bytesDown int64
|
|
|
- }
|
|
|
-
|
|
|
- var portForwardsMutex sync.Mutex
|
|
|
- portForwards := make(map[uint16]*udpPortForward)
|
|
|
- relayWaitGroup := new(sync.WaitGroup)
|
|
|
buffer := make([]byte, udpgwProtocolMaxMessageSize)
|
|
|
-
|
|
|
for {
|
|
|
// Note: message.packet points to the reusable memory in "buffer".
|
|
|
// Each readUdpgwMessage call will overwrite the last message.packet.
|
|
|
- message, err := readUdpgwMessage(fwdChannel, buffer)
|
|
|
+ message, err := readUdpgwMessage(mux.sshChannel, buffer)
|
|
|
if err != nil {
|
|
|
if err != io.EOF {
|
|
|
log.WithContextFields(LogFields{"error": err}).Warning("readUpdgwMessage failed")
|
|
|
@@ -108,9 +112,9 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
break
|
|
|
}
|
|
|
|
|
|
- portForwardsMutex.Lock()
|
|
|
- portForward := portForwards[message.connID]
|
|
|
- portForwardsMutex.Unlock()
|
|
|
+ mux.portForwardsMutex.Lock()
|
|
|
+ portForward := mux.portForwards[message.connID]
|
|
|
+ mux.portForwardsMutex.Unlock()
|
|
|
|
|
|
if portForward != nil && message.discardExistingConn {
|
|
|
// The port forward's goroutine will complete cleanup, including
|
|
|
@@ -136,55 +140,48 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
|
|
|
// Create a new port forward
|
|
|
|
|
|
- if !sshClient.isPortForwardPermitted(
|
|
|
+ if !mux.sshClient.isPortForwardPermitted(
|
|
|
int(message.remotePort),
|
|
|
- sshClient.trafficRules.AllowUDPPorts,
|
|
|
- sshClient.trafficRules.DenyUDPPorts) {
|
|
|
+ mux.sshClient.trafficRules.AllowUDPPorts,
|
|
|
+ mux.sshClient.trafficRules.DenyUDPPorts) {
|
|
|
// The udpgw protocol has no error response, so
|
|
|
// we just discard the message and read another.
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- if sshClient.isPortForwardLimitExceeded(
|
|
|
- sshClient.tcpTrafficState,
|
|
|
- sshClient.trafficRules.MaxUDPPortForwardCount) {
|
|
|
+ mux.sshClient.openedPortForward(mux.sshClient.udpTrafficState)
|
|
|
+ // Note: can't defer sshClient.closedPortForward() here
|
|
|
+
|
|
|
+ // TOCTOU note: important to increment the port forward count (via
|
|
|
+ // openPortForward) _before_ checking isPortForwardLimitExceeded
|
|
|
+ if mux.sshClient.isPortForwardLimitExceeded(
|
|
|
+ mux.sshClient.tcpTrafficState,
|
|
|
+ mux.sshClient.trafficRules.MaxUDPPortForwardCount) {
|
|
|
|
|
|
// When the UDP port forward limit is exceeded, we
|
|
|
- // select the least recently used (red from or written
|
|
|
+ // select the least recently used (read from or written
|
|
|
// to) port forward and discard it.
|
|
|
-
|
|
|
- // TODO: use "container/list" and avoid a linear scan?
|
|
|
- portForwardsMutex.Lock()
|
|
|
- oldestActivity := int64(math.MaxInt64)
|
|
|
- var oldestPortForward *udpPortForward
|
|
|
- for _, nextPortForward := range portForwards {
|
|
|
- if nextPortForward.lastActivity < oldestActivity {
|
|
|
- oldestPortForward = nextPortForward
|
|
|
- }
|
|
|
- }
|
|
|
- if oldestPortForward != nil {
|
|
|
- // The port forward's goroutine will complete cleanup
|
|
|
- oldestPortForward.conn.Close()
|
|
|
- }
|
|
|
- portForwardsMutex.Unlock()
|
|
|
+ mux.closeLeastRecentlyUsedPortForward()
|
|
|
}
|
|
|
|
|
|
- dialIP := message.remoteIP
|
|
|
+ dialIP := net.IP(message.remoteIP)
|
|
|
dialPort := int(message.remotePort)
|
|
|
|
|
|
// Transparent DNS forwarding
|
|
|
- if message.forwardDNS && sshClient.sshServer.config.DNSServerAddress != "" {
|
|
|
- // Note: DNSServerAddress is validated in LoadConfig
|
|
|
- host, portStr, _ := net.SplitHostPort(
|
|
|
- sshClient.sshServer.config.DNSServerAddress)
|
|
|
- dialIP = net.ParseIP(host)
|
|
|
- dialPort, _ = strconv.Atoi(portStr)
|
|
|
+ if message.forwardDNS {
|
|
|
+ dialIP, dialPort = mux.transparentDNSAddress(dialIP, dialPort)
|
|
|
}
|
|
|
|
|
|
+ log.WithContextFields(
|
|
|
+ LogFields{
|
|
|
+ "remoteAddr": fmt.Sprintf("%s:%d", dialIP.String(), dialPort),
|
|
|
+ "connID": message.connID}).Debug("dialing")
|
|
|
+
|
|
|
// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
|
|
|
updConn, err := net.DialUDP(
|
|
|
"udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort})
|
|
|
if err != nil {
|
|
|
+ mux.sshClient.closedPortForward(mux.sshClient.udpTrafficState, 0, 0)
|
|
|
log.WithContextFields(LogFields{"error": err}).Warning("DialUDP failed")
|
|
|
continue
|
|
|
}
|
|
|
@@ -198,76 +195,17 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
lastActivity: time.Now().UnixNano(),
|
|
|
bytesUp: 0,
|
|
|
bytesDown: 0,
|
|
|
+ mux: mux,
|
|
|
}
|
|
|
- portForwardsMutex.Lock()
|
|
|
- portForwards[portForward.connID] = portForward
|
|
|
- portForwardsMutex.Unlock()
|
|
|
+ mux.portForwardsMutex.Lock()
|
|
|
+ mux.portForwards[portForward.connID] = portForward
|
|
|
+ mux.portForwardsMutex.Unlock()
|
|
|
|
|
|
// TODO: timeout inactive UDP port forwards
|
|
|
|
|
|
- sshClient.establishedPortForward(sshClient.udpTrafficState)
|
|
|
-
|
|
|
- relayWaitGroup.Add(1)
|
|
|
- go func(portForward *udpPortForward) {
|
|
|
- defer relayWaitGroup.Done()
|
|
|
-
|
|
|
- // Downstream UDP packets are read into the reusable memory
|
|
|
- // in "buffer" starting at the offset past the udpgw message
|
|
|
- // header and address, leaving enough space to write the udpgw
|
|
|
- // values into the same buffer and use for writing to the ssh
|
|
|
- // channel.
|
|
|
- //
|
|
|
- // Note: there is one downstream buffer per UDP port forward,
|
|
|
- // while for upstream there is one buffer per client.
|
|
|
- // TODO: is the buffer size larger than necessary?
|
|
|
- buffer := make([]byte, udpgwProtocolMaxMessageSize)
|
|
|
- packetBuffer := buffer[portForward.preambleSize:udpgwProtocolMaxMessageSize]
|
|
|
- for {
|
|
|
- // TODO: if read buffer is too small, excess bytes are discarded?
|
|
|
- packetSize, err := portForward.conn.Read(packetBuffer)
|
|
|
- if packetSize > udpgwProtocolMaxPayloadSize {
|
|
|
- err = fmt.Errorf("unexpected packet size: %d", packetSize)
|
|
|
- }
|
|
|
- if err != nil {
|
|
|
- if err != io.EOF {
|
|
|
- log.WithContextFields(LogFields{"error": err}).Warning("downstream UDP relay failed")
|
|
|
- }
|
|
|
- break
|
|
|
- }
|
|
|
-
|
|
|
- err = writeUdpgwPreamble(
|
|
|
- portForward.preambleSize,
|
|
|
- portForward.connID,
|
|
|
- portForward.remoteIP,
|
|
|
- portForward.remotePort,
|
|
|
- uint16(packetSize),
|
|
|
- buffer)
|
|
|
- if err == nil {
|
|
|
- _, err = fwdChannel.Write(buffer[0 : portForward.preambleSize+packetSize])
|
|
|
- }
|
|
|
-
|
|
|
- if err != nil {
|
|
|
- // Close the channel, which will interrupt the main loop.
|
|
|
- fwdChannel.Close()
|
|
|
- log.WithContextFields(LogFields{"error": err}).Warning("downstream UDP relay failed")
|
|
|
- break
|
|
|
- }
|
|
|
-
|
|
|
- atomic.StoreInt64(&portForward.lastActivity, time.Now().UnixNano())
|
|
|
- atomic.AddInt64(&portForward.bytesDown, int64(packetSize))
|
|
|
- }
|
|
|
-
|
|
|
- portForwardsMutex.Lock()
|
|
|
- delete(portForwards, portForward.connID)
|
|
|
- portForwardsMutex.Unlock()
|
|
|
-
|
|
|
- portForward.conn.Close()
|
|
|
-
|
|
|
- bytesUp := atomic.LoadInt64(&portForward.bytesUp)
|
|
|
- bytesDown := atomic.LoadInt64(&portForward.bytesDown)
|
|
|
- sshClient.closedPortForward(sshClient.udpTrafficState, bytesUp, bytesDown)
|
|
|
-
|
|
|
- }(portForward)
|
|
|
+ // relayDownstream will call sshClient.closedPortForward()
|
|
|
+ mux.relayWaitGroup.Add(1)
|
|
|
+ go portForward.relayDownstream()
|
|
|
}
|
|
|
|
|
|
// Note: assumes UDP writes won't block (https://golang.org/pkg/net/#UDPConn.WriteToUDP)
|
|
|
@@ -283,14 +221,129 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
|
|
|
|
|
|
// Cleanup all UDP port forward workers when exiting
|
|
|
|
|
|
- portForwardsMutex.Lock()
|
|
|
- for _, portForward := range portForwards {
|
|
|
+ mux.portForwardsMutex.Lock()
|
|
|
+ for _, portForward := range mux.portForwards {
|
|
|
// The port forward's goroutine will complete cleanup
|
|
|
portForward.conn.Close()
|
|
|
}
|
|
|
- portForwardsMutex.Unlock()
|
|
|
+ mux.portForwardsMutex.Unlock()
|
|
|
+
|
|
|
+ mux.relayWaitGroup.Wait()
|
|
|
+}
|
|
|
+
|
|
|
+func (mux *udpPortForwardMultiplexer) closeLeastRecentlyUsedPortForward() {
|
|
|
+ // TODO: use "container/list" and avoid a linear scan?
|
|
|
+ mux.portForwardsMutex.Lock()
|
|
|
+ oldestActivity := int64(math.MaxInt64)
|
|
|
+ var oldestPortForward *udpPortForward
|
|
|
+ for _, nextPortForward := range mux.portForwards {
|
|
|
+ if nextPortForward.lastActivity < oldestActivity {
|
|
|
+ oldestPortForward = nextPortForward
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if oldestPortForward != nil {
|
|
|
+ // The port forward's goroutine will complete cleanup
|
|
|
+ oldestPortForward.conn.Close()
|
|
|
+ }
|
|
|
+ mux.portForwardsMutex.Unlock()
|
|
|
+}
|
|
|
+
|
|
|
+func (mux *udpPortForwardMultiplexer) transparentDNSAddress(
|
|
|
+ dialIP net.IP, dialPort int) (net.IP, int) {
|
|
|
+
|
|
|
+ if mux.sshClient.sshServer.config.DNSServerAddress != "" {
|
|
|
+ // Note: DNSServerAddress is validated in LoadConfig
|
|
|
+ host, portStr, _ := net.SplitHostPort(
|
|
|
+ mux.sshClient.sshServer.config.DNSServerAddress)
|
|
|
+ dialIP = net.ParseIP(host)
|
|
|
+ dialPort, _ = strconv.Atoi(portStr)
|
|
|
+ }
|
|
|
+ return dialIP, dialPort
|
|
|
+}
|
|
|
+
|
|
|
+func (mux *udpPortForwardMultiplexer) removePortForward(connID uint16) {
|
|
|
+ mux.portForwardsMutex.Lock()
|
|
|
+ delete(mux.portForwards, connID)
|
|
|
+ mux.portForwardsMutex.Unlock()
|
|
|
+}
|
|
|
+
|
|
|
+type udpPortForward struct {
|
|
|
+ connID uint16
|
|
|
+ preambleSize int
|
|
|
+ remoteIP []byte
|
|
|
+ remotePort uint16
|
|
|
+ conn *net.UDPConn
|
|
|
+ lastActivity int64
|
|
|
+ bytesUp int64
|
|
|
+ bytesDown int64
|
|
|
+ mux *udpPortForwardMultiplexer
|
|
|
+}
|
|
|
+
|
|
|
+func (portForward *udpPortForward) relayDownstream() {
|
|
|
+ defer portForward.mux.relayWaitGroup.Done()
|
|
|
+
|
|
|
+ // Downstream UDP packets are read into the reusable memory
|
|
|
+ // in "buffer" starting at the offset past the udpgw message
|
|
|
+ // header and address, leaving enough space to write the udpgw
|
|
|
+ // values into the same buffer and use for writing to the ssh
|
|
|
+ // channel.
|
|
|
+ //
|
|
|
+ // Note: there is one downstream buffer per UDP port forward,
|
|
|
+ // while for upstream there is one buffer per client.
|
|
|
+ // TODO: is the buffer size larger than necessary?
|
|
|
+ buffer := make([]byte, udpgwProtocolMaxMessageSize)
|
|
|
+ packetBuffer := buffer[portForward.preambleSize:udpgwProtocolMaxMessageSize]
|
|
|
+ for {
|
|
|
+ // TODO: if read buffer is too small, excess bytes are discarded?
|
|
|
+ packetSize, err := portForward.conn.Read(packetBuffer)
|
|
|
+ if packetSize > udpgwProtocolMaxPayloadSize {
|
|
|
+ err = fmt.Errorf("unexpected packet size: %d", packetSize)
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ if err != io.EOF {
|
|
|
+ log.WithContextFields(LogFields{"error": err}).Warning("downstream UDP relay failed")
|
|
|
+ }
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ err = writeUdpgwPreamble(
|
|
|
+ portForward.preambleSize,
|
|
|
+ portForward.connID,
|
|
|
+ portForward.remoteIP,
|
|
|
+ portForward.remotePort,
|
|
|
+ uint16(packetSize),
|
|
|
+ buffer)
|
|
|
+ if err == nil {
|
|
|
+ _, err = portForward.mux.sshChannel.Write(buffer[0 : portForward.preambleSize+packetSize])
|
|
|
+ }
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ // Close the channel, which will interrupt the main loop.
|
|
|
+ portForward.mux.sshChannel.Close()
|
|
|
+ log.WithContextFields(LogFields{"error": err}).Warning("downstream UDP relay failed")
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ atomic.StoreInt64(&portForward.lastActivity, time.Now().UnixNano())
|
|
|
+ atomic.AddInt64(&portForward.bytesDown, int64(packetSize))
|
|
|
+ }
|
|
|
+
|
|
|
+ portForward.mux.removePortForward(portForward.connID)
|
|
|
+
|
|
|
+ portForward.conn.Close()
|
|
|
+
|
|
|
+ bytesUp := atomic.LoadInt64(&portForward.bytesUp)
|
|
|
+ bytesDown := atomic.LoadInt64(&portForward.bytesDown)
|
|
|
+ portForward.mux.sshClient.closedPortForward(
|
|
|
+ portForward.mux.sshClient.udpTrafficState, bytesUp, bytesDown)
|
|
|
|
|
|
- relayWaitGroup.Wait()
|
|
|
+ log.WithContextFields(
|
|
|
+ LogFields{
|
|
|
+ "remoteAddr": fmt.Sprintf("%s:%d",
|
|
|
+ net.IP(portForward.remoteIP).String(), portForward.remotePort),
|
|
|
+ "bytesUp": bytesUp,
|
|
|
+ "bytesDown": bytesDown,
|
|
|
+ "connID": portForward.connID}).Debug("exiting")
|
|
|
}
|
|
|
|
|
|
// TODO: express and/or calculate udpgwProtocolMaxPayloadSize as function of MTU?
|