فهرست منبع

Implement traffic rules
* Specified in config, with default rules and optional
per-region (country) rules
* Upstream and downstream throttling
* Maximum number of port forwards per client
* Idle timeout for port forwards (incomplete)

Rod Hynes 10 سال پیش
والد
کامیت
7adf194c15
4فایلهای تغییر یافته به همراه180 افزوده شده و 41 حذف شده
  1. 2 2
      psiphon/config.go
  2. 49 10
      psiphon/net.go
  3. 51 0
      psiphon/server/config.go
  4. 78 29
      psiphon/server/sshService.go

+ 2 - 2
psiphon/config.go

@@ -158,8 +158,8 @@ type Config struct {
 	TunnelProtocol string
 
 	// EstablishTunnelTimeoutSeconds specifies a time limit after which to halt
-	// the core tunnel controller if no tunnel has been established. By default,
-	// the controller will keep trying indefinitely.
+	// the core tunnel controller if no tunnel has been established. The default
+	// is ESTABLISH_TUNNEL_TIMEOUT_SECONDS.
 	EstablishTunnelTimeoutSeconds *int
 
 	// ListenInterface specifies which interface to listen on.  If no interface

+ 49 - 10
psiphon/net.go

@@ -391,25 +391,64 @@ func IPAddressFromAddr(addr net.Addr) string {
 	return ipAddress
 }
 
-// TimeoutTCPConn wraps a net.TCPConn and sets an initial ReadDeadline. The
+// IdleTimeoutConn wraps a net.Conn and sets an initial ReadDeadline. The
 // deadline is reset whenever data is received from the connection.
-type TimeoutTCPConn struct {
-	*net.TCPConn
+type IdleTimeoutConn struct {
+	net.Conn
 	deadline time.Duration
 }
 
-func NewTimeoutTCPConn(tcpConn *net.TCPConn, deadline time.Duration) *TimeoutTCPConn {
-	tcpConn.SetReadDeadline(time.Now().Add(deadline))
-	return &TimeoutTCPConn{
-		TCPConn:  tcpConn,
+func NewIdleTimeoutConn(conn net.Conn, deadline time.Duration) *IdleTimeoutConn {
+	conn.SetReadDeadline(time.Now().Add(deadline))
+	return &IdleTimeoutConn{
+		Conn:     conn,
 		deadline: deadline,
 	}
 }
 
-func (conn *TimeoutTCPConn) Read(buffer []byte) (int, error) {
-	n, err := conn.TCPConn.Read(buffer)
+func (conn *IdleTimeoutConn) Read(buffer []byte) (int, error) {
+	n, err := conn.Conn.Read(buffer)
 	if err == nil {
-		conn.TCPConn.SetReadDeadline(time.Now().Add(conn.deadline))
+		conn.Conn.SetReadDeadline(time.Now().Add(conn.deadline))
+	}
+	return n, err
+}
+
+// JointIdleTimeoutConn wraps a pair of net.Conns, implementing an idle
+// timeout using SetReadDeadline. The read deadline for both conns is
+// extended when either one complete a read.
+type JointIdleTimeoutConn struct {
+	net.Conn
+	deadline time.Duration
+	peer     net.Conn
+}
+
+func NewJointIdleTimeoutConn(
+	conn1, conn2 net.Conn, deadline time.Duration) (
+	*JointIdleTimeoutConn, *JointIdleTimeoutConn) {
+
+	conn1.SetReadDeadline(time.Now().Add(deadline))
+	joint1 := &JointIdleTimeoutConn{
+		Conn:     conn1,
+		deadline: deadline,
+		peer:     conn2,
+	}
+
+	conn2.SetReadDeadline(time.Now().Add(deadline))
+	joint2 := &JointIdleTimeoutConn{
+		Conn:     conn2,
+		deadline: deadline,
+		peer:     conn1,
+	}
+
+	return joint1, joint2
+}
+
+func (conn *JointIdleTimeoutConn) Read(buffer []byte) (int, error) {
+	n, err := conn.Conn.Read(buffer)
+	if err == nil {
+		conn.Conn.SetReadDeadline(time.Now().Add(conn.deadline))
+		conn.peer.SetReadDeadline(time.Now().Add(conn.deadline))
 	}
 	return n, err
 }

+ 51 - 0
psiphon/server/config.go

@@ -55,6 +55,7 @@ const (
 	DEFAULT_SSH_SERVER_PORT                = 2222
 	SSH_HANDSHAKE_TIMEOUT                  = 30 * time.Second
 	SSH_CONNECTION_READ_DEADLINE           = 5 * time.Minute
+	SSH_THROTTLED_PORT_FORWARD_MAX_COPY    = 32 * 1024
 	SSH_OBFUSCATED_KEY_BYTE_LENGTH         = 32
 	DEFAULT_OBFUSCATED_SSH_SERVER_PORT     = 3333
 	REDIS_POOL_MAX_IDLE                    = 50
@@ -146,6 +147,42 @@ type Config struct {
 	// RedisServerAddress is the TCP address of a redis server. When
 	// set, redis is used to store per-session GeoIP information.
 	RedisServerAddress string
+
+	// DefaultTrafficRules specifies the traffic rules to be used when
+	// no regional-specific rules are set.
+	DefaultTrafficRules TrafficRules
+
+	// RegionalTrafficRules specifies the traffic rules for particular
+	// client regions (countries) as determined by GeoIP lookup of the
+	// client IP address. The key for each regional traffic rule entry
+	// is one or more space delimited ISO 3166-1 alpha-2 country codes.
+	RegionalTrafficRules map[string]TrafficRules
+}
+
+// TrafficRules specify the limits placed on SSH client port forward
+// traffic.
+type TrafficRules struct {
+
+	// ThrottleUpstreamSleepMilliseconds is the period to sleep
+	// between sending each chunk of client->destination traffic.
+	// The default, 0, is no sleep.
+	ThrottleUpstreamSleepMilliseconds int
+
+	// ThrottleDownstreamSleepMilliseconds is the period to sleep
+	// between sending each chunk of destination->client traffic.
+	// The default, 0, is no sleep.
+	ThrottleDownstreamSleepMilliseconds int
+
+	// IdlePortForwardTimeoutMilliseconds is the timeout period
+	// after which idle (no bytes flowing in either direction)
+	// SSH client port forwards are preemptively closed.
+	// The default, 0, is no idle timeout.
+	IdlePortForwardTimeoutMilliseconds int
+
+	// MaxClientPortForwardCount is the maximum number of port
+	// forwards each client may have open concurrently.
+	// The default, 0, is no maximum.
+	MaxClientPortForwardCount int
 }
 
 // RunWebServer indicates whether to run a web server component.
@@ -169,6 +206,20 @@ func (config *Config) UseRedis() bool {
 	return config.RedisServerAddress != ""
 }
 
+// GetTrafficRules looks up the traffic rules for the specified country. If there
+// are no RegionalTrafficRules for the country, DefaultTrafficRules are returned.
+func (config *Config) GetTrafficRules(targetCountryCode string) TrafficRules {
+	// TODO: faster lookup?
+	for countryCodes, trafficRules := range config.RegionalTrafficRules {
+		for _, countryCode := range strings.Split(countryCodes, " ") {
+			if countryCode == targetCountryCode {
+				return trafficRules
+			}
+		}
+	}
+	return config.DefaultTrafficRules
+}
+
 // LoadConfig loads and validates a JSON encoded server config. If more than one
 // JSON config is specified, then all are loaded and values are merged together,
 // in order. Multiple configs allows for use cases like storing static, server-specific

+ 78 - 29
psiphon/server/sshService.go

@@ -211,16 +211,16 @@ func (sshServer *sshServer) stopClient(client *sshClient) {
 	client.Lock()
 	log.WithContextFields(
 		LogFields{
-			"startTime":                     client.startTime,
-			"duration":                      time.Now().Sub(client.startTime),
-			"psiphonSessionID":              client.psiphonSessionID,
-			"country":                       client.geoIPData.Country,
-			"city":                          client.geoIPData.City,
-			"ISP":                           client.geoIPData.ISP,
-			"bytesUp":                       client.bytesUp,
-			"bytesDown":                     client.bytesDown,
-			"portForwardCount":              client.portForwardCount,
-			"maxConcurrentPortForwardCount": client.maxConcurrentPortForwardCount,
+			"startTime":                      client.startTime,
+			"duration":                       time.Now().Sub(client.startTime),
+			"psiphonSessionID":               client.psiphonSessionID,
+			"country":                        client.geoIPData.Country,
+			"city":                           client.geoIPData.City,
+			"ISP":                            client.geoIPData.ISP,
+			"bytesUp":                        client.bytesUp,
+			"bytesDown":                      client.bytesDown,
+			"portForwardCount":               client.portForwardCount,
+			"peakConcurrentPortForwardCount": client.peakConcurrentPortForwardCount,
 		}).Info("tunnel closed")
 	client.Unlock()
 }
@@ -244,14 +244,15 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 		startTime: time.Now(),
 		geoIPData: GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr())),
 	}
+	sshClient.trafficRules = sshServer.config.GetTrafficRules(sshClient.geoIPData.Country)
 
-	// Wrap the base TCP connection in a TimeoutTCPConn which will terminate
+	// Wrap the base TCP connection with an IdleTimeoutConn which will terminate
 	// the connection if it's idle for too long. This timeout is in effect for
 	// the entire duration of the SSH connection. Clients must actively use
 	// the connection or send SSH keep alive requests to keep the connection
 	// active.
 
-	conn := psiphon.NewTimeoutTCPConn(tcpConn, SSH_CONNECTION_READ_DEADLINE)
+	conn := psiphon.NewIdleTimeoutConn(tcpConn, SSH_CONNECTION_READ_DEADLINE)
 
 	// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
 	// respect shutdownBroadcast and implement a specific handshake timeout.
@@ -320,7 +321,7 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 
 	clientID, ok := sshServer.registerClient(sshClient)
 	if !ok {
-		tcpConn.Close()
+		conn.Close()
 		log.WithContext().Warning("register failed")
 		return
 	}
@@ -333,16 +334,17 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 
 type sshClient struct {
 	sync.Mutex
-	sshServer                     *sshServer
-	sshConn                       ssh.Conn
-	startTime                     time.Time
-	geoIPData                     GeoIPData
-	psiphonSessionID              string
-	bytesUp                       int64
-	bytesDown                     int64
-	portForwardCount              int64
-	concurrentPortForwardCount    int64
-	maxConcurrentPortForwardCount int64
+	sshServer                      *sshServer
+	sshConn                        ssh.Conn
+	startTime                      time.Time
+	geoIPData                      GeoIPData
+	trafficRules                   TrafficRules
+	psiphonSessionID               string
+	bytesUp                        int64
+	bytesDown                      int64
+	portForwardCount               int64
+	concurrentPortForwardCount     int64
+	peakConcurrentPortForwardCount int64
 }
 
 func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
@@ -353,6 +355,18 @@ func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
 			return
 		}
 
+		if sshClient.trafficRules.MaxClientPortForwardCount > 0 {
+			sshClient.Lock()
+			limitExceeded := sshClient.portForwardCount >= int64(sshClient.trafficRules.MaxClientPortForwardCount)
+			sshClient.Unlock()
+
+			if limitExceeded {
+				sshClient.rejectNewChannel(
+					newChannel, ssh.ResourceShortage, "maximum port forward limit exceeded")
+				return
+			}
+		}
+
 		// process each port forward concurrently
 		go sshClient.handleNewDirectTcpipChannel(newChannel)
 	}
@@ -410,8 +424,8 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
 	sshClient.Lock()
 	sshClient.portForwardCount += 1
 	sshClient.concurrentPortForwardCount += 1
-	if sshClient.concurrentPortForwardCount > sshClient.maxConcurrentPortForwardCount {
-		sshClient.maxConcurrentPortForwardCount = sshClient.concurrentPortForwardCount
+	if sshClient.concurrentPortForwardCount > sshClient.peakConcurrentPortForwardCount {
+		sshClient.peakConcurrentPortForwardCount = sshClient.concurrentPortForwardCount
 	}
 	sshClient.Unlock()
 
@@ -421,9 +435,20 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
 
 	defer fwdChannel.Close()
 
-	// relay channel to forwarded connection
+	// TODO: Fix -- fwdChannel is a ssh.Channel which is not a net.Conn, which
+	// NewJointIdleTimeoutConn expects.
+	/*
+		// When idle port forward traffic rules are in place, wrap each end of the
+		// port forward in peer JointIdleTimeoutConns which triggers an idle
+		// timeout only if both ends are (read) idle.
+		if sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds > 0 {
+			fwdConn, fwdChannel = psiphon.NewJointIdleTimeoutConn(
+				fwdConn, fwdChannel,
+				time.Duration(sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds)*time.Millisecond)
+		}
+	*/
 
-	// TODO: use a low-memory io.Copy?
+	// relay channel to forwarded connection
 	// TODO: relay errors to fwdChannel.Stderr()?
 
 	var bytesUp, bytesDown int64
@@ -433,12 +458,14 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
 	go func() {
 		defer relayWaitGroup.Done()
 		var err error
-		bytesUp, err = io.Copy(fwdConn, fwdChannel)
+		bytesUp, err = copyWithThrottle(
+			fwdConn, fwdChannel, sshClient.trafficRules.ThrottleUpstreamSleepMilliseconds)
 		if err != nil {
 			log.WithContextFields(LogFields{"error": err}).Warning("upstream relay failed")
 		}
 	}()
-	bytesDown, err = io.Copy(fwdChannel, fwdConn)
+	bytesDown, err = copyWithThrottle(
+		fwdChannel, fwdConn, sshClient.trafficRules.ThrottleDownstreamSleepMilliseconds)
 	if err != nil {
 		log.WithContextFields(LogFields{"error": err}).Warning("downstream relay failed")
 	}
@@ -454,6 +481,28 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
 	log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
 }
 
+func copyWithThrottle(dst io.Writer, src io.Reader, throttleSleepMilliseconds int) (int64, error) {
+	// TODO: use a low-memory io.Copy?
+	if throttleSleepMilliseconds <= 0 {
+		// No throttle
+		return io.Copy(dst, src)
+	}
+	var totalBytes int64
+	for {
+		bytes, err := io.CopyN(dst, src, SSH_THROTTLED_PORT_FORWARD_MAX_COPY)
+		totalBytes += bytes
+		if err == io.EOF {
+			err = nil
+			break
+		}
+		if err != nil {
+			return totalBytes, psiphon.ContextError(err)
+		}
+		time.Sleep(time.Duration(throttleSleepMilliseconds) * time.Millisecond)
+	}
+	return totalBytes, nil
+}
+
 func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
 	var sshPasswordPayload struct {
 		SessionId   string `json:"SessionId"`