ソースを参照

Merged to master.

Miro Kuratczyk 9 年 前
コミット
5bdf3171ce

+ 21 - 0
AndroidLibrary/psi/psi.go

@@ -39,6 +39,7 @@ type PsiphonProvider interface {
 	GetSecondaryDnsServer() string
 }
 
+var controllerMutex sync.Mutex
 var controller *psiphon.Controller
 var shutdownBroadcast chan struct{}
 var controllerWaitGroup *sync.WaitGroup
@@ -48,6 +49,9 @@ func Start(
 	provider PsiphonProvider,
 	useDeviceBinder bool) error {
 
+	controllerMutex.Lock()
+	defer controllerMutex.Unlock()
+
 	if controller != nil {
 		return fmt.Errorf("already started")
 	}
@@ -106,6 +110,10 @@ func Start(
 }
 
 func Stop() {
+
+	controllerMutex.Lock()
+	defer controllerMutex.Unlock()
+
 	if controller != nil {
 		close(shutdownBroadcast)
 		controllerWaitGroup.Wait()
@@ -114,3 +122,16 @@ func Stop() {
 		controllerWaitGroup = nil
 	}
 }
+
+// This is a passthrough to Controller.SetClientVerificationPayload.
+// Note: should only be called after Start() and before Stop(); otherwise,
+// will silently take no action.
+func SetClientVerificationPayload(clientVerificationPayload string) {
+
+	controllerMutex.Lock()
+	defer controllerMutex.Unlock()
+
+	if controller != nil {
+		controller.SetClientVerificationPayload(clientVerificationPayload)
+	}
+}

+ 3 - 1
SampleApps/Psibot/app/src/main/java/ca/psiphon/PsiphonTunnel.java

@@ -348,7 +348,9 @@ public class PsiphonTunnel extends Psi.PsiphonProvider.Stub {
         // The Psiphon library won't be able to use its current working directory
         // and the standard temporary directories do not exist.
         json.put("DataStoreDirectory", context.getFilesDir());
-        json.put("DataStoreTempDirectory", context.getCacheDir());
+
+        File remoteServerListDownload = new File(context.getFilesDir(), "remote_server_list");
+        json.put("RemoteServerListDownloadFilename", remoteServerListDownload.getPath());
 
         // Note: onConnecting/onConnected logic assumes 1 tunnel connection
         json.put("TunnelPoolSize", 1);

+ 3 - 1
SampleApps/TunneledWebView/app/src/main/java/ca/psiphon/PsiphonTunnel.java

@@ -197,7 +197,9 @@ public class PsiphonTunnel extends Psi.PsiphonProvider.Stub {
         // The Psiphon library won't be able to use its current working directory
         // and the standard temporary directories do not exist.
         json.put("DataStoreDirectory", context.getFilesDir());
-        json.put("DataStoreTempDirectory", context.getCacheDir());
+
+        File remoteServerListDownload = new File(context.getFilesDir(), "remote_server_list");
+        json.put("RemoteServerListDownloadFilename", remoteServerListDownload.getPath());
 
         // Note: onConnecting/onConnected logic assumes 1 tunnel connection
         json.put("TunnelPoolSize", 1);

+ 42 - 41
psiphon/config.go

@@ -31,47 +31,48 @@ import (
 // TODO: allow all params to be configured
 
 const (
-	LEGACY_DATA_STORE_FILENAME                     = "psiphon.db"
-	DATA_STORE_FILENAME                            = "psiphon.boltdb"
-	CONNECTION_WORKER_POOL_SIZE                    = 10
-	TUNNEL_POOL_SIZE                               = 1
-	TUNNEL_CONNECT_TIMEOUT_SECONDS                 = 20
-	TUNNEL_OPERATE_SHUTDOWN_TIMEOUT                = 1 * time.Second
-	TUNNEL_PORT_FORWARD_DIAL_TIMEOUT_SECONDS       = 10
-	TUNNEL_SSH_KEEP_ALIVE_PAYLOAD_MAX_BYTES        = 256
-	TUNNEL_SSH_KEEP_ALIVE_PERIOD_MIN               = 60 * time.Second
-	TUNNEL_SSH_KEEP_ALIVE_PERIOD_MAX               = 120 * time.Second
-	TUNNEL_SSH_KEEP_ALIVE_PERIODIC_TIMEOUT_SECONDS = 30
-	TUNNEL_SSH_KEEP_ALIVE_PERIODIC_INACTIVE_PERIOD = 10 * time.Second
-	TUNNEL_SSH_KEEP_ALIVE_PROBE_TIMEOUT_SECONDS    = 5
-	TUNNEL_SSH_KEEP_ALIVE_PROBE_INACTIVE_PERIOD    = 10 * time.Second
-	ESTABLISH_TUNNEL_TIMEOUT_SECONDS               = 300
-	ESTABLISH_TUNNEL_WORK_TIME                     = 60 * time.Second
-	ESTABLISH_TUNNEL_PAUSE_PERIOD_SECONDS          = 5
-	ESTABLISH_TUNNEL_SERVER_AFFINITY_GRACE_PERIOD  = 1 * time.Second
-	HTTP_PROXY_ORIGIN_SERVER_TIMEOUT_SECONDS       = 15
-	HTTP_PROXY_MAX_IDLE_CONNECTIONS_PER_HOST       = 50
-	FETCH_REMOTE_SERVER_LIST_TIMEOUT_SECONDS       = 30
-	FETCH_REMOTE_SERVER_LIST_RETRY_PERIOD_SECONDS  = 30
-	FETCH_REMOTE_SERVER_LIST_STALE_PERIOD          = 6 * time.Hour
-	PSIPHON_API_CLIENT_SESSION_ID_LENGTH           = 16
-	PSIPHON_API_SERVER_TIMEOUT_SECONDS             = 20
-	PSIPHON_API_SHUTDOWN_SERVER_TIMEOUT            = 1 * time.Second
-	PSIPHON_API_STATUS_REQUEST_PERIOD_MIN          = 5 * time.Minute
-	PSIPHON_API_STATUS_REQUEST_PERIOD_MAX          = 10 * time.Minute
-	PSIPHON_API_STATUS_REQUEST_SHORT_PERIOD_MIN    = 5 * time.Second
-	PSIPHON_API_STATUS_REQUEST_SHORT_PERIOD_MAX    = 10 * time.Second
-	PSIPHON_API_STATUS_REQUEST_PADDING_MAX_BYTES   = 256
-	PSIPHON_API_CONNECTED_REQUEST_PERIOD           = 24 * time.Hour
-	PSIPHON_API_CONNECTED_REQUEST_RETRY_PERIOD     = 5 * time.Second
-	PSIPHON_API_TUNNEL_STATS_MAX_COUNT             = 1000
-	FETCH_ROUTES_TIMEOUT_SECONDS                   = 60
-	DOWNLOAD_UPGRADE_TIMEOUT                       = 15 * time.Minute
-	DOWNLOAD_UPGRADE_RETRY_PERIOD_SECONDS          = 30
-	DOWNLOAD_UPGRADE_STALE_PERIOD                  = 6 * time.Hour
-	IMPAIRED_PROTOCOL_CLASSIFICATION_DURATION      = 2 * time.Minute
-	IMPAIRED_PROTOCOL_CLASSIFICATION_THRESHOLD     = 3
-	TOTAL_BYTES_TRANSFERRED_NOTICE_PERIOD          = 5 * time.Minute
+	LEGACY_DATA_STORE_FILENAME                           = "psiphon.db"
+	DATA_STORE_FILENAME                                  = "psiphon.boltdb"
+	CONNECTION_WORKER_POOL_SIZE                          = 10
+	TUNNEL_POOL_SIZE                                     = 1
+	TUNNEL_CONNECT_TIMEOUT_SECONDS                       = 20
+	TUNNEL_OPERATE_SHUTDOWN_TIMEOUT                      = 1 * time.Second
+	TUNNEL_PORT_FORWARD_DIAL_TIMEOUT_SECONDS             = 10
+	TUNNEL_SSH_KEEP_ALIVE_PAYLOAD_MAX_BYTES              = 256
+	TUNNEL_SSH_KEEP_ALIVE_PERIOD_MIN                     = 60 * time.Second
+	TUNNEL_SSH_KEEP_ALIVE_PERIOD_MAX                     = 120 * time.Second
+	TUNNEL_SSH_KEEP_ALIVE_PERIODIC_TIMEOUT_SECONDS       = 30
+	TUNNEL_SSH_KEEP_ALIVE_PERIODIC_INACTIVE_PERIOD       = 10 * time.Second
+	TUNNEL_SSH_KEEP_ALIVE_PROBE_TIMEOUT_SECONDS          = 5
+	TUNNEL_SSH_KEEP_ALIVE_PROBE_INACTIVE_PERIOD          = 10 * time.Second
+	ESTABLISH_TUNNEL_TIMEOUT_SECONDS                     = 300
+	ESTABLISH_TUNNEL_WORK_TIME                           = 60 * time.Second
+	ESTABLISH_TUNNEL_PAUSE_PERIOD_SECONDS                = 5
+	ESTABLISH_TUNNEL_SERVER_AFFINITY_GRACE_PERIOD        = 1 * time.Second
+	HTTP_PROXY_ORIGIN_SERVER_TIMEOUT_SECONDS             = 15
+	HTTP_PROXY_MAX_IDLE_CONNECTIONS_PER_HOST             = 50
+	FETCH_REMOTE_SERVER_LIST_TIMEOUT_SECONDS             = 30
+	FETCH_REMOTE_SERVER_LIST_RETRY_PERIOD_SECONDS        = 30
+	FETCH_REMOTE_SERVER_LIST_STALE_PERIOD                = 6 * time.Hour
+	PSIPHON_API_CLIENT_SESSION_ID_LENGTH                 = 16
+	PSIPHON_API_SERVER_TIMEOUT_SECONDS                   = 20
+	PSIPHON_API_SHUTDOWN_SERVER_TIMEOUT                  = 1 * time.Second
+	PSIPHON_API_STATUS_REQUEST_PERIOD_MIN                = 5 * time.Minute
+	PSIPHON_API_STATUS_REQUEST_PERIOD_MAX                = 10 * time.Minute
+	PSIPHON_API_STATUS_REQUEST_SHORT_PERIOD_MIN          = 5 * time.Second
+	PSIPHON_API_STATUS_REQUEST_SHORT_PERIOD_MAX          = 10 * time.Second
+	PSIPHON_API_STATUS_REQUEST_PADDING_MAX_BYTES         = 256
+	PSIPHON_API_CONNECTED_REQUEST_PERIOD                 = 24 * time.Hour
+	PSIPHON_API_CONNECTED_REQUEST_RETRY_PERIOD           = 5 * time.Second
+	PSIPHON_API_TUNNEL_STATS_MAX_COUNT                   = 1000
+	PSIPHON_API_CLIENT_VERIFICATION_REQUEST_RETRY_PERIOD = 5 * time.Second
+	FETCH_ROUTES_TIMEOUT_SECONDS                         = 60
+	DOWNLOAD_UPGRADE_TIMEOUT                             = 15 * time.Minute
+	DOWNLOAD_UPGRADE_RETRY_PERIOD_SECONDS                = 30
+	DOWNLOAD_UPGRADE_STALE_PERIOD                        = 6 * time.Hour
+	IMPAIRED_PROTOCOL_CLASSIFICATION_DURATION            = 2 * time.Minute
+	IMPAIRED_PROTOCOL_CLASSIFICATION_THRESHOLD           = 3
+	TOTAL_BYTES_TRANSFERRED_NOTICE_PERIOD                = 5 * time.Minute
 )
 
 // To distinguish omitted timeout params from explicit 0 value timeout

+ 52 - 1
psiphon/controller.go

@@ -60,6 +60,7 @@ type Controller struct {
 	impairedProtocolClassification map[string]int
 	signalReportConnected          chan struct{}
 	serverAffinityDoneBroadcast    chan struct{}
+	newClientVerificationPayload   chan string
 }
 
 type candidateServerEntry struct {
@@ -107,7 +108,7 @@ func NewController(config *Config) (controller *Controller, err error) {
 		sessionId: sessionId,
 		// componentFailureSignal receives a signal from a component (including socks and
 		// http local proxies) if they unexpectedly fail. Senders should not block.
-		// A buffer allows at least one stop signal to be sent before there is a receiver.
+		// Buffer allows at least one stop signal to be sent before there is a receiver.
 		componentFailureSignal: make(chan struct{}, 1),
 		shutdownBroadcast:      make(chan struct{}),
 		runWaitGroup:           new(sync.WaitGroup),
@@ -129,6 +130,9 @@ func NewController(config *Config) (controller *Controller, err error) {
 		signalFetchRemoteServerList: make(chan struct{}),
 		signalDownloadUpgrade:       make(chan string),
 		signalReportConnected:       make(chan struct{}),
+		// Buffer allows SetClientVerificationPayload to submit one new payload
+		// without blocking or dropping it.
+		newClientVerificationPayload: make(chan string, 1),
 	}
 
 	controller.splitTunnelClassifier = NewSplitTunnelClassifier(config, controller)
@@ -242,6 +246,31 @@ func (controller *Controller) SignalComponentFailure() {
 	}
 }
 
+// SetClientVerificationPayload sets the client verification payload
+// that is to be sent in client verification requests to all established
+// tunnels. Calling this function both sets the payload to be used for
+// all future tunnels as wells as triggering requests with this payload
+// for all currently established tunneled.
+//
+// Client verification is used to verify that the client is a
+// valid Psiphon client, which will determine how the server treats
+// the client traffic. The proof-of-validity is platform-specific
+// and the payload is opaque to this function but assumed to be JSON.
+//
+// Since, in some cases, verification payload cannot be determined until
+// after tunnel-core starts, the payload cannot be simply specified in
+// the Config.
+//
+// SetClientVerificationPayload will not block enqueuing a new verification
+// payload. One new payload can be enqueued, after which additional payloads
+// will be dropped if a payload is still enqueued.
+func (controller *Controller) SetClientVerificationPayload(clientVerificationPayload string) {
+	select {
+	case controller.newClientVerificationPayload <- clientVerificationPayload:
+	default:
+	}
+}
+
 // remoteServerListFetcher fetches an out-of-band list of server entries
 // for more tunnel candidates. It fetches when signalled, with retries
 // on failure.
@@ -498,6 +527,8 @@ downloadLoop:
 func (controller *Controller) runTunnels() {
 	defer controller.runWaitGroup.Done()
 
+	var clientVerificationPayload string
+
 	// Start running
 
 	controller.startEstablishing()
@@ -555,6 +586,10 @@ loop:
 				break
 			}
 
+			if clientVerificationPayload != "" {
+				establishedTunnel.SetClientVerificationPayload(clientVerificationPayload)
+			}
+
 			NoticeActiveTunnel(establishedTunnel.serverEntry.IpAddress, establishedTunnel.protocol)
 
 			if tunnelCount == 1 {
@@ -597,6 +632,9 @@ loop:
 				controller.stopEstablishing()
 			}
 
+		case clientVerificationPayload = <-controller.newClientVerificationPayload:
+			controller.setClientVerificationPayloadForActiveTunnels(clientVerificationPayload)
+
 		case <-controller.shutdownBroadcast:
 			break loop
 		}
@@ -817,6 +855,19 @@ func (controller *Controller) isActiveTunnelServerEntry(serverEntry *ServerEntry
 	return false
 }
 
+// setClientVerificationPayloadForActiveTunnels triggers the client verification
+// request for all active tunnels.
+func (controller *Controller) setClientVerificationPayloadForActiveTunnels(
+	clientVerificationPayload string) {
+
+	controller.tunnelMutex.Lock()
+	defer controller.tunnelMutex.Unlock()
+
+	for _, activeTunnel := range controller.tunnels {
+		activeTunnel.SetClientVerificationPayload(clientVerificationPayload)
+	}
+}
+
 // Dial selects an active tunnel and establishes a port forward
 // connection through the selected tunnel. Failure to connect is considered
 // a port foward failure, for the purpose of monitoring tunnel health.

+ 46 - 0
psiphon/net.go

@@ -35,6 +35,7 @@ import (
 	"time"
 
 	"github.com/Psiphon-Inc/dns"
+	"github.com/Psiphon-Inc/ratelimit"
 )
 
 const DNS_PORT = 53
@@ -617,3 +618,48 @@ func (conn *IdleTimeoutConn) Write(buffer []byte) (int, error) {
 	}
 	return n, err
 }
+
+// ThrottledConn wraps a net.Conn with read and write rate limiters.
+// Rates are specified as bytes per second. The underlying rate limiter
+// uses the token bucket algorithm to calculate delay times for read
+// and write operations. Specify limit values of 0 set no limit.
+type ThrottledConn struct {
+	net.Conn
+	reader io.Reader
+	writer io.Writer
+}
+
+func NewThrottledConn(
+	conn net.Conn,
+	limitReadBytesPerSecond, limitWriteBytesPerSecond int64) *ThrottledConn {
+
+	var reader io.Reader
+	if limitReadBytesPerSecond == 0 {
+		reader = conn
+	} else {
+		reader = ratelimit.Reader(conn,
+			ratelimit.NewBucketWithRate(
+				float64(limitReadBytesPerSecond), limitReadBytesPerSecond))
+	}
+	var writer io.Writer
+	if limitWriteBytesPerSecond == 0 {
+		writer = conn
+	} else {
+		writer = ratelimit.Writer(conn,
+			ratelimit.NewBucketWithRate(
+				float64(limitWriteBytesPerSecond), limitWriteBytesPerSecond))
+	}
+	return &ThrottledConn{
+		Conn:   conn,
+		reader: reader,
+		writer: writer,
+	}
+}
+
+func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
+	return conn.reader.Read(buffer)
+}
+
+func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
+	return conn.writer.Write(buffer)
+}

+ 4 - 0
psiphon/notice.go

@@ -337,6 +337,10 @@ func NoticeRemoteServerListDownloaded(filename string) {
 	outputNotice("RemoteServerListDownloaded", false, false, "filename", filename)
 }
 
+func NoticeClientVerificationRequestCompleted(ipAddress string) {
+	outputNotice("NoticeClientVerificationRequestCompleted", true, false, "ipAddress", ipAddress)
+}
+
 type repetitiveNoticeState struct {
 	message string
 	repeats int

+ 74 - 15
psiphon/server/config.go

@@ -29,6 +29,8 @@ import (
 	"errors"
 	"fmt"
 	"math/big"
+	"net"
+	"strconv"
 	"strings"
 	"time"
 
@@ -55,7 +57,6 @@ const (
 	DEFAULT_SSH_SERVER_PORT                = 2222
 	SSH_HANDSHAKE_TIMEOUT                  = 30 * time.Second
 	SSH_CONNECTION_READ_DEADLINE           = 5 * time.Minute
-	SSH_THROTTLED_PORT_FORWARD_MAX_COPY    = 32 * 1024
 	SSH_OBFUSCATED_KEY_BYTE_LENGTH         = 32
 	DEFAULT_OBFUSCATED_SSH_SERVER_PORT     = 3333
 	REDIS_POOL_MAX_IDLE                    = 50
@@ -168,21 +169,34 @@ type Config struct {
 	// client IP address. The key for each regional traffic rule entry
 	// is one or more space delimited ISO 3166-1 alpha-2 country codes.
 	RegionalTrafficRules map[string]TrafficRules
+
+	// DNSServerAddress specifies the network address of a DNS server
+	// to which DNS UDP packets will be forwarded to. When set, any
+	// tunneled DNS UDP packets will be re-routed to this destination.
+	DNSServerAddress string
+
+	// UdpgwServerAddress specifies the network address of a udpgw
+	// server which clients may be port forwarding to. When specified,
+	// these TCP port forwards are intercepted and handled directly
+	// by this server, which parses the SSH channel using the udpgw
+	// protocol.
+	UdpgwServerAddress string
 }
 
-// TrafficRules specify the limits placed on SSH client port forward
-// traffic.
+// TrafficRules specify the limits placed on client traffic.
 type TrafficRules struct {
 
-	// ThrottleUpstreamSleepMilliseconds is the period to sleep
-	// between sending each chunk of client->destination traffic.
-	// The default, 0, is no sleep.
-	ThrottleUpstreamSleepMilliseconds int
+	// LimitDownstreamBytesPerSecond specifies a rate limit for
+	// downstream data transfer between a single client and the
+	// server.
+	// The default, 0, is no rate limit.
+	LimitDownstreamBytesPerSecond int
 
-	// ThrottleDownstreamSleepMilliseconds is the period to sleep
-	// between sending each chunk of destination->client traffic.
-	// The default, 0, is no sleep.
-	ThrottleDownstreamSleepMilliseconds int
+	// LimitUpstreamBytesPerSecond specifies a rate limit for
+	// upstream data transfer between a single client and the
+	// server.
+	// The default, 0, is no rate limit.
+	LimitUpstreamBytesPerSecond int
 
 	// IdlePortForwardTimeoutMilliseconds is the timeout period
 	// after which idle (no bytes flowing in either direction)
@@ -190,10 +204,35 @@ type TrafficRules struct {
 	// The default, 0, is no idle timeout.
 	IdlePortForwardTimeoutMilliseconds int
 
-	// MaxClientPortForwardCount is the maximum number of port
+	// MaxTCPPortForwardCount is the maximum number of TCP port
+	// forwards each client may have open concurrently.
+	// The default, 0, is no maximum.
+	MaxTCPPortForwardCount int
+
+	// MaxUDPPortForwardCount is the maximum number of UDP port
 	// forwards each client may have open concurrently.
 	// The default, 0, is no maximum.
-	MaxClientPortForwardCount int
+	MaxUDPPortForwardCount int
+
+	// AllowTCPPorts specifies a whitelist of TCP ports that
+	// are permitted for port forwarding. When set, only ports
+	// in the list are accessible to clients.
+	AllowTCPPorts []int
+
+	// AllowUDPPorts specifies a whitelist of UDP ports that
+	// are permitted for port forwarding. When set, only ports
+	// in the list are accessible to clients.
+	AllowUDPPorts []int
+
+	// DenyTCPPorts specifies a blacklist of TCP ports that
+	// are not permitted for port forwarding. When set, the
+	// ports in the list are inaccessible to clients.
+	DenyTCPPorts []int
+
+	// DenyUDPPorts specifies a blacklist of UDP ports that
+	// are not permitted for port forwarding. When set, the
+	// ports in the list are inaccessible to clients.
+	DenyUDPPorts []int
 }
 
 // RunWebServer indicates whether to run a web server component.
@@ -266,7 +305,7 @@ func LoadConfig(configJSONs [][]byte) (*Config, error) {
 		config.WebServerPrivateKey == "") {
 
 		return nil, errors.New(
-			"web server requires WebServerSecret, WebServerCertificate, WebServerPrivateKey")
+			"Web server requires WebServerSecret, WebServerCertificate, WebServerPrivateKey")
 	}
 
 	if config.SSHServerPort > 0 && (config.SSHPrivateKey == "" || config.SSHServerVersion == "" ||
@@ -283,6 +322,26 @@ func LoadConfig(configJSONs [][]byte) (*Config, error) {
 			"Obfuscated SSH server requires SSHPrivateKey, SSHServerVersion, SSHUserName, SSHPassword, ObfuscatedSSHKey")
 	}
 
+	validateNetworkAddress := func(address string) error {
+		_, portStr, err := net.SplitHostPort(config.DNSServerAddress)
+		if err == nil {
+			_, err = strconv.Atoi(portStr)
+		}
+		return err
+	}
+
+	if config.DNSServerAddress != "" {
+		if err := validateNetworkAddress(config.DNSServerAddress); err != nil {
+			return nil, fmt.Errorf("DNSServerAddress is invalid: %s", err)
+		}
+	}
+
+	if config.UdpgwServerAddress != "" {
+		if err := validateNetworkAddress(config.UdpgwServerAddress); err != nil {
+			return nil, fmt.Errorf("UdpgwServerAddress is invalid: %s", err)
+		}
+	}
+
 	return &config, nil
 }
 
@@ -467,7 +526,7 @@ func generateWebServerCertificate() (string, string, error) {
 	notAfter := notBefore.Add(WEB_SERVER_CERTIFICATE_VALIDITY_PERIOD)
 
 	// TODO: psi_ops_install sets serial number to 0?
-	// TOSO: psi_ops_install sets RSA exponent to 3, digest type to 'sha1', and version to 2?
+	// TODO: psi_ops_install sets RSA exponent to 3, digest type to 'sha1', and version to 2?
 
 	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
 	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)

+ 154 - 83
psiphon/server/sshService.go

@@ -211,16 +211,20 @@ func (sshServer *sshServer) stopClient(client *sshClient) {
 	client.Lock()
 	log.WithContextFields(
 		LogFields{
-			"startTime":                      client.startTime,
-			"duration":                       time.Now().Sub(client.startTime),
-			"psiphonSessionID":               client.psiphonSessionID,
-			"country":                        client.geoIPData.Country,
-			"city":                           client.geoIPData.City,
-			"ISP":                            client.geoIPData.ISP,
-			"bytesUp":                        client.bytesUp,
-			"bytesDown":                      client.bytesDown,
-			"portForwardCount":               client.portForwardCount,
-			"peakConcurrentPortForwardCount": client.peakConcurrentPortForwardCount,
+			"startTime":                         client.startTime,
+			"duration":                          time.Now().Sub(client.startTime),
+			"psiphonSessionID":                  client.psiphonSessionID,
+			"country":                           client.geoIPData.Country,
+			"city":                              client.geoIPData.City,
+			"ISP":                               client.geoIPData.ISP,
+			"bytesUpTCP":                        client.tcpTrafficState.bytesUp,
+			"bytesDownTCP":                      client.tcpTrafficState.bytesDown,
+			"portForwardCountTCP":               client.tcpTrafficState.portForwardCount,
+			"peakConcurrentPortForwardCountTCP": client.tcpTrafficState.peakConcurrentPortForwardCount,
+			"bytesUpUDP":                        client.udpTrafficState.bytesUp,
+			"bytesDownUDP":                      client.udpTrafficState.bytesDown,
+			"portForwardCountUDP":               client.udpTrafficState.portForwardCount,
+			"peakConcurrentPortForwardCountUDP": client.udpTrafficState.peakConcurrentPortForwardCount,
 		}).Info("tunnel closed")
 	client.Unlock()
 }
@@ -239,12 +243,16 @@ func (sshServer *sshServer) stopClients() {
 
 func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 
+	geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr()))
+
 	sshClient := &sshClient{
-		sshServer: sshServer,
-		startTime: time.Now(),
-		geoIPData: GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr())),
+		sshServer:       sshServer,
+		startTime:       time.Now(),
+		geoIPData:       geoIPData,
+		trafficRules:    sshServer.config.GetTrafficRules(geoIPData.Country),
+		tcpTrafficState: &trafficState{},
+		udpTrafficState: &trafficState{},
 	}
-	sshClient.trafficRules = sshServer.config.GetTrafficRules(sshClient.geoIPData.Country)
 
 	// Wrap the base TCP connection with an IdleTimeoutConn which will terminate
 	// the connection if no data is received before the deadline. This timeout is
@@ -252,7 +260,16 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 	// use the connection or send SSH keep alive requests to keep the connection
 	// active.
 
-	conn := psiphon.NewIdleTimeoutConn(tcpConn, SSH_CONNECTION_READ_DEADLINE, false)
+	var conn net.Conn
+
+	conn = psiphon.NewIdleTimeoutConn(tcpConn, SSH_CONNECTION_READ_DEADLINE, false)
+
+	// Further wrap the connection in a rate limiting ThrottledConn.
+
+	conn = psiphon.NewThrottledConn(
+		conn,
+		int64(sshClient.trafficRules.LimitDownstreamBytesPerSecond),
+		int64(sshClient.trafficRules.LimitUpstreamBytesPerSecond))
 
 	// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
 	// respect shutdownBroadcast and implement a specific handshake timeout.
@@ -334,12 +351,18 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 
 type sshClient struct {
 	sync.Mutex
-	sshServer                      *sshServer
-	sshConn                        ssh.Conn
-	startTime                      time.Time
-	geoIPData                      GeoIPData
-	trafficRules                   TrafficRules
-	psiphonSessionID               string
+	sshServer        *sshServer
+	sshConn          ssh.Conn
+	startTime        time.Time
+	geoIPData        GeoIPData
+	psiphonSessionID string
+	udpChannel       ssh.Channel
+	trafficRules     TrafficRules
+	tcpTrafficState  *trafficState
+	udpTrafficState  *trafficState
+}
+
+type trafficState struct {
 	bytesUp                        int64
 	bytesDown                      int64
 	portForwardCount               int64
@@ -355,20 +378,8 @@ func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
 			return
 		}
 
-		if sshClient.trafficRules.MaxClientPortForwardCount > 0 {
-			sshClient.Lock()
-			limitExceeded := sshClient.portForwardCount >= int64(sshClient.trafficRules.MaxClientPortForwardCount)
-			sshClient.Unlock()
-
-			if limitExceeded {
-				sshClient.rejectNewChannel(
-					newChannel, ssh.ResourceShortage, "maximum port forward limit exceeded")
-				return
-			}
-		}
-
 		// process each port forward concurrently
-		go sshClient.handleNewDirectTcpipChannel(newChannel)
+		go sshClient.handleNewPortForwardChannel(newChannel)
 	}
 }
 
@@ -383,7 +394,7 @@ func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason s
 	newChannel.Reject(reason, message)
 }
 
-func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChannel) {
+func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChannel) {
 
 	// http://tools.ietf.org/html/rfc4254#section-7.2
 	var directTcpipExtraData struct {
@@ -399,14 +410,109 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
 		return
 	}
 
-	targetAddr := fmt.Sprintf("%s:%d",
-		directTcpipExtraData.HostToConnect,
-		directTcpipExtraData.PortToConnect)
+	// Intercept TCP port forwards to a specified udpgw server and handle directly.
+	// TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
+	isUDPChannel := sshClient.sshServer.config.UdpgwServerAddress != "" &&
+		sshClient.sshServer.config.UdpgwServerAddress ==
+			fmt.Sprintf("%s:%d",
+				directTcpipExtraData.HostToConnect,
+				directTcpipExtraData.PortToConnect)
+
+	if isUDPChannel {
+		sshClient.handleUDPChannel(newChannel)
+	} else {
+		sshClient.handleTCPChannel(
+			directTcpipExtraData.HostToConnect, int(directTcpipExtraData.PortToConnect), newChannel)
+	}
+}
+
+func (sshClient *sshClient) isPortForwardPermitted(
+	port int, allowPorts []int, denyPorts []int) bool {
+
+	// TODO: faster lookup?
+	if allowPorts != nil {
+		for _, allowPort := range allowPorts {
+			if port == allowPort {
+				return true
+			}
+		}
+		return false
+	}
+	if denyPorts != nil {
+		for _, denyPort := range denyPorts {
+			if port == denyPort {
+				return false
+			}
+		}
+	}
+	return true
+}
+
+func (sshClient *sshClient) isPortForwardLimitExceeded(
+	state *trafficState, maxPortForwardCount int) bool {
+
+	limitExceeded := false
+	if maxPortForwardCount > 0 {
+		sshClient.Lock()
+		limitExceeded = state.portForwardCount >= int64(maxPortForwardCount)
+		sshClient.Unlock()
+	}
+	return limitExceeded
+}
+
+func (sshClient *sshClient) establishedPortForward(
+	state *trafficState) {
+
+	sshClient.Lock()
+	state.portForwardCount += 1
+	state.concurrentPortForwardCount += 1
+	if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
+		state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
+	}
+	sshClient.Unlock()
+}
+
+func (sshClient *sshClient) closedPortForward(
+	state *trafficState, bytesUp, bytesDown int64) {
+
+	sshClient.Lock()
+	state.concurrentPortForwardCount -= 1
+	state.bytesUp += bytesUp
+	state.bytesDown += bytesDown
+	sshClient.Unlock()
+}
+
+func (sshClient *sshClient) handleTCPChannel(
+	hostToConnect string,
+	portToConnect int,
+	newChannel ssh.NewChannel) {
+
+	if !sshClient.isPortForwardPermitted(
+		portToConnect,
+		sshClient.trafficRules.AllowTCPPorts,
+		sshClient.trafficRules.DenyTCPPorts) {
+
+		sshClient.rejectNewChannel(
+			newChannel, ssh.Prohibited, "port forward not permitted")
+		return
+	}
+
+	// TODO: close LRU connection (after successful Dial) instead of rejecting new connection?
+	if sshClient.isPortForwardLimitExceeded(
+		sshClient.tcpTrafficState,
+		sshClient.trafficRules.MaxTCPPortForwardCount) {
+
+		sshClient.rejectNewChannel(
+			newChannel, ssh.Prohibited, "maximum port forward limit exceeded")
+		return
+	}
+
+	targetAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
 
 	log.WithContextFields(LogFields{"target": targetAddr}).Debug("dialing")
 
+	// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
 	// TODO: port forward dial timeout
-	// TODO: report ssh.ResourceShortage when appropriate
 	// TODO: IPv6 support
 	fwdConn, err := net.Dial("tcp4", targetAddr)
 	if err != nil {
@@ -420,21 +526,13 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
 		log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
 		return
 	}
+	go ssh.DiscardRequests(requests)
+	defer fwdChannel.Close()
 
-	sshClient.Lock()
-	sshClient.portForwardCount += 1
-	sshClient.concurrentPortForwardCount += 1
-	if sshClient.concurrentPortForwardCount > sshClient.peakConcurrentPortForwardCount {
-		sshClient.peakConcurrentPortForwardCount = sshClient.concurrentPortForwardCount
-	}
-	sshClient.Unlock()
+	sshClient.establishedPortForward(sshClient.tcpTrafficState)
 
 	log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
 
-	go ssh.DiscardRequests(requests)
-
-	defer fwdChannel.Close()
-
 	// When idle port forward traffic rules are in place, wrap fwdConn
 	// in an IdleTimeoutConn configured to reset idle on writes as well
 	// as read. This ensures the port forward idle timeout only happens
@@ -449,6 +547,7 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
 
 	// relay channel to forwarded connection
 	// TODO: relay errors to fwdChannel.Stderr()?
+	// TODO: use a low-memory io.Copy?
 
 	var bytesUp, bytesDown int64
 
@@ -457,51 +556,23 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
 	go func() {
 		defer relayWaitGroup.Done()
 		var err error
-		bytesUp, err = copyWithThrottle(
-			fwdConn, fwdChannel, sshClient.trafficRules.ThrottleUpstreamSleepMilliseconds)
+		bytesUp, err = io.Copy(fwdConn, fwdChannel)
 		if err != nil {
-			log.WithContextFields(LogFields{"error": err}).Warning("upstream relay failed")
+			log.WithContextFields(LogFields{"error": err}).Warning("upstream TCP relay failed")
 		}
 	}()
-	bytesDown, err = copyWithThrottle(
-		fwdChannel, fwdConn, sshClient.trafficRules.ThrottleDownstreamSleepMilliseconds)
+	bytesDown, err = io.Copy(fwdChannel, fwdConn)
 	if err != nil {
-		log.WithContextFields(LogFields{"error": err}).Warning("downstream relay failed")
+		log.WithContextFields(LogFields{"error": err}).Warning("downstream TCP relay failed")
 	}
 	fwdChannel.CloseWrite()
 	relayWaitGroup.Wait()
 
-	sshClient.Lock()
-	sshClient.concurrentPortForwardCount -= 1
-	sshClient.bytesUp += bytesUp
-	sshClient.bytesDown += bytesDown
-	sshClient.Unlock()
+	sshClient.closedPortForward(sshClient.tcpTrafficState, bytesUp, bytesDown)
 
 	log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
 }
 
-func copyWithThrottle(dst io.Writer, src io.Reader, throttleSleepMilliseconds int) (int64, error) {
-	// TODO: use a low-memory io.Copy?
-	if throttleSleepMilliseconds <= 0 {
-		// No throttle
-		return io.Copy(dst, src)
-	}
-	var totalBytes int64
-	for {
-		bytes, err := io.CopyN(dst, src, SSH_THROTTLED_PORT_FORWARD_MAX_COPY)
-		totalBytes += bytes
-		if err == io.EOF {
-			err = nil
-			break
-		}
-		if err != nil {
-			return totalBytes, psiphon.ContextError(err)
-		}
-		time.Sleep(time.Duration(throttleSleepMilliseconds) * time.Millisecond)
-	}
-	return totalBytes, nil
-}
-
 func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
 	var sshPasswordPayload struct {
 		SessionId   string `json:"SessionId"`

+ 413 - 0
psiphon/server/udpChannel.go

@@ -0,0 +1,413 @@
+/*
+ * Copyright (c) 2016, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package server
+
+import (
+	"bytes"
+	"encoding/binary"
+	"errors"
+	"fmt"
+	"io"
+	"math"
+	"net"
+	"strconv"
+	"sync"
+	"sync/atomic"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"golang.org/x/crypto/ssh"
+)
+
+// setUDPChannel sets the single UDP channel for this sshClient.
+// Each sshClient may have only one concurrent UDP channel. Each
+// UDP channel multiplexes many UDP port forwards via the udpgw
+// protocol. Any existing UDP channel is closed.
+func (sshClient *sshClient) setUDPChannel(channel ssh.Channel) {
+	sshClient.Lock()
+	if sshClient.udpChannel != nil {
+		sshClient.udpChannel.Close()
+	}
+	sshClient.udpChannel = channel
+	sshClient.Unlock()
+}
+
+// handleUDPChannel implements UDP port forwarding. A single UDP
+// SSH channel follows the udpgw protocol, which multiplexes many
+// UDP port forwards.
+//
+// The udpgw protocol and original server implementation:
+// Copyright (c) 2009, Ambroz Bizjak <ambrop7@gmail.com>
+// https://github.com/ambrop72/badvpn
+//
+func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
+
+	// Accept this channel immediately. This channel will replace any
+	// previously existing UDP channel for this client.
+
+	fwdChannel, requests, err := newChannel.Accept()
+	if err != nil {
+		log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
+		return
+	}
+	go ssh.DiscardRequests(requests)
+	defer fwdChannel.Close()
+
+	sshClient.setUDPChannel(fwdChannel)
+
+	// In a loop, read udpgw messages from the client to this channel. Each message is
+	// a UDP packet to send upstream either via a new port forward, or on an existing
+	// port forward.
+	//
+	// A goroutine is run to read downstream packets for each UDP port forward. All read
+	// packets are encapsulated in udpgw protocol and sent down the channel to the client.
+	//
+	// When the client disconnects or the server shuts down, the channel will close and
+	// readUdpgwMessage will exit with EOF.
+
+	type udpPortForward struct {
+		connID       uint16
+		conn         *net.UDPConn
+		lastActivity int64
+		bytesUp      int64
+		bytesDown    int64
+	}
+
+	var portForwardsMutex sync.Mutex
+	portForwards := make(map[uint16]*udpPortForward)
+	relayWaitGroup := new(sync.WaitGroup)
+	buffer := make([]byte, udpgwProtocolMaxMessageSize)
+
+	for {
+		// Note: udpProtocolMessage.packet points to the resuable
+		// memory in "buffer". Each readUdpgwMessage call will overwrite
+		// the last udpProtocolMessage.packet.
+		udpProtocolMessage, err := readUdpgwMessage(
+			sshClient.sshServer.config, fwdChannel, buffer)
+		if err != nil {
+			if err != io.EOF {
+				log.WithContextFields(LogFields{"error": err}).Warning("readUpdgwMessage failed")
+			}
+			break
+		}
+
+		portForwardsMutex.Lock()
+		portForward := portForwards[udpProtocolMessage.connID]
+		portForwardsMutex.Unlock()
+
+		if portForward != nil && udpProtocolMessage.discardExistingConn {
+			// The port forward's goroutine will complete cleanup, including
+			// tallying stats and calling sshClient.closedPortForward.
+			// portForward.conn.Close() will signal this shutdown.
+			// TODO: wait for goroutine to exit before proceeding?
+			portForward.conn.Close()
+			portForward = nil
+		}
+
+		if portForward == nil {
+
+			if !sshClient.isPortForwardPermitted(
+				udpProtocolMessage.portToConnect,
+				sshClient.trafficRules.AllowUDPPorts,
+				sshClient.trafficRules.DenyUDPPorts) {
+				// The udpgw protocol has no error response, so
+				// we just discard the message and read another.
+				continue
+			}
+
+			if sshClient.isPortForwardLimitExceeded(
+				sshClient.tcpTrafficState,
+				sshClient.trafficRules.MaxUDPPortForwardCount) {
+
+				// When the UDP port forward limit is exceeded, we
+				// select the least recently used (red from or written
+				// to) port forward and discard it.
+
+				// TODO: use "container/list" and avoid a linear scan?
+				portForwardsMutex.Lock()
+				oldestActivity := int64(math.MaxInt64)
+				var oldestPortForward *udpPortForward
+				for _, nextPortForward := range portForwards {
+					if nextPortForward.lastActivity < oldestActivity {
+						oldestPortForward = nextPortForward
+					}
+				}
+				if oldestPortForward != nil {
+					// *** comment: let goro call closePortForward
+					oldestPortForward.conn.Close()
+				}
+				portForwardsMutex.Unlock()
+			}
+
+			// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
+			// TODO: IPv6 support
+			updConn, err := net.Dial(
+				"udp4",
+				fmt.Sprintf("%s:%d", udpProtocolMessage.hostToConnect, udpProtocolMessage.portToConnect))
+			if err != nil {
+				log.WithContextFields(LogFields{"error": err}).Warning("DialUDP failed")
+				continue
+			}
+
+			portForward := &udpPortForward{
+				connID:       udpProtocolMessage.connID,
+				conn:         updConn.(*net.UDPConn),
+				lastActivity: time.Now().UnixNano(),
+				bytesUp:      0,
+				bytesDown:    0,
+			}
+			portForwardsMutex.Lock()
+			portForwards[portForward.connID] = portForward
+			portForwardsMutex.Unlock()
+
+			// TODO: timeout inactive UDP port forwards
+
+			sshClient.establishedPortForward(sshClient.udpTrafficState)
+
+			relayWaitGroup.Add(1)
+			go func(portForward *udpPortForward) {
+				defer relayWaitGroup.Done()
+
+				// Downstream UDP packets are read into the reusable memory
+				// in "buffer" starting at the offset udpgwProtocolHeaderSize,
+				// leaving enough space to write the udpgw header into the
+				// same buffer and use for writing to the ssh channel.
+				//
+				// Note: there is one downstream buffer per UDP port forward,
+				// while for upstream there is one buffer per client.
+				// TODO: is the buffer size larger than necessary?
+				buffer := make([]byte, udpgwProtocolMaxMessageSize)
+				packetBuffer := buffer[udpgwProtocolHeaderSize:udpgwProtocolMaxMessageSize]
+				for {
+					// TODO: if read buffer is too small, excess bytes are discarded?
+					packetSize, _, err := portForward.conn.ReadFrom(packetBuffer)
+					if packetSize > udpgwProtocolMaxPayloadSize {
+						err = fmt.Errorf("unexpected packet size: %d", packetSize)
+					}
+					if err != nil {
+						if err != io.EOF {
+							log.WithContextFields(LogFields{"error": err}).Warning("downstream UDP relay failed")
+						}
+						break
+					}
+
+					writeUdpgwHeader(buffer, uint16(packetSize), portForward.connID)
+
+					_, err = fwdChannel.Write(buffer[0 : udpgwProtocolHeaderSize+packetSize])
+					if err != nil {
+						// Close the channel, which will interrupt the main loop.
+						fwdChannel.Close()
+						log.WithContextFields(LogFields{"error": err}).Warning("downstream UDP relay failed")
+						break
+					}
+
+					atomic.StoreInt64(&portForward.lastActivity, time.Now().UnixNano())
+					atomic.AddInt64(&portForward.bytesDown, int64(packetSize))
+				}
+
+				portForwardsMutex.Lock()
+				delete(portForwards, portForward.connID)
+				portForwardsMutex.Unlock()
+
+				portForward.conn.Close()
+
+				bytesUp := atomic.LoadInt64(&portForward.bytesUp)
+				bytesDown := atomic.LoadInt64(&portForward.bytesDown)
+				sshClient.closedPortForward(sshClient.udpTrafficState, bytesUp, bytesDown)
+
+			}(portForward)
+		}
+
+		// Note: assumes UDP writes won't block (https://golang.org/pkg/net/#UDPConn.WriteToUDP)
+		_, err = portForward.conn.WriteTo(udpProtocolMessage.packet, nil)
+		if err != nil {
+			log.WithContextFields(LogFields{"error": err}).Warning("upstream UDP relay failed")
+			// The port forward's goroutine will complete cleanup
+			portForward.conn.Close()
+		}
+		atomic.StoreInt64(&portForward.lastActivity, time.Now().UnixNano())
+		atomic.AddInt64(&portForward.bytesUp, int64(len(udpProtocolMessage.packet)))
+	}
+
+	// Cleanup all UDP port forward workers when exiting
+
+	portForwardsMutex.Lock()
+	for _, portForward := range portForwards {
+		// The port forward's goroutine will complete cleanup
+		portForward.conn.Close()
+	}
+	portForwardsMutex.Unlock()
+
+	relayWaitGroup.Wait()
+}
+
+// TODO: express and/or calculate udpgwProtocolMaxPayloadSize as function of MTU?
+const (
+	udpgwProtocolFlagKeepalive = 1 << 0
+	udpgwProtocolFlagRebind    = 1 << 1
+	udpgwProtocolFlagDNS       = 1 << 2
+	udpgwProtocolFlagIPv6      = 1 << 3
+
+	udpgwProtocolHeaderSize     = 3
+	udpgwProtocolIPv4AddrSize   = 6
+	udpgwProtocolIPv6AddrSize   = 18
+	udpgwProtocolMaxPayloadSize = 32768
+	udpgwProtocolMaxMessageSize = udpgwProtocolHeaderSize +
+		udpgwProtocolIPv6AddrSize +
+		udpgwProtocolMaxPayloadSize
+)
+
+type udpgwHeader struct {
+	Size   uint16
+	Flags  uint8
+	ConnID uint16
+}
+
+type udpgwAddrIPv4 struct {
+	IP   uint32
+	Port uint16
+}
+
+type udpgwAddrIPv6 struct {
+	IP   [16]uint8
+	Port uint16
+}
+
+type udpProtocolMessage struct {
+	connID              uint16
+	discardExistingConn bool
+	hostToConnect       string
+	portToConnect       int
+	packet              []byte
+}
+
+func readUdpgwMessage(
+	config *Config, reader io.Reader, buffer []byte) (*udpProtocolMessage, error) {
+
+	for {
+		// Read udpgwHeader
+
+		_, err := io.ReadFull(reader, buffer[0:udpgwProtocolHeaderSize])
+		if err != nil {
+			return nil, psiphon.ContextError(err)
+		}
+
+		var header udpgwHeader
+		err = binary.Read(
+			bytes.NewReader(buffer[0:udpgwProtocolHeaderSize]), binary.BigEndian, &header)
+		if err != nil {
+			return nil, psiphon.ContextError(err)
+		}
+
+		if int(header.Size) < udpgwProtocolHeaderSize || int(header.Size) > len(buffer) {
+			return nil, psiphon.ContextError(errors.New("invalid udpgw message size"))
+		}
+
+		_, err = io.ReadFull(reader, buffer[udpgwProtocolHeaderSize:header.Size])
+		if err != nil {
+			return nil, psiphon.ContextError(err)
+		}
+
+		// Ignore udpgw keep-alive messages -- read another message
+
+		if header.Flags&udpgwProtocolFlagKeepalive == udpgwProtocolFlagKeepalive {
+			continue
+		}
+
+		// Read udpgwAddrIPv4 or udpgwAddrIPv6
+
+		var hostToConnect string
+		var portToConnect int
+		var packetOffset int
+
+		if header.Flags&udpgwProtocolFlagIPv6 == udpgwProtocolFlagIPv6 {
+
+			var addr udpgwAddrIPv6
+			err = binary.Read(
+				bytes.NewReader(
+					buffer[udpgwProtocolHeaderSize:udpgwProtocolHeaderSize+udpgwProtocolIPv6AddrSize]),
+				binary.BigEndian, &addr)
+			if err != nil {
+				return nil, psiphon.ContextError(err)
+			}
+
+			ip := make(net.IP, 16)
+			copy(ip, addr.IP[:])
+
+			hostToConnect = ip.String()
+			portToConnect = int(addr.Port)
+			packetOffset = udpgwProtocolHeaderSize + udpgwProtocolIPv6AddrSize
+
+		} else {
+
+			var addr udpgwAddrIPv4
+			err = binary.Read(
+				bytes.NewReader(
+					buffer[udpgwProtocolHeaderSize:udpgwProtocolHeaderSize+udpgwProtocolIPv4AddrSize]),
+				binary.BigEndian, &addr)
+
+			ip := make(net.IP, 4)
+			binary.BigEndian.PutUint32(ip, addr.IP)
+
+			hostToConnect = net.IP(ip).String()
+			portToConnect = int(addr.Port)
+			packetOffset = udpgwProtocolHeaderSize + udpgwProtocolIPv4AddrSize
+		}
+
+		// Assemble message
+		// Note: udpProtocolMessage.packet references memory in the input buffer
+
+		udpProtocolMessage := &udpProtocolMessage{
+			connID:              header.ConnID,
+			discardExistingConn: header.Flags&udpgwProtocolFlagRebind == udpgwProtocolFlagRebind,
+			hostToConnect:       hostToConnect,
+			portToConnect:       portToConnect,
+			packet:              buffer[packetOffset : int(header.Size)-packetOffset],
+		}
+
+		// Transparent DNS forwarding
+
+		if (header.Flags&udpgwProtocolFlagDNS == udpgwProtocolFlagDNS) &&
+			config.DNSServerAddress != "" {
+
+			// Note: DNSServerAddress SplitHostPort is checked in LoadConfig
+			host, portStr, _ := net.SplitHostPort(config.DNSServerAddress)
+			port, _ := strconv.Atoi(portStr)
+			udpProtocolMessage.hostToConnect = host
+			udpProtocolMessage.portToConnect = port
+		}
+
+		return udpProtocolMessage, nil
+	}
+}
+
+func writeUdpgwHeader(
+	buffer []byte, packetSize uint16, connID uint16) {
+	// TODO: write directly into buffer
+	header := make([]byte, 0, udpgwProtocolHeaderSize)
+	binary.Write(
+		bytes.NewBuffer(header),
+		binary.BigEndian,
+		&udpgwHeader{
+			Size:   udpgwProtocolHeaderSize + packetSize,
+			Flags:  0,
+			ConnID: connID})
+	copy(buffer[0:udpgwProtocolHeaderSize], header)
+}

+ 14 - 0
psiphon/serverApi.go

@@ -553,6 +553,20 @@ func RecordTunnelStats(
 	return StoreTunnelStats(tunnelStatsJson)
 }
 
+// DoClientVerificationRequest performs the client_verification API
+// request. This request is used to verify that the client is a
+// valid Psiphon client, which will determine how the server treats
+// the client traffic. The proof-of-validity is platform-specific
+// and the payload is opaque to this function but assumed to be JSON.
+func (serverContext *ServerContext) DoClientVerificationRequest(
+	verificationPayload string) error {
+
+	return serverContext.doPostRequest(
+		buildRequestUrl(serverContext.baseRequestUrl, "client_verification"),
+		"application/json",
+		bytes.NewReader([]byte(verificationPayload)))
+}
+
 // doGetRequest makes a tunneled HTTPS request and returns the response body.
 func (serverContext *ServerContext) doGetRequest(
 	requestUrl string) (responseBody []byte, err error) {

+ 89 - 16
psiphon/tunnel.go

@@ -63,22 +63,23 @@ type TunnelOwner interface {
 // tunnel includes a network connection to the specified server
 // and an SSH session built on top of that transport.
 type Tunnel struct {
-	mutex                    *sync.Mutex
-	config                   *Config
-	untunneledDialConfig     *DialConfig
-	isDiscarded              bool
-	isClosed                 bool
-	serverEntry              *ServerEntry
-	serverContext            *ServerContext
-	protocol                 string
-	conn                     net.Conn
-	sshClient                *ssh.Client
-	operateWaitGroup         *sync.WaitGroup
-	shutdownOperateBroadcast chan struct{}
-	signalPortForwardFailure chan struct{}
-	totalPortForwardFailures int
-	startTime                time.Time
-	meekStats                *MeekStats
+	mutex                        *sync.Mutex
+	config                       *Config
+	untunneledDialConfig         *DialConfig
+	isDiscarded                  bool
+	isClosed                     bool
+	serverEntry                  *ServerEntry
+	serverContext                *ServerContext
+	protocol                     string
+	conn                         net.Conn
+	sshClient                    *ssh.Client
+	operateWaitGroup             *sync.WaitGroup
+	shutdownOperateBroadcast     chan struct{}
+	signalPortForwardFailure     chan struct{}
+	totalPortForwardFailures     int
+	startTime                    time.Time
+	meekStats                    *MeekStats
+	newClientVerificationPayload chan string
 }
 
 // EstablishTunnel first makes a network transport connection to the
@@ -135,6 +136,9 @@ func EstablishTunnel(
 		// not listening. Senders should not block.
 		signalPortForwardFailure: make(chan struct{}, 1),
 		meekStats:                meekStats,
+		// Buffer allows SetClientVerificationPayload to submit one new payload
+		// without blocking or dropping it.
+		newClientVerificationPayload: make(chan string, 1),
 	}
 
 	// Create a new Psiphon API server context for this tunnel. This includes
@@ -262,6 +266,17 @@ func (tunnel *Tunnel) SignalComponentFailure() {
 	tunnel.Close(false)
 }
 
+// SetClientVerificationPayload triggers a client verification request, for this
+// tunnel, with the specified verifiction payload. If the tunnel is not yet established,
+// the payload/request is enqueued. If a payload/request is already eneueued, the
+// new payload is dropped.
+func (tunnel *Tunnel) SetClientVerificationPayload(clientVerificationPayload string) {
+	select {
+	case tunnel.newClientVerificationPayload <- clientVerificationPayload:
+	default:
+	}
+}
+
 // TunneledConn implements net.Conn and wraps a port foward connection.
 // It is used to hook into Read and Write to observe I/O errors and
 // report these errors back to the tunnel monitor as port forward failures.
@@ -767,6 +782,39 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 		}
 	}()
 
+	requestsWaitGroup.Add(1)
+	signalStopClientVerificationRequests := make(chan struct{})
+	go func() {
+		defer requestsWaitGroup.Done()
+
+		clientVerificationPayload := ""
+		for {
+			// TODO: use reflect.SelectCase?
+			if clientVerificationPayload == "" {
+				select {
+				case clientVerificationPayload = <-tunnel.newClientVerificationPayload:
+				case <-signalStopClientVerificationRequests:
+					return
+				}
+			} else {
+				// When clientVerificationPayload is not "", the request for that
+				// payload so retry after a delay. Will use a new payload instead
+				// if that arrives in the meantime.
+				timeout := time.After(PSIPHON_API_CLIENT_VERIFICATION_REQUEST_RETRY_PERIOD)
+				select {
+				case <-timeout:
+				case clientVerificationPayload = <-tunnel.newClientVerificationPayload:
+				case <-signalStopClientVerificationRequests:
+					return
+				}
+			}
+			if sendClientVerification(tunnel, clientVerificationPayload) {
+				clientVerificationPayload = ""
+			}
+
+		}
+	}()
+
 	shutdown := false
 	var err error
 	for !shutdown && err == nil {
@@ -833,6 +881,7 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 
 	close(signalSshKeepAlive)
 	close(signalStatusRequest)
+	close(signalStopClientVerificationRequests)
 	requestsWaitGroup.Wait()
 
 	// Capture bytes transferred since the last noticeBytesTransferredTicker tick
@@ -972,3 +1021,27 @@ func sendUntunneledStats(tunnel *Tunnel, isShutdown bool) {
 		NoticeAlert("TryUntunneledStatusRequest failed for %s: %s", tunnel.serverEntry.IpAddress, err)
 	}
 }
+
+// sendClientVerification is a helper for sending a client verification request
+// to the server.
+func sendClientVerification(tunnel *Tunnel, clientVerificationPayload string) bool {
+
+	// Tunnel does not have a serverContext when DisableApi is set
+	if tunnel.serverContext == nil {
+		return true
+	}
+
+	// Skip when tunnel is discarded
+	if tunnel.IsDiscarded() {
+		return true
+	}
+
+	err := tunnel.serverContext.DoClientVerificationRequest(clientVerificationPayload)
+	if err != nil {
+		NoticeAlert("DoClientVerificationRequest failed for %s: %s", tunnel.serverEntry.IpAddress, err)
+	} else {
+		NoticeClientVerificationRequestCompleted(tunnel.serverEntry.IpAddress)
+	}
+
+	return err == nil
+}