Explorar el Código

Timeout idle UDP port forwards

Rod Hynes hace 9 años
padre
commit
0e648ccfde
Se han modificado 1 ficheros con 20 adiciones y 6 borrados
  1. 20 6
      psiphon/server/udp.go

+ 20 - 6
psiphon/server/udp.go

@@ -178,7 +178,7 @@ func (mux *udpPortForwardMultiplexer) run() {
 					"connID":     message.connID}).Debug("dialing")
 					"connID":     message.connID}).Debug("dialing")
 
 
 			// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
 			// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
-			updConn, err := net.DialUDP(
+			udpConn, err := net.DialUDP(
 				"udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort})
 				"udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort})
 			if err != nil {
 			if err != nil {
 				mux.sshClient.closedPortForward(mux.sshClient.udpTrafficState, 0, 0)
 				mux.sshClient.closedPortForward(mux.sshClient.udpTrafficState, 0, 0)
@@ -186,12 +186,27 @@ func (mux *udpPortForwardMultiplexer) run() {
 				continue
 				continue
 			}
 			}
 
 
+			// When idle port forward traffic rules are in place, wrap updConn
+			// in an IdleTimeoutConn configured to reset idle on writes as well
+			// as reads. This ensures the port forward idle timeout only happens
+			// when both upstream and downstream directions are are idle.
+
+			var conn net.Conn
+			if mux.sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds > 0 {
+				conn = psiphon.NewIdleTimeoutConn(
+					udpConn,
+					time.Duration(mux.sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds)*time.Millisecond,
+					true)
+			} else {
+				conn = udpConn
+			}
+
 			portForward = &udpPortForward{
 			portForward = &udpPortForward{
 				connID:       message.connID,
 				connID:       message.connID,
 				preambleSize: message.preambleSize,
 				preambleSize: message.preambleSize,
 				remoteIP:     message.remoteIP,
 				remoteIP:     message.remoteIP,
 				remotePort:   message.remotePort,
 				remotePort:   message.remotePort,
-				conn:         updConn,
+				conn:         conn,
 				lastActivity: time.Now().UnixNano(),
 				lastActivity: time.Now().UnixNano(),
 				bytesUp:      0,
 				bytesUp:      0,
 				bytesDown:    0,
 				bytesDown:    0,
@@ -201,8 +216,6 @@ func (mux *udpPortForwardMultiplexer) run() {
 			mux.portForwards[portForward.connID] = portForward
 			mux.portForwards[portForward.connID] = portForward
 			mux.portForwardsMutex.Unlock()
 			mux.portForwardsMutex.Unlock()
 
 
-			// TODO: timeout inactive UDP port forwards
-
 			// relayDownstream will call sshClient.closedPortForward()
 			// relayDownstream will call sshClient.closedPortForward()
 			mux.relayWaitGroup.Add(1)
 			mux.relayWaitGroup.Add(1)
 			go portForward.relayDownstream()
 			go portForward.relayDownstream()
@@ -233,7 +246,8 @@ func (mux *udpPortForwardMultiplexer) run() {
 }
 }
 
 
 func (mux *udpPortForwardMultiplexer) closeLeastRecentlyUsedPortForward() {
 func (mux *udpPortForwardMultiplexer) closeLeastRecentlyUsedPortForward() {
-	// TODO: use "container/list" and avoid a linear scan?
+	// TODO: use "container/list" and avoid a linear scan? However,
+	// move-to-front on each read/write would require mutex lock?
 	mux.portForwardsMutex.Lock()
 	mux.portForwardsMutex.Lock()
 	oldestActivity := int64(math.MaxInt64)
 	oldestActivity := int64(math.MaxInt64)
 	var oldestPortForward *udpPortForward
 	var oldestPortForward *udpPortForward
@@ -273,7 +287,7 @@ type udpPortForward struct {
 	preambleSize int
 	preambleSize int
 	remoteIP     []byte
 	remoteIP     []byte
 	remotePort   uint16
 	remotePort   uint16
-	conn         *net.UDPConn
+	conn         net.Conn
 	lastActivity int64
 	lastActivity int64
 	bytesUp      int64
 	bytesUp      int64
 	bytesDown    int64
 	bytesDown    int64