Browse Source

Cleaner implementation of idle port forward timeout

Rod Hynes 10 years ago
parent
commit
2963f2e2ba
2 changed files with 29 additions and 55 deletions
  1. 14 39
      psiphon/net.go
  2. 15 16
      psiphon/server/sshService.go

+ 14 - 39
psiphon/net.go

@@ -392,17 +392,23 @@ func IPAddressFromAddr(addr net.Addr) string {
 }
 
 // IdleTimeoutConn wraps a net.Conn and sets an initial ReadDeadline. The
-// deadline is reset whenever data is received from the connection.
+// deadline is extended whenever data is received from the connection.
+// Optionally, IdleTimeoutConn will also extend the deadline when data is
+// written to the connection.
 type IdleTimeoutConn struct {
 	net.Conn
-	deadline time.Duration
+	deadline     time.Duration
+	resetOnWrite bool
 }
 
-func NewIdleTimeoutConn(conn net.Conn, deadline time.Duration) *IdleTimeoutConn {
+func NewIdleTimeoutConn(
+	conn net.Conn, deadline time.Duration, resetOnWrite bool) *IdleTimeoutConn {
+
 	conn.SetReadDeadline(time.Now().Add(deadline))
 	return &IdleTimeoutConn{
-		Conn:     conn,
-		deadline: deadline,
+		Conn:         conn,
+		deadline:     deadline,
+		resetOnWrite: resetOnWrite,
 	}
 }
 
@@ -414,41 +420,10 @@ func (conn *IdleTimeoutConn) Read(buffer []byte) (int, error) {
 	return n, err
 }
 
-// JointIdleTimeoutConn wraps a pair of net.Conns, implementing an idle
-// timeout using SetReadDeadline. The read deadline for both conns is
-// extended when either one complete a read.
-type JointIdleTimeoutConn struct {
-	net.Conn
-	deadline time.Duration
-	peer     net.Conn
-}
-
-func NewJointIdleTimeoutConn(
-	conn1, conn2 net.Conn, deadline time.Duration) (
-	*JointIdleTimeoutConn, *JointIdleTimeoutConn) {
-
-	conn1.SetReadDeadline(time.Now().Add(deadline))
-	joint1 := &JointIdleTimeoutConn{
-		Conn:     conn1,
-		deadline: deadline,
-		peer:     conn2,
-	}
-
-	conn2.SetReadDeadline(time.Now().Add(deadline))
-	joint2 := &JointIdleTimeoutConn{
-		Conn:     conn2,
-		deadline: deadline,
-		peer:     conn1,
-	}
-
-	return joint1, joint2
-}
-
-func (conn *JointIdleTimeoutConn) Read(buffer []byte) (int, error) {
-	n, err := conn.Conn.Read(buffer)
-	if err == nil {
+func (conn *IdleTimeoutConn) Write(buffer []byte) (int, error) {
+	n, err := conn.Conn.Write(buffer)
+	if err == nil && conn.resetOnWrite {
 		conn.Conn.SetReadDeadline(time.Now().Add(conn.deadline))
-		conn.peer.SetReadDeadline(time.Now().Add(conn.deadline))
 	}
 	return n, err
 }

+ 15 - 16
psiphon/server/sshService.go

@@ -247,12 +247,12 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 	sshClient.trafficRules = sshServer.config.GetTrafficRules(sshClient.geoIPData.Country)
 
 	// Wrap the base TCP connection with an IdleTimeoutConn which will terminate
-	// the connection if it's idle for too long. This timeout is in effect for
-	// the entire duration of the SSH connection. Clients must actively use
-	// the connection or send SSH keep alive requests to keep the connection
+	// the connection if no data is received before the deadline. This timeout is
+	// in effect for the entire duration of the SSH connection. Clients must actively
+	// use the connection or send SSH keep alive requests to keep the connection
 	// active.
 
-	conn := psiphon.NewIdleTimeoutConn(tcpConn, SSH_CONNECTION_READ_DEADLINE)
+	conn := psiphon.NewIdleTimeoutConn(tcpConn, SSH_CONNECTION_READ_DEADLINE, false)
 
 	// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
 	// respect shutdownBroadcast and implement a specific handshake timeout.
@@ -435,18 +435,17 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
 
 	defer fwdChannel.Close()
 
-	// TODO: Fix -- fwdChannel is a ssh.Channel which is not a net.Conn, which
-	// NewJointIdleTimeoutConn expects.
-	/*
-		// When idle port forward traffic rules are in place, wrap each end of the
-		// port forward in peer JointIdleTimeoutConns which triggers an idle
-		// timeout only if both ends are (read) idle.
-		if sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds > 0 {
-			fwdConn, fwdChannel = psiphon.NewJointIdleTimeoutConn(
-				fwdConn, fwdChannel,
-				time.Duration(sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds)*time.Millisecond)
-		}
-	*/
+	// When idle port forward traffic rules are in place, wrap fwdConn
+	// in an IdleTimeoutConn configured to reset idle on writes as well
+	// as read. This ensures the port forward idle timeout only happens
+	// when both upstream and downstream directions are are idle.
+
+	if sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds > 0 {
+		fwdConn = psiphon.NewIdleTimeoutConn(
+			fwdConn,
+			time.Duration(sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds)*time.Millisecond,
+			true)
+	}
 
 	// relay channel to forwarded connection
 	// TODO: relay errors to fwdChannel.Stderr()?