Browse Source

Merge pull request #251 from rod-hynes/master

New traffic rules functionality + misc.
Rod Hynes 9 years ago
parent
commit
3c7738fcba

+ 1 - 1
ConsoleClient/Dockerfile

@@ -22,7 +22,7 @@ RUN apt-get update -y && apt-get install -y --no-install-recommends \
   && rm -rf /var/lib/apt/lists/*
 
 # Install Go.
-ENV GOVERSION=go1.7 GOROOT=/usr/local/go GOPATH=/go PATH=$PATH:/usr/local/go/bin:/go/bin CGO_ENABLED=1
+ENV GOVERSION=go1.7.1 GOROOT=/usr/local/go GOPATH=/go PATH=$PATH:/usr/local/go/bin:/go/bin CGO_ENABLED=1
 
 RUN curl -L https://storage.googleapis.com/golang/$GOVERSION.linux-amd64.tar.gz -o /tmp/go.tar.gz \
    && tar -C /usr/local -xzf /tmp/go.tar.gz \

+ 1 - 1
MobileLibrary/Android/Dockerfile

@@ -19,7 +19,7 @@ RUN apt-get update -y && apt-get install -y --no-install-recommends \
   && rm -rf /var/lib/apt/lists/*
 
 # Install Go.
-ENV GOVERSION=go1.7 GOROOT=/usr/local/go GOPATH=/go PATH=$PATH:/usr/local/go/bin:/go/bin CGO_ENABLED=1
+ENV GOVERSION=go1.7.1 GOROOT=/usr/local/go GOPATH=/go PATH=$PATH:/usr/local/go/bin:/go/bin CGO_ENABLED=1
 
 RUN curl -L https://storage.googleapis.com/golang/$GOVERSION.linux-amd64.tar.gz -o /tmp/go.tar.gz \
   && tar -C /usr/local -xzf /tmp/go.tar.gz \

+ 1 - 1
Server/Dockerfile-binary-builder

@@ -1,6 +1,6 @@
 FROM alpine:latest
 
-ENV GOLANG_VERSION 1.7
+ENV GOLANG_VERSION 1.7.1
 ENV GOLANG_SRC_URL https://golang.org/dl/go$GOLANG_VERSION.src.tar.gz
 
 RUN set -ex \

+ 17 - 0
psiphon/common/protocol.go

@@ -41,6 +41,9 @@ const (
 	PSIPHON_API_CLIENT_VERIFICATION_REQUEST_NAME = "psiphon-client-verification"
 
 	PSIPHON_API_CLIENT_SESSION_ID_LENGTH = 16
+
+	PSIPHON_SSH_API_PROTOCOL = "ssh"
+	PSIPHON_WEB_API_PROTOCOL = "web"
 )
 
 var SupportedTunnelProtocols = []string{
@@ -76,3 +79,17 @@ func TunnelProtocolUsesMeekHTTPS(protocol string) bool {
 	return protocol == TUNNEL_PROTOCOL_FRONTED_MEEK ||
 		protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS
 }
+
+type HandshakeResponse struct {
+	Homepages            []string            `json:"homepages"`
+	UpgradeClientVersion string              `json:"upgrade_client_version"`
+	PageViewRegexes      []map[string]string `json:"page_view_regexes"`
+	HttpsRequestRegexes  []map[string]string `json:"https_request_regexes"`
+	EncodedServerList    []string            `json:"encoded_server_list"`
+	ClientRegion         string              `json:"client_region"`
+	ServerTimestamp      string              `json:"server_timestamp"`
+}
+
+type ConnectedResponse struct {
+	ConnectedTimestamp string `json:"connected_timestamp"`
+}

+ 123 - 57
psiphon/common/throttled.go

@@ -20,8 +20,10 @@
 package common
 
 import (
+	"errors"
 	"io"
 	"net"
+	"sync"
 	"sync/atomic"
 
 	"github.com/Psiphon-Inc/ratelimit"
@@ -30,23 +32,27 @@ import (
 // RateLimits specify the rate limits for a ThrottledConn.
 type RateLimits struct {
 
-	// DownstreamUnlimitedBytes specifies the number of downstream
-	// bytes to transfer, approximately, before starting rate
-	// limiting.
-	DownstreamUnlimitedBytes int64
+	// ReadUnthrottledBytes specifies the number of bytes to
+	// read, approximately, before starting rate limiting.
+	ReadUnthrottledBytes int64
 
-	// DownstreamBytesPerSecond specifies a rate limit for downstream
+	// ReadBytesPerSecond specifies a rate limit for read
 	// data transfer. The default, 0, is no limit.
-	DownstreamBytesPerSecond int64
+	ReadBytesPerSecond int64
 
-	// UpstreamUnlimitedBytes specifies the number of upstream
-	// bytes to transfer, approximately, before starting rate
-	// limiting.
-	UpstreamUnlimitedBytes int64
+	// WriteUnthrottledBytes specifies the number of bytes to
+	// write, approximately, before starting rate limiting.
+	WriteUnthrottledBytes int64
 
-	// UpstreamBytesPerSecond specifies a rate limit for upstream
+	// WriteBytesPerSecond specifies a rate limit for write
 	// data transfer. The default, 0, is no limit.
-	UpstreamBytesPerSecond int64
+	WriteBytesPerSecond int64
+
+	// CloseAfterExhausted indicates that the underlying
+	// net.Conn should be closed once either the read or
+	// write unthrottled bytes have been exhausted. In this
+	// case, throttling is never applied.
+	CloseAfterExhausted bool
 }
 
 // ThrottledConn wraps a net.Conn with read and write rate limiters.
@@ -60,76 +66,136 @@ type ThrottledConn struct {
 	// Note: 64-bit ints used with atomic operations are at placed
 	// at the start of struct to ensure 64-bit alignment.
 	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
-	unlimitedReadBytes  int64
-	unlimitedWriteBytes int64
-	limitingReads       int32
-	limitingWrites      int32
-	limitedReader       io.Reader
-	limitedWriter       io.Writer
+	readUnthrottledBytes  int64
+	readBytesPerSecond    int64
+	writeUnthrottledBytes int64
+	writeBytesPerSecond   int64
+	closeAfterExhausted   int32
+	readLock              sync.Mutex
+	throttledReader       io.Reader
+	writeLock             sync.Mutex
+	throttledWriter       io.Writer
 	net.Conn
 }
 
 // NewThrottledConn initializes a new ThrottledConn.
 func NewThrottledConn(conn net.Conn, limits RateLimits) *ThrottledConn {
+	throttledConn := &ThrottledConn{Conn: conn}
+	throttledConn.SetLimits(limits)
+	return throttledConn
+}
 
-	// When no limit is specified, the rate limited reader/writer
-	// is simply the base reader/writer.
-
-	var reader io.Reader
-	if limits.DownstreamBytesPerSecond == 0 {
-		reader = conn
-	} else {
-		reader = ratelimit.Reader(conn,
-			ratelimit.NewBucketWithRate(
-				float64(limits.DownstreamBytesPerSecond),
-				limits.DownstreamBytesPerSecond))
+// SetLimits modifies the rate limits of an existing
+// ThrottledConn. It is safe to call SetLimits while
+// other goroutines are calling Read/Write. This function
+// will not block, and the new rate limits will be
+// applied within Read/Write, but not necessarily until
+// some futher I/O at previous rates.
+func (conn *ThrottledConn) SetLimits(limits RateLimits) {
+
+	// Using atomic instead of mutex to avoid blocking
+	// this function on throttled I/O in an ongoing
+	// read or write. Precise synchronized application
+	// of the rate limit values is not required.
+
+	// Negative rates are invalid and -1 is a special
+	// value to used to signal throttling initialized
+	// state. Silently normalize negative values to 0.
+	rate := limits.ReadBytesPerSecond
+	if rate < 0 {
+		rate = 0
 	}
+	atomic.StoreInt64(&conn.readBytesPerSecond, rate)
+	atomic.StoreInt64(&conn.readUnthrottledBytes, limits.ReadUnthrottledBytes)
 
-	var writer io.Writer
-	if limits.UpstreamBytesPerSecond == 0 {
-		writer = conn
-	} else {
-		writer = ratelimit.Writer(conn,
-			ratelimit.NewBucketWithRate(
-				float64(limits.UpstreamBytesPerSecond),
-				limits.UpstreamBytesPerSecond))
+	rate = limits.WriteBytesPerSecond
+	if rate < 0 {
+		rate = 0
 	}
+	atomic.StoreInt64(&conn.writeBytesPerSecond, rate)
+	atomic.StoreInt64(&conn.writeUnthrottledBytes, limits.WriteUnthrottledBytes)
 
-	return &ThrottledConn{
-		Conn:                conn,
-		unlimitedReadBytes:  limits.DownstreamUnlimitedBytes,
-		limitingReads:       0,
-		limitedReader:       reader,
-		unlimitedWriteBytes: limits.UpstreamUnlimitedBytes,
-		limitingWrites:      0,
-		limitedWriter:       writer,
+	closeAfterExhausted := int32(0)
+	if limits.CloseAfterExhausted {
+		closeAfterExhausted = 1
 	}
+	atomic.StoreInt32(&conn.closeAfterExhausted, closeAfterExhausted)
 }
 
 func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
 
-	// Use the base reader until the unlimited count is exhausted.
-	if atomic.LoadInt32(&conn.limitingReads) == 0 {
-		if atomic.AddInt64(&conn.unlimitedReadBytes, -int64(len(buffer))) <= 0 {
-			atomic.StoreInt32(&conn.limitingReads, 1)
+	// A mutex is used to ensure conformance with net.Conn
+	// concurrency semantics. The atomic.SwapInt64 and
+	// subsequent assignment of throttledReader could be
+	// a race condition with concurrent reads.
+	conn.readLock.Lock()
+	defer conn.readLock.Unlock()
+
+	// Use the base conn until the unthrottled count is
+	// exhausted. This is only an approximate enforcement
+	// since this read, or concurrent reads, could exceed
+	// the remaining count.
+	if atomic.LoadInt64(&conn.readUnthrottledBytes) > 0 {
+		n, err := conn.Conn.Read(buffer)
+		atomic.AddInt64(&conn.readUnthrottledBytes, -int64(n))
+		return n, err
+	}
+
+	if atomic.LoadInt32(&conn.closeAfterExhausted) == 1 {
+		conn.Conn.Close()
+		return 0, errors.New("throttled conn exhausted")
+	}
+
+	rate := atomic.SwapInt64(&conn.readBytesPerSecond, -1)
+
+	if rate != -1 {
+		// SetLimits has been called and a new rate limiter
+		// must be initialized. When no limit is specified,
+		// the reader/writer is simply the base conn.
+		// No state is retained from the previous rate limiter,
+		// so a pending I/O throttle sleep may be skipped when
+		// the old and new rate are similar.
+		if rate == 0 {
+			conn.throttledReader = conn.Conn
 		} else {
-			return conn.Read(buffer)
+			conn.throttledReader = ratelimit.Reader(
+				conn.Conn,
+				ratelimit.NewBucketWithRate(float64(rate), rate))
 		}
 	}
 
-	return conn.limitedReader.Read(buffer)
+	return conn.throttledReader.Read(buffer)
 }
 
 func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
 
-	// Use the base writer until the unlimited count is exhausted.
-	if atomic.LoadInt32(&conn.limitingWrites) == 0 {
-		if atomic.AddInt64(&conn.unlimitedWriteBytes, -int64(len(buffer))) <= 0 {
-			atomic.StoreInt32(&conn.limitingWrites, 1)
+	// See comments in Read.
+
+	conn.writeLock.Lock()
+	defer conn.writeLock.Unlock()
+
+	if atomic.LoadInt64(&conn.writeUnthrottledBytes) > 0 {
+		n, err := conn.Conn.Write(buffer)
+		atomic.AddInt64(&conn.writeUnthrottledBytes, -int64(n))
+		return n, err
+	}
+
+	if atomic.LoadInt32(&conn.closeAfterExhausted) == 1 {
+		conn.Conn.Close()
+		return 0, errors.New("throttled conn exhausted")
+	}
+
+	rate := atomic.SwapInt64(&conn.writeBytesPerSecond, -1)
+
+	if rate != -1 {
+		if rate == 0 {
+			conn.throttledWriter = conn.Conn
 		} else {
-			return conn.Write(buffer)
+			conn.throttledWriter = ratelimit.Writer(
+				conn.Conn,
+				ratelimit.NewBucketWithRate(float64(rate), rate))
 		}
 	}
 
-	return conn.limitedWriter.Write(buffer)
+	return conn.throttledWriter.Write(buffer)
 }

+ 29 - 22
psiphon/common/throttled_test.go

@@ -40,40 +40,47 @@ const (
 func TestThrottledConn(t *testing.T) {
 
 	run(t, RateLimits{
-		DownstreamUnlimitedBytes: 0,
-		DownstreamBytesPerSecond: 0,
-		UpstreamUnlimitedBytes:   0,
-		UpstreamBytesPerSecond:   0,
+		ReadUnthrottledBytes:  0,
+		ReadBytesPerSecond:    0,
+		WriteUnthrottledBytes: 0,
+		WriteBytesPerSecond:   0,
 	})
 
 	run(t, RateLimits{
-		DownstreamUnlimitedBytes: 0,
-		DownstreamBytesPerSecond: 5 * 1024 * 1024,
-		UpstreamUnlimitedBytes:   0,
-		UpstreamBytesPerSecond:   5 * 1024 * 1024,
+		ReadUnthrottledBytes:  0,
+		ReadBytesPerSecond:    5 * 1024 * 1024,
+		WriteUnthrottledBytes: 0,
+		WriteBytesPerSecond:   5 * 1024 * 1024,
 	})
 
 	run(t, RateLimits{
-		DownstreamUnlimitedBytes: 0,
-		DownstreamBytesPerSecond: 2 * 1024 * 1024,
-		UpstreamUnlimitedBytes:   0,
-		UpstreamBytesPerSecond:   2 * 1024 * 1024,
+		ReadUnthrottledBytes:  0,
+		ReadBytesPerSecond:    5 * 1024 * 1024,
+		WriteUnthrottledBytes: 0,
+		WriteBytesPerSecond:   1024 * 1024,
 	})
 
 	run(t, RateLimits{
-		DownstreamUnlimitedBytes: 0,
-		DownstreamBytesPerSecond: 1024 * 1024,
-		UpstreamUnlimitedBytes:   0,
-		UpstreamBytesPerSecond:   1024 * 1024,
+		ReadUnthrottledBytes:  0,
+		ReadBytesPerSecond:    2 * 1024 * 1024,
+		WriteUnthrottledBytes: 0,
+		WriteBytesPerSecond:   2 * 1024 * 1024,
+	})
+
+	run(t, RateLimits{
+		ReadUnthrottledBytes:  0,
+		ReadBytesPerSecond:    1024 * 1024,
+		WriteUnthrottledBytes: 0,
+		WriteBytesPerSecond:   1024 * 1024,
 	})
 
 	// This test takes > 1 min to run, so disabled for now
 	/*
 		run(t, RateLimits{
-			DownstreamUnlimitedBytes: 0,
-			DownstreamBytesPerSecond: 1024 * 1024 / 8,
-			UpstreamUnlimitedBytes:   0,
-			UpstreamBytesPerSecond:   1024 * 1024 / 8,
+			ReadUnthrottledBytes: 0,
+			ReadBytesPerSecond: 1024 * 1024 / 8,
+			WriteUnthrottledBytes:   0,
+			WriteBytesPerSecond:   1024 * 1024 / 8,
 		})
 	*/
 }
@@ -136,7 +143,7 @@ func run(t *testing.T, rateLimits RateLimits) {
 
 	// Test: elapsed upload time must reflect rate limit
 
-	checkElapsedTime(t, testDataSize, rateLimits.UpstreamBytesPerSecond, monotime.Since(startTime))
+	checkElapsedTime(t, testDataSize, rateLimits.WriteBytesPerSecond, monotime.Since(startTime))
 
 	startTime = monotime.Now()
 
@@ -150,7 +157,7 @@ func run(t *testing.T, rateLimits RateLimits) {
 
 	// Test: elapsed download time must reflect rate limit
 
-	checkElapsedTime(t, testDataSize, rateLimits.DownstreamBytesPerSecond, monotime.Since(startTime))
+	checkElapsedTime(t, testDataSize, rateLimits.ReadBytesPerSecond, monotime.Since(startTime))
 }
 
 func checkElapsedTime(t *testing.T, dataSize int, rateLimit int64, duration time.Duration) {

+ 14 - 0
psiphon/config.go

@@ -234,6 +234,12 @@ type Config struct {
 	// status, etc. This is used for special case temporary tunnels (Windows VPN mode).
 	DisableApi bool
 
+	// TargetApiProtocol specifies whether to force use of "ssh" or "web" API protocol.
+	// When blank, the default, the optimal API protocol is used. Note that this
+	// capability check is not applied before the "CandidateServers" count is emitted.
+	// This parameter is intended for testing and debugging only.
+	TargetApiProtocol string
+
 	// DisableRemoteServerListFetcher disables fetching remote server lists. This is
 	// used for special case temporary tunnels.
 	DisableRemoteServerListFetcher bool
@@ -470,6 +476,14 @@ func LoadConfig(configJson []byte) (*Config, error) {
 			errors.New("HostNameTransformer interface must be set at runtime"))
 	}
 
+	if !common.Contains(
+		[]string{"", common.PSIPHON_SSH_API_PROTOCOL, common.PSIPHON_WEB_API_PROTOCOL},
+		config.TargetApiProtocol) {
+
+		return nil, common.ContextError(
+			errors.New("invalid TargetApiProtocol"))
+	}
+
 	if config.UpgradeDownloadUrl != "" &&
 		(config.UpgradeDownloadClientVersionHeader == "" || config.UpgradeDownloadFilename == "") {
 		return nil, common.ContextError(errors.New(

+ 5 - 0
psiphon/controller.go

@@ -1045,6 +1045,11 @@ loop:
 				break
 			}
 
+			if controller.config.TargetApiProtocol == common.PSIPHON_SSH_API_PROTOCOL &&
+				!serverEntry.SupportsSSHAPIRequests() {
+				continue
+			}
+
 			// Disable impaired protocols. This is only done for the
 			// first iteration of the ESTABLISH_TUNNEL_WORK_TIME
 			// loop since (a) one iteration should be sufficient to

+ 72 - 37
psiphon/server/api.go

@@ -72,13 +72,19 @@ func sshAPIRequestHandler(
 			fmt.Errorf("invalid payload for request name: %s: %s", name, err))
 	}
 
-	return dispatchAPIRequestHandler(support, geoIPData, name, params)
+	return dispatchAPIRequestHandler(
+		support,
+		common.PSIPHON_SSH_API_PROTOCOL,
+		geoIPData,
+		name,
+		params)
 }
 
 // dispatchAPIRequestHandler is the common dispatch point for both
 // web and SSH API requests.
 func dispatchAPIRequestHandler(
 	support *SupportServices,
+	apiProtocol string,
 	geoIPData GeoIPData,
 	name string,
 	params requestJSONObject) (response []byte, reterr error) {
@@ -97,7 +103,7 @@ func dispatchAPIRequestHandler(
 
 	switch name {
 	case common.PSIPHON_API_HANDSHAKE_REQUEST_NAME:
-		return handshakeAPIRequestHandler(support, geoIPData, params)
+		return handshakeAPIRequestHandler(support, apiProtocol, geoIPData, params)
 	case common.PSIPHON_API_CONNECTED_REQUEST_NAME:
 		return connectedAPIRequestHandler(support, geoIPData, params)
 	case common.PSIPHON_API_STATUS_REQUEST_NAME:
@@ -115,6 +121,7 @@ func dispatchAPIRequestHandler(
 // stats to record, etc.
 func handshakeAPIRequestHandler(
 	support *SupportServices,
+	apiProtocol string,
 	geoIPData GeoIPData,
 	params requestJSONObject) ([]byte, error) {
 
@@ -133,40 +140,42 @@ func handshakeAPIRequestHandler(
 			params,
 			baseRequestParams))
 
-	// TODO: share struct definition with psiphon/serverApi.go?
-	var handshakeResponse struct {
-		Homepages            []string            `json:"homepages"`
-		UpgradeClientVersion string              `json:"upgrade_client_version"`
-		PageViewRegexes      []map[string]string `json:"page_view_regexes"`
-		HttpsRequestRegexes  []map[string]string `json:"https_request_regexes"`
-		EncodedServerList    []string            `json:"encoded_server_list"`
-		ClientRegion         string              `json:"client_region"`
-		ServerTimestamp      string              `json:"server_timestamp"`
-	}
+	// Note: ignoring param format errors as params have been validated
 
-	// Ignoring errors as params are validated
+	sessionID, _ := getStringRequestParam(params, "client_session_id")
 	sponsorID, _ := getStringRequestParam(params, "sponsor_id")
 	clientVersion, _ := getStringRequestParam(params, "client_version")
 	clientPlatform, _ := getStringRequestParam(params, "client_platform")
-	clientRegion := geoIPData.Country
-
-	// Note: no guarantee that PsinetDatabase won't reload between calls
-
-	handshakeResponse.Homepages = support.PsinetDatabase.GetHomepages(
-		sponsorID, clientRegion, isMobileClientPlatform(clientPlatform))
-
-	handshakeResponse.UpgradeClientVersion = support.PsinetDatabase.GetUpgradeClientVersion(
-		clientVersion, normalizeClientPlatform(clientPlatform))
-
-	handshakeResponse.HttpsRequestRegexes = support.PsinetDatabase.GetHttpsRequestRegexes(
-		sponsorID)
-
-	handshakeResponse.EncodedServerList = support.PsinetDatabase.DiscoverServers(
-		geoIPData.DiscoveryValue)
-
-	handshakeResponse.ClientRegion = clientRegion
+	isMobile := isMobileClientPlatform(clientPlatform)
+	normalizedPlatform := normalizeClientPlatform(clientPlatform)
+
+	// Flag the SSH client as having completed its handshake. This
+	// may reselect traffic rules and starts allowing port forwards.
+
+	// TODO: in the case of SSH API requests, the actual sshClient could
+	// be passed in and used here. The session ID lookup is only strictly
+	// necessary to support web API requests.
+	err = support.TunnelServer.SetClientHandshakeState(
+		sessionID,
+		handshakeState{
+			completed:   true,
+			apiProtocol: apiProtocol,
+			apiParams:   copyBaseRequestParams(params),
+		})
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
 
-	handshakeResponse.ServerTimestamp = common.GetCurrentTimestamp()
+	// Note: no guarantee that PsinetDatabase won't reload between database calls
+	db := support.PsinetDatabase
+	handshakeResponse := common.HandshakeResponse{
+		Homepages:            db.GetHomepages(sponsorID, geoIPData.Country, isMobile),
+		UpgradeClientVersion: db.GetUpgradeClientVersion(clientVersion, normalizedPlatform),
+		HttpsRequestRegexes:  db.GetHttpsRequestRegexes(sponsorID),
+		EncodedServerList:    db.DiscoverServers(geoIPData.DiscoveryValue),
+		ClientRegion:         geoIPData.Country,
+		ServerTimestamp:      common.GetCurrentTimestamp(),
+	}
 
 	responsePayload, err := json.Marshal(handshakeResponse)
 	if err != nil {
@@ -205,13 +214,10 @@ func connectedAPIRequestHandler(
 			params,
 			connectedRequestParams))
 
-	var connectedResponse struct {
-		ConnectedTimestamp string `json:"connected_timestamp"`
+	connectedResponse := common.ConnectedResponse{
+		ConnectedTimestamp: common.TruncateTimestampToHour(common.GetCurrentTimestamp()),
 	}
 
-	connectedResponse.ConnectedTimestamp =
-		common.TruncateTimestampToHour(common.GetCurrentTimestamp())
-
 	responsePayload, err := json.Marshal(connectedResponse)
 	if err != nil {
 		return nil, common.ContextError(err)
@@ -446,7 +452,7 @@ var baseRequestParams = []requestParamSpec{
 	requestParamSpec{"client_session_id", isHexDigits, requestParamOptional | requestParamNotLogged},
 	requestParamSpec{"propagation_channel_id", isHexDigits, 0},
 	requestParamSpec{"sponsor_id", isHexDigits, 0},
-	requestParamSpec{"client_version", isDigits, 0},
+	requestParamSpec{"client_version", isIntString, 0},
 	requestParamSpec{"client_platform", isClientPlatform, 0},
 	requestParamSpec{"relay_protocol", isRelayProtocol, 0},
 	requestParamSpec{"tunnel_whole_device", isBooleanFlag, requestParamOptional},
@@ -491,6 +497,26 @@ func validateRequestParams(
 	return nil
 }
 
+// copyBaseRequestParams makes a copy of the params which
+// includes only the baseRequestParams.
+func copyBaseRequestParams(params requestJSONObject) requestJSONObject {
+
+	// Note: not a deep copy; assumes baseRequestParams values
+	// are all scalar types (int, string, etc.)
+
+	paramsCopy := make(requestJSONObject)
+	for _, baseParam := range baseRequestParams {
+		value := params[baseParam.name]
+		if value == nil {
+			continue
+		}
+
+		paramsCopy[baseParam.name] = value
+	}
+
+	return paramsCopy
+}
+
 func validateStringRequestParam(
 	support *SupportServices,
 	expectedParam requestParamSpec,
@@ -549,6 +575,10 @@ func getRequestLogFields(
 	logFields["client_city"] = strings.Replace(geoIPData.City, " ", "_", -1)
 	logFields["client_isp"] = strings.Replace(geoIPData.ISP, " ", "_", -1)
 
+	if params == nil {
+		return logFields
+	}
+
 	for _, expectedParam := range expectedParams {
 
 		if expectedParam.flags&requestParamNotLogged != 0 {
@@ -730,6 +760,11 @@ func isDigits(_ *SupportServices, value string) bool {
 	})
 }
 
+func isIntString(_ *SupportServices, value string) bool {
+	_, err := strconv.Atoi(value)
+	return err == nil
+}
+
 func isClientPlatform(_ *SupportServices, value string) bool {
 	return -1 == strings.IndexFunc(value, func(c rune) bool {
 		// Note: stricter than psi_web's Python string.whitespace

+ 70 - 55
psiphon/server/config.go

@@ -102,6 +102,20 @@ type Config struct {
 	// authenticate itself to clients.
 	WebServerPrivateKey string
 
+	// WebServerPortForwardAddress specifies the expected network
+	// address ("<host>:<port>") specified in a client's port forward
+	// HostToConnect and PortToConnect when the client is making a
+	// tunneled connection to the web server. This address is always
+	// exempted from validation against SSH_DISALLOWED_PORT_FORWARD_HOSTS
+	// and AllowTCPPorts/DenyTCPPorts.
+	WebServerPortForwardAddress string
+
+	// WebServerPortForwardRedirectAddress specifies an alternate
+	// destination address to be substituted and dialed instead of
+	// the original destination when the port forward destination is
+	// WebServerPortForwardAddress.
+	WebServerPortForwardRedirectAddress string
+
 	// TunnelProtocolPorts specifies which tunnel protocols to run
 	// and which ports to listen on for each protocol. Valid tunnel
 	// protocols include: "SSH", "OSSH", "UNFRONTED-MEEK-OSSH",
@@ -186,21 +200,6 @@ type Config struct {
 	// "nameserver" entry.
 	DNSResolverIPAddress string
 
-	// TCPPortForwardRedirects is a mapping from client port forward
-	// destination to an alternate destination address. When the client's
-	// port forward HostToConnect and PortToConnect matches a redirect,
-	// the redirect is substituted and dialed instead of the original
-	// destination.
-	//
-	// The redirect is applied after the original destination is
-	// validated against SSH_DISALLOWED_PORT_FORWARD_HOSTS and
-	// AllowTCPPorts/DenyTCPPorts. So the redirect may map to any
-	// otherwise prohibited destination.
-	//
-	// The redirect is applied after UDPInterceptUdpgwServerAddress is
-	// checked. So the redirect address will not be intercepted.
-	TCPPortForwardRedirects map[string]string
-
 	// LoadMonitorPeriodSeconds indicates how frequently to log server
 	// load information (number of connected clients per tunnel protocol,
 	// number of running goroutines, amount of memory allocated, etc.)
@@ -233,7 +232,7 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 	}
 
 	if config.ServerIPAddress == "" {
-		return nil, errors.New("ServerIPAddress is missing from config file")
+		return nil, errors.New("ServerIPAddress is required")
 	}
 
 	if config.WebServerPort > 0 && (config.WebServerSecret == "" || config.WebServerCertificate == "" ||
@@ -243,6 +242,24 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 			"Web server requires WebServerSecret, WebServerCertificate, WebServerPrivateKey")
 	}
 
+	if config.WebServerPortForwardAddress != "" {
+		if err := validateNetworkAddress(config.WebServerPortForwardAddress, false); err != nil {
+			return nil, errors.New("WebServerPortForwardAddress is invalid")
+		}
+	}
+
+	if config.WebServerPortForwardRedirectAddress != "" {
+
+		if config.WebServerPortForwardAddress == "" {
+			return nil, errors.New(
+				"WebServerPortForwardRedirectAddress requires WebServerPortForwardAddress")
+		}
+
+		if err := validateNetworkAddress(config.WebServerPortForwardRedirectAddress, false); err != nil {
+			return nil, errors.New("WebServerPortForwardRedirectAddress is invalid")
+		}
+	}
+
 	for tunnelProtocol, _ := range config.TunnelProtocolPorts {
 		if !common.Contains(common.SupportedTunnelProtocols, tunnelProtocol) {
 			return nil, fmt.Errorf("Unsupported tunnel protocol: %s", tunnelProtocol)
@@ -280,24 +297,6 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 		}
 	}
 
-	validateNetworkAddress := func(address string, requireIPaddress bool) error {
-		host, portStr, err := net.SplitHostPort(address)
-		if err != nil {
-			return err
-		}
-		if requireIPaddress && net.ParseIP(host) == nil {
-			return errors.New("host must be an IP address")
-		}
-		port, err := strconv.Atoi(portStr)
-		if err != nil {
-			return err
-		}
-		if port < 0 || port > 65535 {
-			return errors.New("invalid port")
-		}
-		return nil
-	}
-
 	if config.UDPInterceptUdpgwServerAddress != "" {
 		if err := validateNetworkAddress(config.UDPInterceptUdpgwServerAddress, true); err != nil {
 			return nil, fmt.Errorf("UDPInterceptUdpgwServerAddress is invalid: %s", err)
@@ -310,20 +309,27 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 		}
 	}
 
-	if config.TCPPortForwardRedirects != nil {
-		for destination, redirect := range config.TCPPortForwardRedirects {
-			if err := validateNetworkAddress(destination, false); err != nil {
-				return nil, fmt.Errorf("TCPPortForwardRedirects destination %s is invalid: %s", destination, err)
-			}
-			if err := validateNetworkAddress(redirect, false); err != nil {
-				return nil, fmt.Errorf("TCPPortForwardRedirects redirect %s is invalid: %s", redirect, err)
-			}
-		}
-	}
-
 	return &config, nil
 }
 
+func validateNetworkAddress(address string, requireIPaddress bool) error {
+	host, portStr, err := net.SplitHostPort(address)
+	if err != nil {
+		return err
+	}
+	if requireIPaddress && net.ParseIP(host) == nil {
+		return errors.New("host must be an IP address")
+	}
+	port, err := strconv.Atoi(portStr)
+	if err != nil {
+		return err
+	}
+	if port < 0 || port > 65535 {
+		return errors.New("invalid port")
+	}
+	return nil
+}
+
 // GenerateConfigParams specifies customizations to be applied to
 // a generated server config.
 type GenerateConfigParams struct {
@@ -380,7 +386,8 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 
 	// Web server config
 
-	var webServerSecret, webServerCertificate, webServerPrivateKey string
+	var webServerSecret, webServerCertificate,
+		webServerPrivateKey, webServerPortForwardAddress string
 
 	if params.WebServerPort != 0 {
 		var err error
@@ -393,6 +400,9 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 		if err != nil {
 			return nil, nil, nil, common.ContextError(err)
 		}
+
+		webServerPortForwardAddress = net.JoinHostPort(
+			params.ServerIPAddress, strconv.Itoa(params.WebServerPort))
 	}
 
 	// SSH config
@@ -482,6 +492,7 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 		WebServerSecret:                webServerSecret,
 		WebServerCertificate:           webServerCertificate,
 		WebServerPrivateKey:            webServerPrivateKey,
+		WebServerPortForwardAddress:    webServerPortForwardAddress,
 		SSHPrivateKey:                  string(sshPrivateKey),
 		SSHServerVersion:               sshServerVersion,
 		SSHUserName:                    sshUserName,
@@ -504,18 +515,22 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 		return nil, nil, nil, common.ContextError(err)
 	}
 
+	intPtr := func(i int) *int {
+		return &i
+	}
+
 	trafficRulesSet := &TrafficRulesSet{
 		DefaultRules: TrafficRules{
-			DefaultLimits: common.RateLimits{
-				DownstreamUnlimitedBytes: 0,
-				DownstreamBytesPerSecond: 0,
-				UpstreamUnlimitedBytes:   0,
-				UpstreamBytesPerSecond:   0,
+			RateLimits: RateLimits{
+				ReadUnthrottledBytes:  new(int64),
+				ReadBytesPerSecond:    new(int64),
+				WriteUnthrottledBytes: new(int64),
+				WriteBytesPerSecond:   new(int64),
 			},
-			IdleTCPPortForwardTimeoutMilliseconds: 30000,
-			IdleUDPPortForwardTimeoutMilliseconds: 30000,
-			MaxTCPPortForwardCount:                1024,
-			MaxUDPPortForwardCount:                32,
+			IdleTCPPortForwardTimeoutMilliseconds: intPtr(30000),
+			IdleUDPPortForwardTimeoutMilliseconds: intPtr(30000),
+			MaxTCPPortForwardCount:                intPtr(1024),
+			MaxUDPPortForwardCount:                intPtr(32),
 			AllowTCPPorts:                         nil,
 			AllowUDPPorts:                         nil,
 			DenyTCPPorts:                          nil,

+ 125 - 20
psiphon/server/server_test.go

@@ -54,6 +54,7 @@ func TestSSH(t *testing.T) {
 			tunnelProtocol:       "SSH",
 			enableSSHAPIRequests: true,
 			doHotReload:          false,
+			denyTrafficRules:     false,
 		})
 }
 
@@ -63,6 +64,7 @@ func TestOSSH(t *testing.T) {
 			tunnelProtocol:       "OSSH",
 			enableSSHAPIRequests: true,
 			doHotReload:          false,
+			denyTrafficRules:     false,
 		})
 }
 
@@ -72,6 +74,7 @@ func TestUnfrontedMeek(t *testing.T) {
 			tunnelProtocol:       "UNFRONTED-MEEK-OSSH",
 			enableSSHAPIRequests: true,
 			doHotReload:          false,
+			denyTrafficRules:     false,
 		})
 }
 
@@ -81,6 +84,7 @@ func TestUnfrontedMeekHTTPS(t *testing.T) {
 			tunnelProtocol:       "UNFRONTED-MEEK-HTTPS-OSSH",
 			enableSSHAPIRequests: true,
 			doHotReload:          false,
+			denyTrafficRules:     false,
 		})
 }
 
@@ -90,6 +94,7 @@ func TestWebTransportAPIRequests(t *testing.T) {
 			tunnelProtocol:       "OSSH",
 			enableSSHAPIRequests: false,
 			doHotReload:          false,
+			denyTrafficRules:     false,
 		})
 }
 
@@ -99,6 +104,17 @@ func TestHotReload(t *testing.T) {
 			tunnelProtocol:       "OSSH",
 			enableSSHAPIRequests: true,
 			doHotReload:          true,
+			denyTrafficRules:     false,
+		})
+}
+
+func TestDenyTrafficRules(t *testing.T) {
+	runServer(t,
+		&runServerConfig{
+			tunnelProtocol:       "OSSH",
+			enableSSHAPIRequests: true,
+			doHotReload:          true,
+			denyTrafficRules:     true,
 		})
 }
 
@@ -106,6 +122,7 @@ type runServerConfig struct {
 	tunnelProtocol       string
 	enableSSHAPIRequests bool
 	doHotReload          bool
+	denyTrafficRules     bool
 }
 
 func sendNotificationReceived(c chan<- struct{}) {
@@ -162,11 +179,15 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	psinetFilename := "psinet.json"
 	sponsorID, expectedHomepageURL := pavePsinetDatabaseFile(t, psinetFilename)
 
+	// Pave traffic rules file which exercises handshake parameter filtering
+	trafficRulesFilename := "traffic_rules.json"
+	paveTrafficRulesFile(t, trafficRulesFilename, sponsorID, runConfig.denyTrafficRules)
+
 	var serverConfig interface{}
 	json.Unmarshal(serverConfigJSON, &serverConfig)
 	serverConfig.(map[string]interface{})["GeoIPDatabaseFilename"] = ""
 	serverConfig.(map[string]interface{})["PsinetDatabaseFilename"] = psinetFilename
-	serverConfig.(map[string]interface{})["TrafficRulesFilename"] = ""
+	serverConfig.(map[string]interface{})["TrafficRulesFilename"] = trafficRulesFilename
 	serverConfig.(map[string]interface{})["LogLevel"] = "debug"
 
 	// 1 second is the minimum period; should be small enough to emit a log during the
@@ -348,22 +369,43 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 
 	// Test: tunneled web site fetch
 
-	makeTunneledWebRequest(t, localHTTPProxyPort)
+	err = makeTunneledWebRequest(t, localHTTPProxyPort)
+
+	if err == nil {
+		if runConfig.denyTrafficRules {
+			t.Fatalf("unexpected tunneled web request success")
+		}
+	} else {
+		if !runConfig.denyTrafficRules {
+			t.Fatalf("tunneled web request failed: %s", err)
+		}
+	}
 
 	// Test: tunneled UDP packet
 
 	udpgwServerAddress := serverConfig.(map[string]interface{})["UDPInterceptUdpgwServerAddress"].(string)
-	makeTunneledDNSRequest(t, localSOCKSProxyPort, udpgwServerAddress)
+
+	err = makeTunneledDNSRequest(t, localSOCKSProxyPort, udpgwServerAddress)
+
+	if err == nil {
+		if runConfig.denyTrafficRules {
+			t.Fatalf("unexpected tunneled DNS request success")
+		}
+	} else {
+		if !runConfig.denyTrafficRules {
+			t.Fatalf("tunneled DNS request failed: %s", err)
+		}
+	}
 }
 
-func makeTunneledWebRequest(t *testing.T, localHTTPProxyPort int) {
+func makeTunneledWebRequest(t *testing.T, localHTTPProxyPort int) error {
 
 	testUrl := "https://psiphon.ca"
 	roundTripTimeout := 30 * time.Second
 
 	proxyUrl, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", localHTTPProxyPort))
 	if err != nil {
-		t.Fatalf("error initializing proxied HTTP request: %s", err)
+		return fmt.Errorf("error initializing proxied HTTP request: %s", err)
 	}
 
 	httpClient := &http.Client{
@@ -375,20 +417,22 @@ func makeTunneledWebRequest(t *testing.T, localHTTPProxyPort int) {
 
 	response, err := httpClient.Get(testUrl)
 	if err != nil {
-		t.Fatalf("error sending proxied HTTP request: %s", err)
+		return fmt.Errorf("error sending proxied HTTP request: %s", err)
 	}
 
 	_, err = ioutil.ReadAll(response.Body)
 	if err != nil {
-		t.Fatalf("error reading proxied HTTP response: %s", err)
+		return fmt.Errorf("error reading proxied HTTP response: %s", err)
 	}
 	response.Body.Close()
+
+	return nil
 }
 
-func makeTunneledDNSRequest(t *testing.T, localSOCKSProxyPort int, udpgwServerAddress string) {
+func makeTunneledDNSRequest(t *testing.T, localSOCKSProxyPort int, udpgwServerAddress string) error {
 
 	testHostname := "psiphon.ca"
-	timeout := 10 * time.Second
+	timeout := 5 * time.Second
 
 	localUDPProxyAddress, err := net.ResolveUDPAddr("udp", "127.0.0.1:7301")
 	if err != nil {
@@ -399,7 +443,8 @@ func makeTunneledDNSRequest(t *testing.T, localSOCKSProxyPort int, udpgwServerAd
 
 		serverUDPConn, err := net.ListenUDP("udp", localUDPProxyAddress)
 		if err != nil {
-			t.Fatalf("ListenUDP failed: %s", err)
+			t.Logf("ListenUDP failed: %s", err)
+			return
 		}
 		defer serverUDPConn.Close()
 
@@ -408,19 +453,22 @@ func makeTunneledDNSRequest(t *testing.T, localSOCKSProxyPort int, udpgwServerAd
 		packetSize, clientAddr, err := serverUDPConn.ReadFromUDP(
 			buffer[udpgwPreambleSize:len(buffer)])
 		if err != nil {
-			t.Fatalf("serverUDPConn.Read failed: %s", err)
+			t.Logf("serverUDPConn.Read failed: %s", err)
+			return
 		}
 
 		socksProxyAddress := fmt.Sprintf("127.0.0.1:%d", localSOCKSProxyPort)
 
 		dialer, err := proxy.SOCKS5("tcp", socksProxyAddress, nil, proxy.Direct)
 		if err != nil {
-			t.Fatalf("proxy.SOCKS5 failed: %s", err)
+			t.Logf("proxy.SOCKS5 failed: %s", err)
+			return
 		}
 
 		socksTCPConn, err := dialer.Dial("tcp", udpgwServerAddress)
 		if err != nil {
-			t.Fatalf("dialer.Dial failed: %s", err)
+			t.Logf("dialer.Dial failed: %s", err)
+			return
 		}
 		defer socksTCPConn.Close()
 
@@ -433,22 +481,26 @@ func makeTunneledDNSRequest(t *testing.T, localSOCKSProxyPort int, udpgwServerAd
 			uint16(packetSize),
 			buffer)
 		if err != nil {
-			t.Fatalf("writeUdpgwPreamble failed: %s", err)
+			t.Logf("writeUdpgwPreamble failed: %s", err)
+			return
 		}
 
 		_, err = socksTCPConn.Write(buffer[0 : udpgwPreambleSize+packetSize])
 		if err != nil {
-			t.Fatalf("socksTCPConn.Write failed: %s", err)
+			t.Logf("socksTCPConn.Write failed: %s", err)
+			return
 		}
 
 		updgwProtocolMessage, err := readUdpgwMessage(socksTCPConn, buffer)
 		if err != nil {
-			t.Fatalf("readUdpgwMessage failed: %s", err)
+			t.Logf("readUdpgwMessage failed: %s", err)
+			return
 		}
 
 		_, err = serverUDPConn.WriteToUDP(updgwProtocolMessage.packet, clientAddr)
 		if err != nil {
-			t.Fatalf("serverUDPConn.Write failed: %s", err)
+			t.Logf("serverUDPConn.Write failed: %s", err)
+			return
 		}
 	}()
 
@@ -457,7 +509,7 @@ func makeTunneledDNSRequest(t *testing.T, localSOCKSProxyPort int, udpgwServerAd
 
 	clientUDPConn, err := net.DialUDP("udp", nil, localUDPProxyAddress)
 	if err != nil {
-		t.Fatalf("DialUDP failed: %s", err)
+		return fmt.Errorf("DialUDP failed: %s", err)
 	}
 	defer clientUDPConn.Close()
 
@@ -466,8 +518,10 @@ func makeTunneledDNSRequest(t *testing.T, localSOCKSProxyPort int, udpgwServerAd
 
 	_, _, err = psiphon.ResolveIP(testHostname, clientUDPConn)
 	if err != nil {
-		t.Fatalf("ResolveIP failed: %s", err)
+		return fmt.Errorf("ResolveIP failed: %s", err)
 	}
+
+	return nil
 }
 
 func pavePsinetDatabaseFile(t *testing.T, psinetFilename string) (string, string) {
@@ -498,8 +552,59 @@ func pavePsinetDatabaseFile(t *testing.T, psinetFilename string) (string, string
 
 	err := ioutil.WriteFile(psinetFilename, []byte(psinetJSON), 0600)
 	if err != nil {
-		t.Fatalf("error paving psinet database: %s", err)
+		t.Fatalf("error paving psinet database file: %s", err)
 	}
 
 	return sponsorID, expectedHomepageURL
 }
+
+func paveTrafficRulesFile(t *testing.T, trafficRulesFilename, sponsorID string, deny bool) {
+
+	allowTCPPort := "443"
+	allowUDPPort := "53"
+
+	if deny {
+		allowTCPPort = "0"
+		allowUDPPort = "0"
+	}
+
+	trafficRulesJSONFormat := `
+    {
+        "DefaultRules" :  {
+            "RateLimits" : {
+                "ReadBytesPerSecond": 16384,
+                "WriteBytesPerSecond": 16384
+            },
+            "DenyTCPPorts" : [443],
+            "DenyUDPPorts" : [53]
+        },
+        "FilteredRules" : [
+            {
+                "Filter" : {
+                    "HandshakeParameters" : {
+                        "sponsor_id" : ["%s"]
+                    }
+                },
+                "Rules" : {
+                    "RateLimits" : {
+                        "ReadUnthrottledBytes": 132352,
+                        "WriteUnthrottledBytes": 132352
+                    },
+                    "AllowTCPPorts" : [%s],
+                    "DenyTCPPorts" : [],
+                    "AllowUDPPorts" : [%s],
+                    "DenyUDPPorts" : []
+                }
+            }
+        ]
+    }
+    `
+
+	trafficRulesJSON := fmt.Sprintf(
+		trafficRulesJSONFormat, sponsorID, allowTCPPort, allowUDPPort)
+
+	err := ioutil.WriteFile(trafficRulesFilename, []byte(trafficRulesJSON), 0600)
+	if err != nil {
+		t.Fatalf("error paving traffic rules file: %s", err)
+	}
+}

+ 6 - 0
psiphon/server/services.go

@@ -70,6 +70,8 @@ func RunServices(configJSON []byte) error {
 		return common.ContextError(err)
 	}
 
+	supportServices.TunnelServer = tunnelServer
+
 	if config.RunLoadMonitor() {
 		waitGroup.Add(1)
 		go func() {
@@ -130,6 +132,9 @@ loop:
 		select {
 		case <-reloadSupportServicesSignal:
 			supportServices.Reload()
+			// Reset traffic rules for established clients to reflect reloaded config
+			// TODO: only update when traffic rules config has changed
+			tunnelServer.ResetAllClientTrafficRules()
 		case <-logServerLoadSignal:
 			logServerLoad(tunnelServer)
 		case <-systemStopSignal:
@@ -187,6 +192,7 @@ type SupportServices struct {
 	PsinetDatabase  *psinet.Database
 	GeoIPService    *GeoIPService
 	DNSResolver     *DNSResolver
+	TunnelServer    *TunnelServer
 }
 
 // NewSupportServices initializes a new SupportServices.

+ 231 - 42
psiphon/server/trafficRules.go

@@ -22,7 +22,6 @@ package server
 import (
 	"encoding/json"
 	"io/ioutil"
-	"strings"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
@@ -30,55 +29,79 @@ import (
 // TrafficRulesSet represents the various traffic rules to
 // apply to Psiphon client tunnels. The Reload function supports
 // hot reloading of rules data while the server is running.
+//
+// For a given client, the traffic rules are determined by starting
+// with DefaultRules, then finding the first (if any)
+// FilteredTrafficRules match and overriding the defaults with fields
+// set in the selected FilteredTrafficRules.
 type TrafficRulesSet struct {
 	common.ReloadableFile
 
-	// DefaultRules specifies the traffic rules to be used when no
-	// regional-specific rules are set or apply to a particular
-	// client.
+	// DefaultRules are the base values to use as defaults for all
+	// clients.
 	DefaultRules TrafficRules
 
-	// RegionalRules specifies the traffic rules for particular client
-	// regions (countries) as determined by GeoIP lookup of the client
-	// IP address. The key for each regional traffic rule entry is one
-	// or more space delimited ISO 3166-1 alpha-2 country codes.
-	RegionalRules map[string]TrafficRules
+	// FilteredTrafficRules is an ordered list of filter/rules pairs.
+	// For each client, the first matching Filter in FilteredTrafficRules
+	// determines the additional Rules that are selected and applied
+	// on top of DefaultRules.
+	FilteredRules []struct {
+		Filter TrafficRulesFilter
+		Rules  TrafficRules
+	}
+}
+
+// TrafficRulesFilter defines a filter to match against client attributes.
+type TrafficRulesFilter struct {
+
+	// Protocols is a list of client tunnel protocols that must be in use
+	// to match this filter. When omitted or empty, any protocol matches.
+	Protocols []string
+
+	// Regions is a list of client GeoIP countries that the client must
+	// reolve to to match this filter. When omitted or empty, any client
+	// region matches.
+	Regions []string
+
+	// APIProtocol specifies whether the client must use the SSH
+	// API protocol (when "ssh") or the web API protocol (when "web").
+	// When omitted or blank, any API protocol matches.
+	APIProtocol string
+
+	// HandshakeParameters specifies handshake API parameter names and
+	// a list of values, one of which must be specified to match this
+	// filter. Only scalar string API parameters may be filtered.
+	HandshakeParameters map[string][]string
 }
 
 // TrafficRules specify the limits placed on client traffic.
 type TrafficRules struct {
-	// DefaultLimits are the rate limits to be applied when
-	// no protocol-specific rates are set.
-	DefaultLimits common.RateLimits
 
-	// ProtocolLimits specifies the rate limits for particular
-	// tunnel protocols. The key for each rate limit entry is one
-	// or more space delimited Psiphon tunnel protocol names. Valid
-	// tunnel protocols includes the same list as for
-	// TunnelProtocolPorts.
-	ProtocolLimits map[string]common.RateLimits
+	// RateLimits specifies data transfer rate limits for the
+	// client traffic.
+	RateLimits RateLimits
 
 	// IdleTCPPortForwardTimeoutMilliseconds is the timeout period
 	// after which idle (no bytes flowing in either direction)
 	// client TCP port forwards are preemptively closed.
 	// The default, 0, is no idle timeout.
-	IdleTCPPortForwardTimeoutMilliseconds int
+	IdleTCPPortForwardTimeoutMilliseconds *int
 
 	// IdleUDPPortForwardTimeoutMilliseconds is the timeout period
 	// after which idle (no bytes flowing in either direction)
 	// client UDP port forwards are preemptively closed.
 	// The default, 0, is no idle timeout.
-	IdleUDPPortForwardTimeoutMilliseconds int
+	IdleUDPPortForwardTimeoutMilliseconds *int
 
 	// MaxTCPPortForwardCount is the maximum number of TCP port
 	// forwards each client may have open concurrently.
 	// The default, 0, is no maximum.
-	MaxTCPPortForwardCount int
+	MaxTCPPortForwardCount *int
 
 	// MaxUDPPortForwardCount is the maximum number of UDP port
 	// forwards each client may have open concurrently.
 	// The default, 0, is no maximum.
-	MaxUDPPortForwardCount int
+	MaxUDPPortForwardCount *int
 
 	// AllowTCPPorts specifies a whitelist of TCP ports that
 	// are permitted for port forwarding. When set, only ports
@@ -101,6 +124,29 @@ type TrafficRules struct {
 	DenyUDPPorts []int
 }
 
+// RateLimits is a clone of common.RateLimits with pointers
+// to fields to enable distinguishing between zero values and
+// omitted values in JSON serialized traffic rules.
+// See common.RateLimits for field descriptions.
+type RateLimits struct {
+	ReadUnthrottledBytes  *int64
+	ReadBytesPerSecond    *int64
+	WriteUnthrottledBytes *int64
+	WriteBytesPerSecond   *int64
+	CloseAfterExhausted   *bool
+}
+
+// CommonRateLimits converts a RateLimits to a common.RateLimits.
+func (rateLimits *RateLimits) CommonRateLimits() common.RateLimits {
+	return common.RateLimits{
+		ReadUnthrottledBytes:  *rateLimits.ReadUnthrottledBytes,
+		ReadBytesPerSecond:    *rateLimits.ReadBytesPerSecond,
+		WriteUnthrottledBytes: *rateLimits.WriteUnthrottledBytes,
+		WriteBytesPerSecond:   *rateLimits.WriteBytesPerSecond,
+		CloseAfterExhausted:   *rateLimits.CloseAfterExhausted,
+	}
+}
+
 // NewTrafficRulesSet initializes a TrafficRulesSet with
 // the rules data in the specified config file.
 func NewTrafficRulesSet(filename string) (*TrafficRulesSet, error) {
@@ -133,35 +179,178 @@ func NewTrafficRulesSet(filename string) (*TrafficRulesSet, error) {
 	return set, nil
 }
 
-// GetTrafficRules looks up the traffic rules for the specified country. If there
-// are no regional TrafficRules for the country, default TrafficRules are returned.
-func (set *TrafficRulesSet) GetTrafficRules(clientCountryCode string) TrafficRules {
+// GetTrafficRules determines the traffic rules for a client based on its attributes.
+// For the return value TrafficRules, all pointer and slice fields are initialized,
+// so nil checks are not required. The caller must not modify the returned TrafficRules.
+func (set *TrafficRulesSet) GetTrafficRules(
+	tunnelProtocol string, geoIPData GeoIPData, state handshakeState) TrafficRules {
+
 	set.ReloadableFile.RLock()
 	defer set.ReloadableFile.RUnlock()
 
+	// Start with a copy of the DefaultRules, and then select the first
+	// matches Rules from FilteredTrafficRules, taking only the explicitly
+	// specified fields from that Rules.
+	//
+	// Notes:
+	// - Scalar pointers are used in TrafficRules and RateLimits to distinguish between
+	//   omitted fields (in serialized JSON) and default values. For example, if a filtered
+	//   Rules specifies a field value of 0, this will override the default; but if the
+	//   serialized filtered rule omits the field, the default is to be retained.
+	// - We use shallow copies and slices and scalar pointers are shared between the
+	//   return value TrafficRules, so callers must treat the return value as immutable.
+	//   This also means that these slices and pointers can remain referenced in memory even
+	//   after a hot reload.
+
+	trafficRules := set.DefaultRules
+
+	// Populate defaults for omitted DefaultRules fields
+
+	if trafficRules.RateLimits.ReadUnthrottledBytes == nil {
+		trafficRules.RateLimits.ReadUnthrottledBytes = new(int64)
+	}
+
+	if trafficRules.RateLimits.ReadBytesPerSecond == nil {
+		trafficRules.RateLimits.ReadBytesPerSecond = new(int64)
+	}
+
+	if trafficRules.RateLimits.WriteUnthrottledBytes == nil {
+		trafficRules.RateLimits.WriteUnthrottledBytes = new(int64)
+	}
+
+	if trafficRules.RateLimits.WriteBytesPerSecond == nil {
+		trafficRules.RateLimits.WriteBytesPerSecond = new(int64)
+	}
+
+	if trafficRules.RateLimits.CloseAfterExhausted == nil {
+		trafficRules.RateLimits.CloseAfterExhausted = new(bool)
+	}
+
+	if trafficRules.IdleTCPPortForwardTimeoutMilliseconds == nil {
+		trafficRules.IdleTCPPortForwardTimeoutMilliseconds = new(int)
+	}
+
+	if trafficRules.IdleUDPPortForwardTimeoutMilliseconds == nil {
+		trafficRules.IdleUDPPortForwardTimeoutMilliseconds = new(int)
+	}
+
+	if trafficRules.MaxTCPPortForwardCount == nil {
+		trafficRules.MaxTCPPortForwardCount = new(int)
+	}
+
+	if trafficRules.MaxUDPPortForwardCount == nil {
+		trafficRules.MaxUDPPortForwardCount = new(int)
+	}
+
+	if trafficRules.AllowTCPPorts == nil {
+		trafficRules.AllowTCPPorts = make([]int, 0)
+	}
+
+	if trafficRules.AllowUDPPorts == nil {
+		trafficRules.AllowUDPPorts = make([]int, 0)
+	}
+
+	if trafficRules.DenyTCPPorts == nil {
+		trafficRules.DenyTCPPorts = make([]int, 0)
+	}
+
+	if trafficRules.DenyUDPPorts == nil {
+		trafficRules.DenyUDPPorts = make([]int, 0)
+	}
+
 	// TODO: faster lookup?
-	for countryCodes, trafficRules := range set.RegionalRules {
-		for _, countryCode := range strings.Split(countryCodes, " ") {
-			if countryCode == clientCountryCode {
-				return trafficRules
+	for _, filteredRules := range set.FilteredRules {
+
+		if len(filteredRules.Filter.Protocols) > 0 {
+			if !common.Contains(filteredRules.Filter.Protocols, tunnelProtocol) {
+				continue
 			}
 		}
-	}
-	return set.DefaultRules
-}
 
-// GetRateLimits looks up the rate limits for the specified tunnel protocol.
-// If there are no specific RateLimits for the protocol, default RateLimits are
-// returned.
-func (rules *TrafficRules) GetRateLimits(clientTunnelProtocol string) common.RateLimits {
+		if len(filteredRules.Filter.Regions) > 0 {
+			if !common.Contains(filteredRules.Filter.Regions, geoIPData.Country) {
+				continue
+			}
+		}
 
-	// TODO: faster lookup?
-	for tunnelProtocols, rateLimits := range rules.ProtocolLimits {
-		for _, tunnelProtocol := range strings.Split(tunnelProtocols, " ") {
-			if tunnelProtocol == clientTunnelProtocol {
-				return rateLimits
+		if filteredRules.Filter.APIProtocol != "" {
+			if !state.completed {
+				continue
 			}
+			if state.apiProtocol != filteredRules.Filter.APIProtocol {
+				continue
+			}
+		}
+
+		if filteredRules.Filter.HandshakeParameters != nil {
+			if !state.completed {
+				continue
+			}
+
+			for name, values := range filteredRules.Filter.HandshakeParameters {
+				clientValue, err := getStringRequestParam(state.apiParams, name)
+				if err != nil || !common.Contains(values, clientValue) {
+					continue
+				}
+			}
+		}
+
+		// This is the first match. Override defaults using provided fields from selected rules, and return result.
+
+		if filteredRules.Rules.RateLimits.ReadUnthrottledBytes != nil {
+			trafficRules.RateLimits.ReadUnthrottledBytes = filteredRules.Rules.RateLimits.ReadUnthrottledBytes
+		}
+
+		if filteredRules.Rules.RateLimits.ReadBytesPerSecond != nil {
+			trafficRules.RateLimits.ReadBytesPerSecond = filteredRules.Rules.RateLimits.ReadBytesPerSecond
+		}
+
+		if filteredRules.Rules.RateLimits.WriteUnthrottledBytes != nil {
+			trafficRules.RateLimits.WriteUnthrottledBytes = filteredRules.Rules.RateLimits.WriteUnthrottledBytes
+		}
+
+		if filteredRules.Rules.RateLimits.WriteBytesPerSecond != nil {
+			trafficRules.RateLimits.WriteBytesPerSecond = filteredRules.Rules.RateLimits.WriteBytesPerSecond
 		}
+
+		if filteredRules.Rules.RateLimits.CloseAfterExhausted != nil {
+			trafficRules.RateLimits.CloseAfterExhausted = filteredRules.Rules.RateLimits.CloseAfterExhausted
+		}
+
+		if filteredRules.Rules.IdleTCPPortForwardTimeoutMilliseconds != nil {
+			trafficRules.IdleTCPPortForwardTimeoutMilliseconds = filteredRules.Rules.IdleTCPPortForwardTimeoutMilliseconds
+		}
+
+		if filteredRules.Rules.IdleUDPPortForwardTimeoutMilliseconds != nil {
+			trafficRules.IdleUDPPortForwardTimeoutMilliseconds = filteredRules.Rules.IdleUDPPortForwardTimeoutMilliseconds
+		}
+
+		if filteredRules.Rules.MaxTCPPortForwardCount != nil {
+			trafficRules.MaxTCPPortForwardCount = filteredRules.Rules.MaxTCPPortForwardCount
+		}
+
+		if filteredRules.Rules.MaxUDPPortForwardCount != nil {
+			trafficRules.MaxUDPPortForwardCount = filteredRules.Rules.MaxUDPPortForwardCount
+		}
+
+		if filteredRules.Rules.AllowTCPPorts != nil {
+			trafficRules.AllowTCPPorts = filteredRules.Rules.AllowTCPPorts
+		}
+
+		if filteredRules.Rules.AllowUDPPorts != nil {
+			trafficRules.AllowUDPPorts = filteredRules.Rules.AllowUDPPorts
+		}
+
+		if filteredRules.Rules.DenyTCPPorts != nil {
+			trafficRules.DenyTCPPorts = filteredRules.Rules.DenyTCPPorts
+		}
+
+		if filteredRules.Rules.DenyUDPPorts != nil {
+			trafficRules.DenyUDPPorts = filteredRules.Rules.DenyUDPPorts
+		}
+
+		break
 	}
-	return rules.DefaultLimits
+
+	return trafficRules
 }

+ 274 - 106
psiphon/server/tunnelServer.go

@@ -81,14 +81,6 @@ func NewTunnelServer(
 	}, nil
 }
 
-// GetLoadStats returns load stats for the tunnel server. The stats are
-// broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
-// include current connected client count, total number of current port
-// forwards.
-func (server *TunnelServer) GetLoadStats() map[string]map[string]int64 {
-	return server.sshServer.getLoadStats()
-}
-
 // Run runs the tunnel server; this function blocks while running a selection of
 // listeners that handle connection using various obfuscation protocols.
 //
@@ -192,17 +184,40 @@ func (server *TunnelServer) Run() error {
 	return err
 }
 
-type sshClientID uint64
+// GetLoadStats returns load stats for the tunnel server. The stats are
+// broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
+// include current connected client count, total number of current port
+// forwards.
+func (server *TunnelServer) GetLoadStats() map[string]map[string]int64 {
+	return server.sshServer.getLoadStats()
+}
+
+// ResetAllClientTrafficRules resets all established client traffic rules
+// to use the latest server config and client state.
+func (server *TunnelServer) ResetAllClientTrafficRules() {
+	server.sshServer.resetAllClientTrafficRules()
+}
+
+// SetClientHandshakeState sets the handshake state -- that it completed and
+// what paramaters were passed -- in sshClient. This state is used for allowing
+// port forwards and for future traffic rule selection. SetClientHandshakeState
+// also triggers an immediate traffic rule re-selection, as the rules selected
+// upon tunnel establishment may no longer apply now that handshake values are
+// set.
+func (server *TunnelServer) SetClientHandshakeState(
+	sessionID string, state handshakeState) error {
+
+	return server.sshServer.setClientHandshakeState(sessionID, state)
+}
 
 type sshServer struct {
 	support              *SupportServices
 	shutdownBroadcast    <-chan struct{}
 	sshHostKey           ssh.Signer
-	nextClientID         sshClientID
 	clientsMutex         sync.Mutex
 	stoppingClients      bool
 	acceptedClientCounts map[string]int64
-	clients              map[sshClientID]*sshClient
+	clients              map[string]*sshClient
 }
 
 func newSSHServer(
@@ -224,9 +239,8 @@ func newSSHServer(
 		support:              support,
 		shutdownBroadcast:    shutdownBroadcast,
 		sshHostKey:           signer,
-		nextClientID:         1,
 		acceptedClientCounts: make(map[string]int64),
-		clients:              make(map[sshClientID]*sshClient),
+		clients:              make(map[string]*sshClient),
 	}, nil
 }
 
@@ -321,28 +335,38 @@ func (sshServer *sshServer) unregisterAcceptedClient(tunnelProtocol string) {
 // An established client has completed its SSH handshake and has a ssh.Conn. Registration is
 // for tracking the number of fully established clients and for maintaining a list of running
 // clients (for stopping at shutdown time).
-func (sshServer *sshServer) registerEstablishedClient(client *sshClient) (sshClientID, bool) {
+func (sshServer *sshServer) registerEstablishedClient(client *sshClient) bool {
 
 	sshServer.clientsMutex.Lock()
 	defer sshServer.clientsMutex.Unlock()
 
 	if sshServer.stoppingClients {
-		return 0, false
+		return false
 	}
 
-	clientID := sshServer.nextClientID
-	sshServer.nextClientID += 1
+	// In the case of a duplicate client sessionID, the previous client is closed.
+	// - Well-behaved clients generate pick a random sessionID that should be
+	//   unique (won't accidentally conflict) and hard to guess (can't be targetted
+	//   by a malicious client).
+	// - Clients reuse the same sessionID when a tunnel is unexpectedly disconnected
+	//   and resestablished. In this case, when the same server is selected, this logic
+	//   will be hit; closing the old, dangling client is desirable.
+	// - Multi-tunnel clients should not normally use one server for multiple tunnels.
+	existingClient := sshServer.clients[client.sessionID]
+	if existingClient != nil {
+		existingClient.stop()
+	}
 
-	sshServer.clients[clientID] = client
+	sshServer.clients[client.sessionID] = client
 
-	return clientID, true
+	return true
 }
 
-func (sshServer *sshServer) unregisterEstablishedClient(clientID sshClientID) {
+func (sshServer *sshServer) unregisterEstablishedClient(sessionID string) {
 
 	sshServer.clientsMutex.Lock()
-	client := sshServer.clients[clientID]
-	delete(sshServer.clients, clientID)
+	client := sshServer.clients[sessionID]
+	delete(sshServer.clients, sessionID)
 	sshServer.clientsMutex.Unlock()
 
 	if client != nil {
@@ -400,12 +424,47 @@ func (sshServer *sshServer) getLoadStats() map[string]map[string]int64 {
 	return loadStats
 }
 
+func (sshServer *sshServer) resetAllClientTrafficRules() {
+
+	sshServer.clientsMutex.Lock()
+	clients := make(map[string]*sshClient)
+	for sessionID, client := range sshServer.clients {
+		clients[sessionID] = client
+	}
+	sshServer.clientsMutex.Unlock()
+
+	for _, client := range clients {
+		client.setTrafficRules()
+	}
+}
+
+func (sshServer *sshServer) setClientHandshakeState(
+	sessionID string, state handshakeState) error {
+
+	sshServer.clientsMutex.Lock()
+	client := sshServer.clients[sessionID]
+	sshServer.clientsMutex.Unlock()
+
+	if client == nil {
+		return common.ContextError(errors.New("unknown session ID"))
+	}
+
+	err := client.setHandshakeState(state)
+	if err != nil {
+		return common.ContextError(err)
+	}
+
+	client.setTrafficRules()
+
+	return nil
+}
+
 func (sshServer *sshServer) stopClients() {
 
 	sshServer.clientsMutex.Lock()
 	sshServer.stoppingClients = true
 	clients := sshServer.clients
-	sshServer.clients = make(map[sshClientID]*sshClient)
+	sshServer.clients = make(map[string]*sshClient)
 	sshServer.clientsMutex.Unlock()
 
 	for _, client := range clients {
@@ -421,13 +480,10 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 	geoIPData := sshServer.support.GeoIPService.Lookup(
 		common.IPAddressFromAddr(clientConn.RemoteAddr()))
 
-	// TODO: apply reload of TrafficRulesSet to existing clients
+	sshClient := newSshClient(sshServer, tunnelProtocol, geoIPData)
 
-	sshClient := newSshClient(
-		sshServer,
-		tunnelProtocol,
-		geoIPData,
-		sshServer.support.TrafficRulesSet.GetTrafficRules(geoIPData.Country))
+	// Set initial traffic rules, pre-handshake, based on currently known info.
+	sshClient.setTrafficRules()
 
 	// Wrap the base client connection with an ActivityMonitoredConn which will
 	// terminate the connection if no data is received before the deadline. This
@@ -450,8 +506,8 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 
 	// Further wrap the connection in a rate limiting ThrottledConn.
 
-	clientConn = common.NewThrottledConn(
-		clientConn, sshClient.trafficRules.GetRateLimits(tunnelProtocol))
+	throttledConn := common.NewThrottledConn(clientConn, sshClient.rateLimits())
+	clientConn = throttledConn
 
 	// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
 	// respect shutdownBroadcast and implement a specific handshake timeout.
@@ -529,15 +585,15 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 	sshClient.Lock()
 	sshClient.sshConn = result.sshConn
 	sshClient.activityConn = activityConn
+	sshClient.throttledConn = throttledConn
 	sshClient.Unlock()
 
-	clientID, ok := sshServer.registerEstablishedClient(sshClient)
-	if !ok {
+	if !sshServer.registerEstablishedClient(sshClient) {
 		clientConn.Close()
 		log.WithContext().Warning("register failed")
 		return
 	}
-	defer sshServer.unregisterEstablishedClient(clientID)
+	defer sshServer.unregisterEstablishedClient(sshClient.sessionID)
 
 	sshClient.runClient(result.channels, result.requests)
 
@@ -551,12 +607,14 @@ type sshClient struct {
 	tunnelProtocol          string
 	sshConn                 ssh.Conn
 	activityConn            *common.ActivityMonitoredConn
+	throttledConn           *common.ThrottledConn
 	geoIPData               GeoIPData
-	psiphonSessionID        string
+	sessionID               string
+	handshakeState          handshakeState
 	udpChannel              ssh.Channel
 	trafficRules            TrafficRules
-	tcpTrafficState         *trafficState
-	udpTrafficState         *trafficState
+	tcpTrafficState         trafficState
+	udpTrafficState         trafficState
 	channelHandlerWaitGroup *sync.WaitGroup
 	tcpPortForwardLRU       *common.LRUConns
 	stopBroadcast           chan struct{}
@@ -573,15 +631,18 @@ type trafficState struct {
 	totalPortForwardCount          int64
 }
 
+type handshakeState struct {
+	completed   bool
+	apiProtocol string
+	apiParams   requestJSONObject
+}
+
 func newSshClient(
-	sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData, trafficRules TrafficRules) *sshClient {
+	sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
 	return &sshClient{
 		sshServer:               sshServer,
 		tunnelProtocol:          tunnelProtocol,
 		geoIPData:               geoIPData,
-		trafficRules:            trafficRules,
-		tcpTrafficState:         &trafficState{},
-		udpTrafficState:         &trafficState{},
 		channelHandlerWaitGroup: new(sync.WaitGroup),
 		tcpPortForwardLRU:       common.NewLRUConns(),
 		stopBroadcast:           make(chan struct{}),
@@ -590,6 +651,9 @@ func newSshClient(
 
 func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
 
+	expectedSessionIDLength := 2 * common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH
+	expectedSSHPasswordLength := 2 * SSH_PASSWORD_BYTE_LENGTH
+
 	var sshPasswordPayload struct {
 		SessionId   string `json:"SessionId"`
 		SshPassword string `json:"SshPassword"`
@@ -601,15 +665,16 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
 		// send the hex encoded session ID prepended to the SSH password.
 		// Note: there's an even older case where clients don't send any session ID,
 		// but that's no longer supported.
-		if len(password) == 2*common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH+2*SSH_PASSWORD_BYTE_LENGTH {
-			sshPasswordPayload.SessionId = string(password[0 : 2*common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH])
-			sshPasswordPayload.SshPassword = string(password[2*common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH : len(password)])
+		if len(password) == expectedSessionIDLength+expectedSSHPasswordLength {
+			sshPasswordPayload.SessionId = string(password[0:expectedSessionIDLength])
+			sshPasswordPayload.SshPassword = string(password[expectedSSHPasswordLength:len(password)])
 		} else {
 			return nil, common.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
 		}
 	}
 
-	if !isHexDigits(sshClient.sshServer.support, sshPasswordPayload.SessionId) {
+	if !isHexDigits(sshClient.sshServer.support, sshPasswordPayload.SessionId) ||
+		len(sshPasswordPayload.SessionId) != expectedSessionIDLength {
 		return nil, common.ContextError(fmt.Errorf("invalid session ID for %q", conn.User()))
 	}
 
@@ -623,17 +688,18 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
 		return nil, common.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
 	}
 
-	psiphonSessionID := sshPasswordPayload.SessionId
+	sessionID := sshPasswordPayload.SessionId
 
 	sshClient.Lock()
-	sshClient.psiphonSessionID = psiphonSessionID
+	sshClient.sessionID = sessionID
 	geoIPData := sshClient.geoIPData
 	sshClient.Unlock()
 
 	// Store the GeoIP data associated with the session ID. This makes the GeoIP data
-	// available to the web server for web transport Psiphon API requests.
-	sshClient.sshServer.support.GeoIPService.SetSessionCache(
-		psiphonSessionID, geoIPData)
+	// available to the web server for web transport Psiphon API requests. To allow for
+	// post-tunnel final status requests, the lifetime of cached GeoIP records exceeds
+	// the lifetime of the sshClient, and that's why this distinct session cache exists.
+	sshClient.sshServer.support.GeoIPService.SetSessionCache(sessionID, geoIPData)
 
 	return nil, nil
 }
@@ -693,24 +759,30 @@ func (sshClient *sshClient) stop() {
 	// request with an EOF flag set.)
 
 	sshClient.Lock()
-	log.WithContextFields(
-		LogFields{
-			"startTime":                         sshClient.activityConn.GetStartTime(),
-			"duration":                          sshClient.activityConn.GetActiveDuration(),
-			"psiphonSessionID":                  sshClient.psiphonSessionID,
-			"country":                           sshClient.geoIPData.Country,
-			"city":                              sshClient.geoIPData.City,
-			"ISP":                               sshClient.geoIPData.ISP,
-			"bytesUpTCP":                        sshClient.tcpTrafficState.bytesUp,
-			"bytesDownTCP":                      sshClient.tcpTrafficState.bytesDown,
-			"peakConcurrentPortForwardCountTCP": sshClient.tcpTrafficState.peakConcurrentPortForwardCount,
-			"totalPortForwardCountTCP":          sshClient.tcpTrafficState.totalPortForwardCount,
-			"bytesUpUDP":                        sshClient.udpTrafficState.bytesUp,
-			"bytesDownUDP":                      sshClient.udpTrafficState.bytesDown,
-			"peakConcurrentPortForwardCountUDP": sshClient.udpTrafficState.peakConcurrentPortForwardCount,
-			"totalPortForwardCountUDP":          sshClient.udpTrafficState.totalPortForwardCount,
-		}).Info("tunnel closed")
+
+	logFields := getRequestLogFields(
+		sshClient.sshServer.support,
+		"server_tunnel",
+		sshClient.geoIPData,
+		sshClient.handshakeState.apiParams,
+		baseRequestParams)
+
+	// TODO: match legacy log field naming convention?
+	logFields["HandshakeCompleted"] = sshClient.handshakeState.completed
+	logFields["startTime"] = sshClient.activityConn.GetStartTime()
+	logFields["Duration"] = sshClient.activityConn.GetActiveDuration()
+	logFields["BytesUpTCP"] = sshClient.tcpTrafficState.bytesUp
+	logFields["BytesDownTCP"] = sshClient.tcpTrafficState.bytesDown
+	logFields["PeakConcurrentPortForwardCountTCP"] = sshClient.tcpTrafficState.peakConcurrentPortForwardCount
+	logFields["TotalPortForwardCountTCP"] = sshClient.tcpTrafficState.totalPortForwardCount
+	logFields["BytesUpUDP"] = sshClient.udpTrafficState.bytesUp
+	logFields["BytesDownUDP"] = sshClient.udpTrafficState.bytesDown
+	logFields["PeakConcurrentPortForwardCountUDP"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
+	logFields["TotalPortForwardCountUDP"] = sshClient.udpTrafficState.totalPortForwardCount
+
 	sshClient.Unlock()
+
+	log.LogRawFieldsWithTimestamp(logFields)
 }
 
 // runClient handles/dispatches new channel and new requests from the client.
@@ -812,13 +884,87 @@ func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChanne
 	}
 }
 
+// setHandshakeState records that a client has completed a handshake API request.
+// Some parameters from the handshake request may be used in future traffic rule
+// selection. Port forwards are disallowed until a handshake is complete. The
+// handshake parameters are included in the session summary log recorded in
+// sshClient.stop().
+func (sshClient *sshClient) setHandshakeState(state handshakeState) error {
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	// Client must only perform one handshake
+	if sshClient.handshakeState.completed {
+		return common.ContextError(errors.New("handshake already completed"))
+	}
+
+	sshClient.handshakeState = state
+
+	return nil
+}
+
+// setTrafficRules resets the client's traffic rules based on the latest server config
+// and client state. As sshClient.trafficRules may be reset by a concurrent goroutine,
+// trafficRules must only be accessed within the sshClient mutex.
+func (sshClient *sshClient) setTrafficRules() {
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
+		sshClient.tunnelProtocol, sshClient.geoIPData, sshClient.handshakeState)
+}
+
+func (sshClient *sshClient) rateLimits() common.RateLimits {
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	return sshClient.trafficRules.RateLimits.CommonRateLimits()
+}
+
+func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration {
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	return time.Duration(*sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds) * time.Millisecond
+}
+
+func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	return time.Duration(*sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds) * time.Millisecond
+}
+
+const (
+	portForwardTypeTCP = iota
+	portForwardTypeUDP
+)
+
 func (sshClient *sshClient) isPortForwardPermitted(
-	host string, port int, allowPorts []int, denyPorts []int) bool {
+	portForwardType int, host string, port int) bool {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	if !sshClient.handshakeState.completed {
+		return false
+	}
 
 	if common.Contains(SSH_DISALLOWED_PORT_FORWARD_HOSTS, host) {
 		return false
 	}
 
+	var allowPorts, denyPorts []int
+	if portForwardType == portForwardTypeTCP {
+		allowPorts = sshClient.trafficRules.AllowTCPPorts
+		denyPorts = sshClient.trafficRules.AllowTCPPorts
+	} else {
+		allowPorts = sshClient.trafficRules.AllowUDPPorts
+		denyPorts = sshClient.trafficRules.AllowUDPPorts
+
+	}
+
 	// TODO: faster lookup?
 	if len(allowPorts) > 0 {
 		for _, allowPort := range allowPorts {
@@ -841,37 +987,63 @@ func (sshClient *sshClient) isPortForwardPermitted(
 }
 
 func (sshClient *sshClient) isPortForwardLimitExceeded(
-	state *trafficState, maxPortForwardCount int) bool {
+	portForwardType int) (int, bool) {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	var maxPortForwardCount int
+	var state *trafficState
+	if portForwardType == portForwardTypeTCP {
+		maxPortForwardCount = *sshClient.trafficRules.MaxTCPPortForwardCount
+		state = &sshClient.tcpTrafficState
+	} else {
+		maxPortForwardCount = *sshClient.trafficRules.MaxUDPPortForwardCount
+		state = &sshClient.udpTrafficState
+	}
 
-	limitExceeded := false
-	if maxPortForwardCount > 0 {
-		sshClient.Lock()
-		limitExceeded = state.concurrentPortForwardCount >= int64(maxPortForwardCount)
-		sshClient.Unlock()
+	if maxPortForwardCount > 0 && state.concurrentPortForwardCount >= int64(maxPortForwardCount) {
+		return maxPortForwardCount, true
 	}
-	return limitExceeded
+	return maxPortForwardCount, false
 }
 
 func (sshClient *sshClient) openedPortForward(
-	state *trafficState) {
+	portForwardType int) {
 
 	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	var state *trafficState
+	if portForwardType == portForwardTypeTCP {
+		state = &sshClient.tcpTrafficState
+	} else {
+		state = &sshClient.udpTrafficState
+	}
+
 	state.concurrentPortForwardCount += 1
 	if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
 		state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
 	}
 	state.totalPortForwardCount += 1
-	sshClient.Unlock()
 }
 
 func (sshClient *sshClient) closedPortForward(
-	state *trafficState, bytesUp, bytesDown int64) {
+	portForwardType int, bytesUp, bytesDown int64) {
 
 	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	var state *trafficState
+	if portForwardType == portForwardTypeTCP {
+		state = &sshClient.tcpTrafficState
+	} else {
+		state = &sshClient.udpTrafficState
+	}
+
 	state.concurrentPortForwardCount -= 1
 	state.bytesUp += bytesUp
 	state.bytesDown += bytesDown
-	sshClient.Unlock()
 }
 
 func (sshClient *sshClient) handleTCPChannel(
@@ -879,37 +1051,35 @@ func (sshClient *sshClient) handleTCPChannel(
 	portToConnect int,
 	newChannel ssh.NewChannel) {
 
-	if !sshClient.isPortForwardPermitted(
-		hostToConnect,
-		portToConnect,
-		sshClient.trafficRules.AllowTCPPorts,
-		sshClient.trafficRules.DenyTCPPorts) {
+	isWebServerPortForward := false
+	config := sshClient.sshServer.support.Config
+	if config.WebServerPortForwardAddress != "" {
+		destination := net.JoinHostPort(hostToConnect, strconv.Itoa(portToConnect))
+		if destination == config.WebServerPortForwardAddress {
+			isWebServerPortForward = true
+			if config.WebServerPortForwardRedirectAddress != "" {
+				// Note: redirect format is validated when config is loaded
+				host, portStr, _ := net.SplitHostPort(config.WebServerPortForwardRedirectAddress)
+				port, _ := strconv.Atoi(portStr)
+				hostToConnect = host
+				portToConnect = port
+			}
+		}
+	}
+
+	if !isWebServerPortForward && !sshClient.isPortForwardPermitted(
+		portForwardTypeTCP, hostToConnect, portToConnect) {
 
 		sshClient.rejectNewChannel(
 			newChannel, ssh.Prohibited, "port forward not permitted")
 		return
 	}
 
-	// Note: redirects are applied *after* isPortForwardPermitted allows the original destination
-	if sshClient.sshServer.support.Config.TCPPortForwardRedirects != nil {
-		destination := net.JoinHostPort(hostToConnect, strconv.Itoa(portToConnect))
-		if redirect, ok := sshClient.sshServer.support.Config.TCPPortForwardRedirects[destination]; ok {
-			// Note: redirect format is validated when config is loaded
-			host, portStr, _ := net.SplitHostPort(redirect)
-			port, _ := strconv.Atoi(portStr)
-			hostToConnect = host
-			portToConnect = port
-			log.WithContextFields(LogFields{"destination": destination, "redirect": redirect}).Debug("port forward redirect")
-		}
-	}
-
 	var bytesUp, bytesDown int64
-	sshClient.openedPortForward(sshClient.tcpTrafficState)
+	sshClient.openedPortForward(portForwardTypeTCP)
 	defer func() {
 		sshClient.closedPortForward(
-			sshClient.tcpTrafficState,
-			atomic.LoadInt64(&bytesUp),
-			atomic.LoadInt64(&bytesDown))
+			portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
 	}()
 
 	// TOCTOU note: important to increment the port forward count (via
@@ -918,9 +1088,7 @@ func (sshClient *sshClient) handleTCPChannel(
 	// by initiating many port forwards concurrently.
 	// TODO: close LRU connection (after successful Dial) instead of
 	// rejecting new connection?
-	if sshClient.isPortForwardLimitExceeded(
-		sshClient.tcpTrafficState,
-		sshClient.trafficRules.MaxTCPPortForwardCount) {
+	if maxCount, exceeded := sshClient.isPortForwardLimitExceeded(portForwardTypeTCP); exceeded {
 
 		// Close the oldest TCP port forward. CloseOldest() closes
 		// the conn and the port forward's goroutine will complete
@@ -952,7 +1120,7 @@ func (sshClient *sshClient) handleTCPChannel(
 
 		log.WithContextFields(
 			LogFields{
-				"maxCount": sshClient.trafficRules.MaxTCPPortForwardCount,
+				"maxCount": maxCount,
 			}).Debug("closed LRU TCP port forward")
 	}
 
@@ -1015,7 +1183,7 @@ func (sshClient *sshClient) handleTCPChannel(
 
 	fwdConn, err = common.NewActivityMonitoredConn(
 		fwdConn,
-		time.Duration(sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds)*time.Millisecond,
+		sshClient.idleTCPPortForwardTimeout(),
 		true,
 		lruEntry)
 	if result.err != nil {

+ 8 - 15
psiphon/server/udp.go

@@ -28,7 +28,6 @@ import (
 	"runtime/debug"
 	"sync"
 	"sync/atomic"
-	"time"
 
 	"github.com/Psiphon-Inc/crypto/ssh"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
@@ -162,23 +161,18 @@ func (mux *udpPortForwardMultiplexer) run() {
 			}
 
 			if !mux.sshClient.isPortForwardPermitted(
-				dialIP.String(),
-				int(message.remotePort),
-				mux.sshClient.trafficRules.AllowUDPPorts,
-				mux.sshClient.trafficRules.DenyUDPPorts) {
+				portForwardTypeUDP, dialIP.String(), int(message.remotePort)) {
 				// The udpgw protocol has no error response, so
 				// we just discard the message and read another.
 				continue
 			}
 
-			mux.sshClient.openedPortForward(mux.sshClient.udpTrafficState)
+			mux.sshClient.openedPortForward(portForwardTypeUDP)
 			// Note: can't defer sshClient.closedPortForward() here
 
 			// TOCTOU note: important to increment the port forward count (via
 			// openPortForward) _before_ checking isPortForwardLimitExceeded
-			if mux.sshClient.isPortForwardLimitExceeded(
-				mux.sshClient.tcpTrafficState,
-				mux.sshClient.trafficRules.MaxUDPPortForwardCount) {
+			if maxCount, exceeded := mux.sshClient.isPortForwardLimitExceeded(portForwardTypeUDP); exceeded {
 
 				// Close the oldest UDP port forward. CloseOldest() closes
 				// the conn and the port forward's goroutine will complete
@@ -190,7 +184,7 @@ func (mux *udpPortForwardMultiplexer) run() {
 
 				log.WithContextFields(
 					LogFields{
-						"maxCount": mux.sshClient.trafficRules.MaxUDPPortForwardCount,
+						"maxCount": maxCount,
 					}).Debug("closed LRU UDP port forward")
 			}
 
@@ -203,7 +197,7 @@ func (mux *udpPortForwardMultiplexer) run() {
 			udpConn, err := net.DialUDP(
 				"udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort})
 			if err != nil {
-				mux.sshClient.closedPortForward(mux.sshClient.udpTrafficState, 0, 0)
+				mux.sshClient.closedPortForward(portForwardTypeUDP, 0, 0)
 				log.WithContextFields(LogFields{"error": err}).Warning("DialUDP failed")
 				continue
 			}
@@ -217,12 +211,12 @@ func (mux *udpPortForwardMultiplexer) run() {
 
 			conn, err := common.NewActivityMonitoredConn(
 				udpConn,
-				time.Duration(mux.sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds)*time.Millisecond,
+				mux.sshClient.idleUDPPortForwardTimeout(),
 				true,
 				lruEntry)
 			if err != nil {
 				lruEntry.Remove()
-				mux.sshClient.closedPortForward(mux.sshClient.udpTrafficState, 0, 0)
+				mux.sshClient.closedPortForward(portForwardTypeUDP, 0, 0)
 				log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
 				continue
 			}
@@ -354,8 +348,7 @@ func (portForward *udpPortForward) relayDownstream() {
 
 	bytesUp := atomic.LoadInt64(&portForward.bytesUp)
 	bytesDown := atomic.LoadInt64(&portForward.bytesDown)
-	portForward.mux.sshClient.closedPortForward(
-		portForward.mux.sshClient.udpTrafficState, bytesUp, bytesDown)
+	portForward.mux.sshClient.closedPortForward(portForwardTypeUDP, bytesUp, bytesDown)
 
 	log.WithContextFields(
 		LogFields{

+ 7 - 2
psiphon/server/webServer.go

@@ -36,8 +36,9 @@ import (
 const WEB_SERVER_IO_TIMEOUT = 10 * time.Second
 
 type webServer struct {
-	support  *SupportServices
-	serveMux *http.ServeMux
+	support      *SupportServices
+	tunnelServer *TunnelServer
+	serveMux     *http.ServeMux
 }
 
 // RunWebServer runs a web server which supports tunneled and untunneled
@@ -233,6 +234,7 @@ func (webServer *webServer) handshakeHandler(w http.ResponseWriter, r *http.Requ
 	if err == nil {
 		responsePayload, err = dispatchAPIRequestHandler(
 			webServer.support,
+			common.PSIPHON_WEB_API_PROTOCOL,
 			webServer.lookupGeoIPData(params),
 			common.PSIPHON_API_HANDSHAKE_REQUEST_NAME,
 			params)
@@ -262,6 +264,7 @@ func (webServer *webServer) connectedHandler(w http.ResponseWriter, r *http.Requ
 	if err == nil {
 		responsePayload, err = dispatchAPIRequestHandler(
 			webServer.support,
+			common.PSIPHON_WEB_API_PROTOCOL,
 			webServer.lookupGeoIPData(params),
 			common.PSIPHON_API_CONNECTED_REQUEST_NAME,
 			params)
@@ -284,6 +287,7 @@ func (webServer *webServer) statusHandler(w http.ResponseWriter, r *http.Request
 	if err == nil {
 		_, err = dispatchAPIRequestHandler(
 			webServer.support,
+			common.PSIPHON_WEB_API_PROTOCOL,
 			webServer.lookupGeoIPData(params),
 			common.PSIPHON_API_STATUS_REQUEST_NAME,
 			params)
@@ -306,6 +310,7 @@ func (webServer *webServer) clientVerificationHandler(w http.ResponseWriter, r *
 	if err == nil {
 		responsePayload, err = dispatchAPIRequestHandler(
 			webServer.support,
+			common.PSIPHON_WEB_API_PROTOCOL,
 			webServer.lookupGeoIPData(params),
 			common.PSIPHON_API_CLIENT_VERIFICATION_REQUEST_NAME,
 			params)

+ 8 - 15
psiphon/serverApi.go

@@ -87,7 +87,9 @@ func NewServerContext(tunnel *Tunnel, sessionId string) (*ServerContext, error)
 	// For legacy servers, set up psiphonHttpsClient for
 	// accessing the Psiphon API via the web service.
 	var psiphonHttpsClient *http.Client
-	if !tunnel.serverEntry.SupportsSSHAPIRequests() {
+	if !tunnel.serverEntry.SupportsSSHAPIRequests() ||
+		tunnel.config.TargetApiProtocol == common.PSIPHON_WEB_API_PROTOCOL {
+
 		var err error
 		psiphonHttpsClient, err = makePsiphonHttpsClient(tunnel)
 		if err != nil {
@@ -167,18 +169,11 @@ func (serverContext *ServerContext) doHandshakeRequest() error {
 		}
 	}
 
-	// Note:
-	// - 'preemptive_reconnect_lifetime_milliseconds' is currently unused
+	// Legacy fields:
+	// - 'preemptive_reconnect_lifetime_milliseconds' is unused and ignored
 	// - 'ssh_session_id' is ignored; client session ID is used instead
-	var handshakeResponse struct {
-		Homepages            []string            `json:"homepages"`
-		UpgradeClientVersion string              `json:"upgrade_client_version"`
-		PageViewRegexes      []map[string]string `json:"page_view_regexes"`
-		HttpsRequestRegexes  []map[string]string `json:"https_request_regexes"`
-		EncodedServerList    []string            `json:"encoded_server_list"`
-		ClientRegion         string              `json:"client_region"`
-		ServerTimestamp      string              `json:"server_timestamp"`
-	}
+
+	var handshakeResponse common.HandshakeResponse
 	err := json.Unmarshal(response, &handshakeResponse)
 	if err != nil {
 		return common.ContextError(err)
@@ -292,9 +287,7 @@ func (serverContext *ServerContext) DoConnectedRequest() error {
 		}
 	}
 
-	var connectedResponse struct {
-		ConnectedTimestamp string `json:"connected_timestamp"`
-	}
+	var connectedResponse common.ConnectedResponse
 	err = json.Unmarshal(response, &connectedResponse)
 	if err != nil {
 		return common.ContextError(err)