Browse Source

Common code refactoring

* Move towards server not importing client.
* Added "common" package, containing code
  used by both client and server or multiple
  packages (ex. Reloader is used in both
  "server" and "psinet").
  * Note: server still imports "psiphon" for
    server entry and obfuscator logic, which
    remain in "psiphon"/client until stand-alone
    unit tests provide equivalent coverage
    reporting.
* Add optional throttling capability to client.
  * moved ThrottledConn to "common".
* Remove upstreamproxy_test.go due to import
  cycle issue. Should be redone as a stand-alone
  unit test.
Rod Hynes 9 years ago
parent
commit
1b30d92d14
58 changed files with 1412 additions and 1442 deletions
  1. 1 2
      .travis.yml
  2. 3 2
      AndroidLibrary/psi/psi.go
  3. 3 2
      ConsoleClient/main.go
  4. 7 5
      psiphon/LookupIP.go
  5. 3 1
      psiphon/LookupIP_nobind.go
  6. 5 4
      psiphon/TCPConn.go
  7. 12 10
      psiphon/TCPConn_bind.go
  8. 3 1
      psiphon/TCPConn_nobind.go
  9. 85 0
      psiphon/common/net.go
  10. 78 0
      psiphon/common/protocol.go
  11. 146 0
      psiphon/common/reloader.go
  12. 28 0
      psiphon/common/reloader_test.go
  13. 132 0
      psiphon/common/throttled.go
  14. 28 0
      psiphon/common/throttled_test.go
  15. 178 0
      psiphon/common/utils.go
  16. 12 3
      psiphon/common/utils_test.go
  17. 21 13
      psiphon/config.go
  18. 11 9
      psiphon/controller.go
  19. 14 13
      psiphon/controller_test.go
  20. 34 33
      psiphon/dataStore.go
  21. 14 12
      psiphon/httpProxy.go
  22. 28 27
      psiphon/meekConn.go
  23. 6 5
      psiphon/migrateDataStore_windows.go
  24. 19 142
      psiphon/net.go
  25. 5 3
      psiphon/networkInterface.go
  26. 4 2
      psiphon/notice.go
  27. 26 24
      psiphon/obfuscatedSshConn.go
  28. 28 26
      psiphon/obfuscator.go
  29. 5 4
      psiphon/opensslConn.go
  30. 3 1
      psiphon/opensslConn_unsupported.go
  31. 8 6
      psiphon/package.go
  32. 14 12
      psiphon/remoteServerList.go
  33. 48 48
      psiphon/server/api.go
  34. 37 36
      psiphon/server/config.go
  35. 11 11
      psiphon/server/dns.go
  36. 7 7
      psiphon/server/geoip.go
  37. 5 5
      psiphon/server/log.go
  38. 23 22
      psiphon/server/meek.go
  39. 61 80
      psiphon/server/net.go
  40. 6 6
      psiphon/server/psinet/psinet.go
  41. 2 2
      psiphon/server/safetyNet.go
  42. 5 4
      psiphon/server/server_test.go
  43. 10 10
      psiphon/server/services.go
  44. 9 32
      psiphon/server/trafficRules.go
  45. 22 26
      psiphon/server/tunnelServer.go
  46. 7 7
      psiphon/server/udp.go
  47. 7 7
      psiphon/server/utils.go
  48. 8 8
      psiphon/server/webServer.go
  49. 53 54
      psiphon/serverApi.go
  50. 12 61
      psiphon/serverEntry.go
  51. 4 2
      psiphon/serverEntry_test.go
  52. 7 6
      psiphon/socksProxy.go
  53. 18 16
      psiphon/splitTunnel.go
  54. 9 7
      psiphon/tlsDialer.go
  55. 66 48
      psiphon/tunnel.go
  56. 7 5
      psiphon/upgradeDownload.go
  57. 0 290
      psiphon/upstreamproxy/upstreamproxy_test.go
  58. 4 280
      psiphon/utils.go

+ 1 - 2
.travis.yml

@@ -9,9 +9,8 @@ install:
 - go get -t -d -v ./... && go build -v ./...
 script:
 - cd psiphon
+- go test -v -covermode=count -coverprofile=common.coverprofile ./common
 - go test -v -covermode=count -coverprofile=transferstats.coverprofile ./transferstats
-# upstreamproxy test disabled until "import cycle not allowed in test" problem resolved
-#- go test -v -covermode=count -coverprofile=upstreamproxy.coverprofile ./upstreamproxy
 - go test -v -covermode=count -coverprofile=server.coverprofile ./server
 - go test -v -covermode=count -coverprofile=psiphon.coverprofile
 - $HOME/gopath/bin/gover

+ 3 - 2
AndroidLibrary/psi/psi.go

@@ -29,6 +29,7 @@ import (
 	"sync"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 type PsiphonProvider interface {
@@ -83,8 +84,8 @@ func Start(
 
 	serverEntries, err := psiphon.DecodeAndValidateServerEntryList(
 		embeddedServerEntryList,
-		psiphon.GetCurrentTimestamp(),
-		psiphon.SERVER_ENTRY_SOURCE_EMBEDDED)
+		common.GetCurrentTimestamp(),
+		common.SERVER_ENTRY_SOURCE_EMBEDDED)
 	if err != nil {
 		return fmt.Errorf("error decoding embedded server entry list: %s", err)
 	}

+ 3 - 2
ConsoleClient/main.go

@@ -29,6 +29,7 @@ import (
 	"sync"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 func main() {
@@ -139,8 +140,8 @@ func main() {
 			// TODO: stream embedded server list data? also, the cast makes an unnecessary copy of a large buffer?
 			serverEntries, err := psiphon.DecodeAndValidateServerEntryList(
 				string(serverEntryList),
-				psiphon.GetCurrentTimestamp(),
-				psiphon.SERVER_ENTRY_SOURCE_EMBEDDED)
+				common.GetCurrentTimestamp(),
+				common.SERVER_ENTRY_SOURCE_EMBEDDED)
 			if err != nil {
 				psiphon.NoticeError("error decoding embedded server entry list file: %s", err)
 				return

+ 7 - 5
psiphon/LookupIP.go

@@ -28,6 +28,8 @@ import (
 	"os"
 	"syscall"
 	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // LookupIP resolves a hostname. When BindToDevice is not required, it
@@ -69,19 +71,19 @@ func bindLookupIP(host, dnsServer string, config *DialConfig) (addrs []net.IP, e
 
 	socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, 0)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	defer syscall.Close(socketFd)
 
 	err = config.DeviceBinder.BindToDevice(socketFd)
 	if err != nil {
-		return nil, ContextError(fmt.Errorf("BindToDevice failed: %s", err))
+		return nil, common.ContextError(fmt.Errorf("BindToDevice failed: %s", err))
 	}
 
 	// config.DnsServerGetter.GetDnsServers() must return IP addresses
 	ipAddr = net.ParseIP(dnsServer)
 	if ipAddr == nil {
-		return nil, ContextError(errors.New("invalid IP address"))
+		return nil, common.ContextError(errors.New("invalid IP address"))
 	}
 
 	// TODO: IPv6 support
@@ -91,7 +93,7 @@ func bindLookupIP(host, dnsServer string, config *DialConfig) (addrs []net.IP, e
 	// Note: no timeout or interrupt for this connect, as it's a datagram socket
 	err = syscall.Connect(socketFd, &sockAddr)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// Convert the syscall socket to a net.Conn, for use in the dns package
@@ -99,7 +101,7 @@ func bindLookupIP(host, dnsServer string, config *DialConfig) (addrs []net.IP, e
 	defer file.Close()
 	conn, err := net.FileConn(file)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// Set DNS query timeouts, using the ConnectTimeout from the overall Dial

+ 3 - 1
psiphon/LookupIP_nobind.go

@@ -24,13 +24,15 @@ package psiphon
 import (
 	"errors"
 	"net"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // LookupIP resolves a hostname. When BindToDevice is not required, it
 // simply uses net.LookupIP.
 func LookupIP(host string, config *DialConfig) (addrs []net.IP, err error) {
 	if config.DeviceBinder != nil {
-		return nil, ContextError(errors.New("LookupIP with DeviceBinder not supported on this platform"))
+		return nil, common.ContextError(errors.New("LookupIP with DeviceBinder not supported on this platform"))
 	}
 	return net.LookupIP(host)
 }

+ 5 - 4
psiphon/TCPConn.go

@@ -24,6 +24,7 @@ import (
 	"net"
 	"sync"
 
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/upstreamproxy"
 )
 
@@ -59,12 +60,12 @@ func makeTCPDialer(config *DialConfig) func(network, addr string) (net.Conn, err
 		}
 		conn, err := interruptibleTCPDial(addr, config)
 		if err != nil {
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 		// Note: when an upstream proxy is used, we don't know what IP address
 		// was resolved, by the proxy, for that destination.
 		if config.ResolvedIPCallback != nil && config.UpstreamProxyUrl == "" {
-			ipAddress := IPAddressFromAddr(conn.RemoteAddr())
+			ipAddress := common.IPAddressFromAddr(conn.RemoteAddr())
 			if ipAddress != "" {
 				config.ResolvedIPCallback(ipAddress)
 			}
@@ -92,7 +93,7 @@ func interruptibleTCPDial(addr string, config *DialConfig) (*TCPConn, error) {
 
 	// Enable interruption
 	if config.PendingConns != nil && !config.PendingConns.Add(conn) {
-		return nil, ContextError(errors.New("pending connections already closed"))
+		return nil, common.ContextError(errors.New("pending connections already closed"))
 	}
 
 	// Call the blocking Connect() in a goroutine. ConnectTimeout is handled
@@ -134,7 +135,7 @@ func interruptibleTCPDial(addr string, config *DialConfig) (*TCPConn, error) {
 	// Wait until Dial completes (or times out) or until interrupt
 	err := <-conn.dialResult
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return conn, nil

+ 12 - 10
psiphon/TCPConn_bind.go

@@ -29,6 +29,8 @@ import (
 	"strconv"
 	"syscall"
 	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // tcpDial is the platform-specific part of interruptibleTCPDial
@@ -52,18 +54,18 @@ func tcpDial(addr string, config *DialConfig, dialResult chan error) (net.Conn,
 	// Get the remote IP and port, resolving a domain name if necessary
 	host, strPort, err := net.SplitHostPort(addr)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	port, err := strconv.Atoi(strPort)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	ipAddrs, err := LookupIP(host, config)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	if len(ipAddrs) < 1 {
-		return nil, ContextError(errors.New("no IP address"))
+		return nil, common.ContextError(errors.New("no IP address"))
 	}
 
 	// Select an IP at random from the list, so we're not always
@@ -71,9 +73,9 @@ func tcpDial(addr string, config *DialConfig, dialResult chan error) (net.Conn,
 	// TODO: retry all IPs until one connects? For now, this retry
 	// will happen on subsequent TCPDial calls, when a different IP
 	// is selected.
-	index, err := MakeSecureRandomInt(len(ipAddrs))
+	index, err := common.MakeSecureRandomInt(len(ipAddrs))
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// TODO: IPv6 support
@@ -83,7 +85,7 @@ func tcpDial(addr string, config *DialConfig, dialResult chan error) (net.Conn,
 	// Create a socket and bind to device, when configured to do so
 	socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	if config.DeviceBinder != nil {
@@ -94,7 +96,7 @@ func tcpDial(addr string, config *DialConfig, dialResult chan error) (net.Conn,
 		err = config.DeviceBinder.BindToDevice(socketFd)
 		if err != nil {
 			syscall.Close(socketFd)
-			return nil, ContextError(fmt.Errorf("BindToDevice failed: %s", err))
+			return nil, common.ContextError(fmt.Errorf("BindToDevice failed: %s", err))
 		}
 	}
 
@@ -102,7 +104,7 @@ func tcpDial(addr string, config *DialConfig, dialResult chan error) (net.Conn,
 	err = syscall.Connect(socketFd, &sockAddr)
 	if err != nil {
 		syscall.Close(socketFd)
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// Convert the socket fd to a net.Conn
@@ -110,7 +112,7 @@ func tcpDial(addr string, config *DialConfig, dialResult chan error) (net.Conn,
 	netConn, err := net.FileConn(file) // net.FileConn() dups socketFd
 	file.Close()                       // file.Close() closes socketFd
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return netConn, nil

+ 3 - 1
psiphon/TCPConn_nobind.go

@@ -24,13 +24,15 @@ package psiphon
 import (
 	"errors"
 	"net"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // tcpDial is the platform-specific part of interruptibleTCPDial
 func tcpDial(addr string, config *DialConfig, dialResult chan error) (net.Conn, error) {
 
 	if config.DeviceBinder != nil {
-		return nil, ContextError(errors.New("psiphon.interruptibleTCPDial with DeviceBinder not supported"))
+		return nil, common.ContextError(errors.New("psiphon.interruptibleTCPDial with DeviceBinder not supported"))
 	}
 
 	return net.DialTimeout("tcp", addr, config.ConnectTimeout)

+ 85 - 0
psiphon/common/net.go

@@ -0,0 +1,85 @@
+/*
+ * Copyright (c) 2016, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package common
+
+import (
+	"net"
+	"sync"
+)
+
+// Conns is a synchronized list of Conns that is used to coordinate
+// interrupting a set of goroutines establishing connections, or
+// close a set of open connections, etc.
+// Once the list is closed, no more items may be added to the
+// list (unless it is reset).
+type Conns struct {
+	mutex    sync.Mutex
+	isClosed bool
+	conns    map[net.Conn]bool
+}
+
+func (conns *Conns) Reset() {
+	conns.mutex.Lock()
+	defer conns.mutex.Unlock()
+	conns.isClosed = false
+	conns.conns = make(map[net.Conn]bool)
+}
+
+func (conns *Conns) Add(conn net.Conn) bool {
+	conns.mutex.Lock()
+	defer conns.mutex.Unlock()
+	if conns.isClosed {
+		return false
+	}
+	if conns.conns == nil {
+		conns.conns = make(map[net.Conn]bool)
+	}
+	conns.conns[conn] = true
+	return true
+}
+
+func (conns *Conns) Remove(conn net.Conn) {
+	conns.mutex.Lock()
+	defer conns.mutex.Unlock()
+	delete(conns.conns, conn)
+}
+
+func (conns *Conns) CloseAll() {
+	conns.mutex.Lock()
+	defer conns.mutex.Unlock()
+	conns.isClosed = true
+	for conn, _ := range conns.conns {
+		conn.Close()
+	}
+	conns.conns = make(map[net.Conn]bool)
+}
+
+// IPAddressFromAddr is a helper which extracts an IP address
+// from a net.Addr or returns "" if there is no IP address.
+func IPAddressFromAddr(addr net.Addr) string {
+	ipAddress := ""
+	if addr != nil {
+		host, _, err := net.SplitHostPort(addr.String())
+		if err == nil {
+			ipAddress = host
+		}
+	}
+	return ipAddress
+}

+ 78 - 0
psiphon/common/protocol.go

@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2016, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package common
+
+const (
+	TUNNEL_PROTOCOL_SSH                  = "SSH"
+	TUNNEL_PROTOCOL_OBFUSCATED_SSH       = "OSSH"
+	TUNNEL_PROTOCOL_UNFRONTED_MEEK       = "UNFRONTED-MEEK-OSSH"
+	TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS = "UNFRONTED-MEEK-HTTPS-OSSH"
+	TUNNEL_PROTOCOL_FRONTED_MEEK         = "FRONTED-MEEK-OSSH"
+	TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP    = "FRONTED-MEEK-HTTP-OSSH"
+
+	SERVER_ENTRY_SOURCE_EMBEDDED  = "EMBEDDED"
+	SERVER_ENTRY_SOURCE_REMOTE    = "REMOTE"
+	SERVER_ENTRY_SOURCE_DISCOVERY = "DISCOVERY"
+	SERVER_ENTRY_SOURCE_TARGET    = "TARGET"
+
+	CAPABILITY_SSH_API_REQUESTS            = "ssh-api-requests"
+	CAPABILITY_UNTUNNELED_WEB_API_REQUESTS = "handshake"
+
+	PSIPHON_API_HANDSHAKE_REQUEST_NAME           = "psiphon-handshake"
+	PSIPHON_API_CONNECTED_REQUEST_NAME           = "psiphon-connected"
+	PSIPHON_API_STATUS_REQUEST_NAME              = "psiphon-status"
+	PSIPHON_API_CLIENT_VERIFICATION_REQUEST_NAME = "psiphon-client-verification"
+
+	PSIPHON_API_CLIENT_SESSION_ID_LENGTH = 16
+)
+
+var SupportedTunnelProtocols = []string{
+	TUNNEL_PROTOCOL_FRONTED_MEEK,
+	TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP,
+	TUNNEL_PROTOCOL_UNFRONTED_MEEK,
+	TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
+	TUNNEL_PROTOCOL_OBFUSCATED_SSH,
+	TUNNEL_PROTOCOL_SSH,
+}
+
+var SupportedServerEntrySources = []string{
+	SERVER_ENTRY_SOURCE_EMBEDDED,
+	SERVER_ENTRY_SOURCE_REMOTE,
+	SERVER_ENTRY_SOURCE_DISCOVERY,
+	SERVER_ENTRY_SOURCE_TARGET,
+}
+
+func TunnelProtocolUsesSSH(protocol string) bool {
+	return true
+}
+
+func TunnelProtocolUsesObfuscatedSSH(protocol string) bool {
+	return protocol != TUNNEL_PROTOCOL_SSH
+}
+
+func TunnelProtocolUsesMeekHTTP(protocol string) bool {
+	return protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK ||
+		protocol == TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP
+}
+
+func TunnelProtocolUsesMeekHTTPS(protocol string) bool {
+	return protocol == TUNNEL_PROTOCOL_FRONTED_MEEK ||
+		protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS
+}

+ 146 - 0
psiphon/common/reloader.go

@@ -0,0 +1,146 @@
+/*
+ * Copyright (c) 2016, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package common
+
+import (
+	"os"
+	"sync"
+)
+
+// IsFileChanged uses os.Stat to check if the name, size, or last mod time of the
+// file has changed (which is a heuristic, but sufficiently robust for users of this
+// function). Returns nil if file has not changed; otherwise, returns a changed
+// os.FileInfo which may be used to check for subsequent changes.
+func IsFileChanged(path string, previousFileInfo os.FileInfo) (os.FileInfo, error) {
+
+	fileInfo, err := os.Stat(path)
+	if err != nil {
+		return nil, ContextError(err)
+	}
+
+	changed := previousFileInfo == nil ||
+		fileInfo.Name() != previousFileInfo.Name() ||
+		fileInfo.Size() != previousFileInfo.Size() ||
+		fileInfo.ModTime() != previousFileInfo.ModTime()
+
+	if !changed {
+		return nil, nil
+	}
+
+	return fileInfo, nil
+}
+
+// Reloader represents a read-only, in-memory reloadable data object. For example,
+// a JSON data file that is loaded into memory and accessed for read-only lookups;
+// and from time to time may be reloaded from the same file, updating the memory
+// copy.
+type Reloader interface {
+
+	// Reload reloads the data object. Reload returns a flag indicating if the
+	// reloadable target has changed and reloaded or remains unchanged. By
+	// convention, when reloading fails the Reloader should revert to its previous
+	// in-memory state.
+	Reload() (bool, error)
+
+	// WillReload indicates if the data object is capable of reloading.
+	WillReload() bool
+
+	// LogDescription returns a description to be used for logging
+	// events related to the Reloader.
+	LogDescription() string
+}
+
+// ReloadableFile is a file-backed Reloader. This type is intended to be embedded
+// in other types that add the actual reloadable data structures.
+//
+// ReloadableFile has a multi-reader mutex for synchronization. Its Reload() function
+// will obtain a write lock before reloading the data structures. Actually reloading
+// action is to be provided via the reloadAction callback (for example, read the contents
+// of the file and unmarshall the contents into data structures). All read access to
+// the data structures should be guarded by RLocks on the ReloadableFile mutex.
+//
+// reloadAction must ensure that data structures revert to their previous state when
+// a reload fails.
+//
+type ReloadableFile struct {
+	sync.RWMutex
+	fileName     string
+	fileInfo     os.FileInfo
+	reloadAction func(string) error
+}
+
+// NewReloadableFile initializes a new ReloadableFile
+func NewReloadableFile(
+	fileName string,
+	reloadAction func(string) error) ReloadableFile {
+
+	return ReloadableFile{
+		fileName:     fileName,
+		reloadAction: reloadAction,
+	}
+}
+
+// WillReload indicates whether the ReloadableFile is capable
+// of reloading.
+func (reloadable *ReloadableFile) WillReload() bool {
+	return reloadable.fileName != ""
+}
+
+// Reload checks if the underlying file has changed (using IsFileChanged semantics, which
+// are heuristics) and, when changed, invokes the reloadAction callback which should
+// reload, from the file, the in-memory data structures.
+// All data structure readers should be blocked by the ReloadableFile mutex.
+func (reloadable *ReloadableFile) Reload() (bool, error) {
+
+	if !reloadable.WillReload() {
+		return false, nil
+	}
+
+	// Check whether the file has changed _before_ blocking readers
+
+	reloadable.RLock()
+	changedFileInfo, err := IsFileChanged(reloadable.fileName, reloadable.fileInfo)
+	reloadable.RUnlock()
+	if err != nil {
+		return false, ContextError(err)
+	}
+
+	if changedFileInfo == nil {
+		return false, nil
+	}
+
+	// ...now block readers
+
+	reloadable.Lock()
+	defer reloadable.Unlock()
+
+	err = reloadable.reloadAction(reloadable.fileName)
+	if err != nil {
+		return false, ContextError(err)
+	}
+
+	reloadable.fileInfo = changedFileInfo
+
+	return true, nil
+}
+
+func (reloadable *ReloadableFile) LogDescription() string {
+	return reloadable.fileName
+}

+ 28 - 0
psiphon/common/reloader_test.go

@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2016, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package common
+
+import (
+	"testing"
+)
+
+func TestReloader(t *testing.T) {
+	// TODO
+}

+ 132 - 0
psiphon/common/throttled.go

@@ -0,0 +1,132 @@
+/*
+ * Copyright (c) 2016, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package common
+
+import (
+	"io"
+	"net"
+	"sync/atomic"
+
+	"github.com/Psiphon-Inc/ratelimit"
+)
+
+// RateLimits specify the rate limits for a ThrottledConn.
+type RateLimits struct {
+
+	// DownstreamUnlimitedBytes specifies the number of downstream
+	// bytes to transfer, approximately, before starting rate
+	// limiting.
+	DownstreamUnlimitedBytes int64
+
+	// DownstreamBytesPerSecond specifies a rate limit for downstream
+	// data transfer. The default, 0, is no limit.
+	DownstreamBytesPerSecond int64
+
+	// UpstreamUnlimitedBytes specifies the number of upstream
+	// bytes to transfer, approximately, before starting rate
+	// limiting.
+	UpstreamUnlimitedBytes int64
+
+	// UpstreamBytesPerSecond specifies a rate limit for upstream
+	// data transfer. The default, 0, is no limit.
+	UpstreamBytesPerSecond int64
+}
+
+// ThrottledConn wraps a net.Conn with read and write rate limiters.
+// Rates are specified as bytes per second. Optional unlimited byte
+// counts allow for a number of bytes to read or write before
+// applying rate limiting. Specify limit values of 0 to set no rate
+// limit (unlimited counts are ignored in this case).
+// The underlying rate limiter uses the token bucket algorithm to
+// calculate delay times for read and write operations.
+type ThrottledConn struct {
+	net.Conn
+	unlimitedReadBytes  int64
+	limitingReads       int32
+	limitedReader       io.Reader
+	unlimitedWriteBytes int64
+	limitingWrites      int32
+	limitedWriter       io.Writer
+}
+
+// NewThrottledConn initializes a new ThrottledConn.
+func NewThrottledConn(conn net.Conn, limits RateLimits) *ThrottledConn {
+
+	// When no limit is specified, the rate limited reader/writer
+	// is simply the base reader/writer.
+
+	var reader io.Reader
+	if limits.DownstreamBytesPerSecond == 0 {
+		reader = conn
+	} else {
+		reader = ratelimit.Reader(conn,
+			ratelimit.NewBucketWithRate(
+				float64(limits.DownstreamBytesPerSecond),
+				limits.DownstreamBytesPerSecond))
+	}
+
+	var writer io.Writer
+	if limits.UpstreamBytesPerSecond == 0 {
+		writer = conn
+	} else {
+		writer = ratelimit.Writer(conn,
+			ratelimit.NewBucketWithRate(
+				float64(limits.UpstreamBytesPerSecond),
+				limits.UpstreamBytesPerSecond))
+	}
+
+	return &ThrottledConn{
+		Conn:                conn,
+		unlimitedReadBytes:  limits.DownstreamUnlimitedBytes,
+		limitingReads:       0,
+		limitedReader:       reader,
+		unlimitedWriteBytes: limits.UpstreamUnlimitedBytes,
+		limitingWrites:      0,
+		limitedWriter:       writer,
+	}
+}
+
+func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
+
+	// Use the base reader until the unlimited count is exhausted.
+	if atomic.LoadInt32(&conn.limitingReads) == 0 {
+		if atomic.AddInt64(&conn.unlimitedReadBytes, -int64(len(buffer))) <= 0 {
+			atomic.StoreInt32(&conn.limitingReads, 1)
+		} else {
+			return conn.Read(buffer)
+		}
+	}
+
+	return conn.limitedReader.Read(buffer)
+}
+
+func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
+
+	// Use the base writer until the unlimited count is exhausted.
+	if atomic.LoadInt32(&conn.limitingWrites) == 0 {
+		if atomic.AddInt64(&conn.unlimitedWriteBytes, -int64(len(buffer))) <= 0 {
+			atomic.StoreInt32(&conn.limitingWrites, 1)
+		} else {
+			return conn.Write(buffer)
+		}
+	}
+
+	return conn.limitedWriter.Write(buffer)
+}

+ 28 - 0
psiphon/common/throttled_test.go

@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2016, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package common
+
+import (
+	"testing"
+)
+
+func TestThrottledConn(t *testing.T) {
+	// TODO
+}

+ 178 - 0
psiphon/common/utils.go

@@ -0,0 +1,178 @@
+/*
+ * Copyright (c) 2016, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package common
+
+import (
+	"crypto/rand"
+	"encoding/base64"
+	"encoding/hex"
+	"errors"
+	"fmt"
+	"math/big"
+	"runtime"
+	"strings"
+	"time"
+)
+
+// Contains is a helper function that returns true
+// if the target string is in the list.
+func Contains(list []string, target string) bool {
+	for _, listItem := range list {
+		if listItem == target {
+			return true
+		}
+	}
+	return false
+}
+
+// FlipCoin is a helper function that randomly
+// returns true or false. If the underlying random
+// number generator fails, FlipCoin still returns
+// a result.
+func FlipCoin() bool {
+	randomInt, _ := MakeSecureRandomInt(2)
+	return randomInt == 1
+}
+
+// MakeSecureRandomInt is a helper function that wraps
+// MakeSecureRandomInt64.
+func MakeSecureRandomInt(max int) (int, error) {
+	randomInt, err := MakeSecureRandomInt64(int64(max))
+	return int(randomInt), err
+}
+
+// MakeSecureRandomInt64 is a helper function that wraps
+// crypto/rand.Int, which returns a uniform random value in [0, max).
+func MakeSecureRandomInt64(max int64) (int64, error) {
+	randomInt, err := rand.Int(rand.Reader, big.NewInt(max))
+	if err != nil {
+		return 0, ContextError(err)
+	}
+	return randomInt.Int64(), nil
+}
+
+// MakeSecureRandomBytes is a helper function that wraps
+// crypto/rand.Read.
+func MakeSecureRandomBytes(length int) ([]byte, error) {
+	randomBytes := make([]byte, length)
+	n, err := rand.Read(randomBytes)
+	if err != nil {
+		return nil, ContextError(err)
+	}
+	if n != length {
+		return nil, ContextError(errors.New("insufficient random bytes"))
+	}
+	return randomBytes, nil
+}
+
+// MakeSecureRandomPadding selects a random padding length in the indicated
+// range and returns a random byte array of the selected length.
+// In the unlikely case where an underlying MakeRandom functions fails,
+// the padding is length 0.
+func MakeSecureRandomPadding(minLength, maxLength int) ([]byte, error) {
+	var padding []byte
+	paddingSize, err := MakeSecureRandomInt(maxLength - minLength)
+	if err != nil {
+		return nil, ContextError(err)
+	}
+	paddingSize += minLength
+	padding, err = MakeSecureRandomBytes(paddingSize)
+	if err != nil {
+		return nil, ContextError(err)
+	}
+	return padding, nil
+}
+
+// MakeRandomPeriod returns a random duration, within a given range.
+// In the unlikely case where an  underlying MakeRandom functions fails,
+// the period is the minimum.
+func MakeRandomPeriod(min, max time.Duration) (time.Duration, error) {
+	period, err := MakeSecureRandomInt64(max.Nanoseconds() - min.Nanoseconds())
+	if err != nil {
+		return 0, ContextError(err)
+	}
+	return min + time.Duration(period), nil
+}
+
+// MakeRandomStringHex returns a hex encoded random string.
+// byteLength specifies the pre-encoded data length.
+func MakeRandomStringHex(byteLength int) (string, error) {
+	bytes, err := MakeSecureRandomBytes(byteLength)
+	if err != nil {
+		return "", ContextError(err)
+	}
+	return hex.EncodeToString(bytes), nil
+}
+
+// MakeRandomStringBase64 returns a base64 encoded random string.
+// byteLength specifies the pre-encoded data length.
+func MakeRandomStringBase64(byteLength int) (string, error) {
+	bytes, err := MakeSecureRandomBytes(byteLength)
+	if err != nil {
+		return "", ContextError(err)
+	}
+	return base64.RawURLEncoding.EncodeToString(bytes), nil
+}
+
+// GetCurrentTimestamp returns the current time in UTC as
+// an RFC 3339 formatted string.
+func GetCurrentTimestamp() string {
+	return time.Now().UTC().Format(time.RFC3339)
+}
+
+// TruncateTimestampToHour truncates an RFC 3339 formatted string
+// to hour granularity. If the input is not a valid format, the
+// result is "".
+func TruncateTimestampToHour(timestamp string) string {
+	t, err := time.Parse(time.RFC3339, timestamp)
+	if err != nil {
+		return ""
+	}
+	return t.Truncate(1 * time.Hour).Format(time.RFC3339)
+}
+
+// getFunctionName is a helper that extracts a simple function name from
+// full name returned byruntime.Func.Name(). This is used to declutter
+// log messages containing function names.
+func getFunctionName(pc uintptr) string {
+	funcName := runtime.FuncForPC(pc).Name()
+	index := strings.LastIndex(funcName, "/")
+	if index != -1 {
+		funcName = funcName[index+1:]
+	}
+	return funcName
+}
+
+// GetParentContext returns the parent function name and source file
+// line number.
+func GetParentContext() string {
+	pc, _, line, _ := runtime.Caller(2)
+	return fmt.Sprintf("%s#%d", getFunctionName(pc), line)
+}
+
+// ContextError prefixes an error message with the current function
+// name and source file line number.
+func ContextError(err error) error {
+	if err == nil {
+		return nil
+	}
+	pc, _, line, _ := runtime.Caller(1)
+	return fmt.Errorf("%s#%d: %s", getFunctionName(pc), line, err)
+}

+ 12 - 3
psiphon/utils_test.go → psiphon/common/utils_test.go

@@ -17,7 +17,7 @@
  *
  */
 
-package psiphon
+package common
 
 import (
 	"testing"
@@ -28,7 +28,11 @@ func TestMakeRandomPeriod(t *testing.T) {
 	min := 1 * time.Nanosecond
 	max := 10000 * time.Nanosecond
 
-	res1 := MakeRandomPeriod(min, max)
+	res1, err := MakeRandomPeriod(min, max)
+
+	if err != nil {
+		t.Error("MakeRandomPeriod failed: %s", err)
+	}
 
 	if res1 < min {
 		t.Error("duration should not be less than min")
@@ -38,7 +42,12 @@ func TestMakeRandomPeriod(t *testing.T) {
 		t.Error("duration should not be more than max")
 	}
 
-	res2 := MakeRandomPeriod(min, max)
+	res2, err := MakeRandomPeriod(min, max)
+
+	if err != nil {
+		t.Error("MakeRandomPeriod failed: %s", err)
+	}
+
 	if res1 == res2 {
 		t.Error("duration should have randomness difference between calls")
 	}

+ 21 - 13
psiphon/config.go

@@ -27,6 +27,8 @@ import (
 	"os"
 	"strconv"
 	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // TODO: allow all params to be configured
@@ -55,7 +57,6 @@ const (
 	FETCH_REMOTE_SERVER_LIST_TIMEOUT_SECONDS             = 30
 	FETCH_REMOTE_SERVER_LIST_RETRY_PERIOD_SECONDS        = 30
 	FETCH_REMOTE_SERVER_LIST_STALE_PERIOD                = 6 * time.Hour
-	PSIPHON_API_CLIENT_SESSION_ID_LENGTH                 = 16
 	PSIPHON_API_SERVER_TIMEOUT_SECONDS                   = 20
 	PSIPHON_API_SHUTDOWN_SERVER_TIMEOUT                  = 1 * time.Second
 	PSIPHON_API_STATUS_REQUEST_PERIOD_MIN                = 5 * time.Minute
@@ -382,6 +383,9 @@ type Config struct {
 	// and for asynchronous operations such as fetch remote server list to complete.
 	// If omitted, the default value is ESTABLISH_TUNNEL_PAUSE_PERIOD_SECONDS.
 	EstablishTunnelPausePeriodSeconds *int
+
+	// RateLimits specify throttling configuration for the tunnel.
+	RateLimits common.RateLimits
 }
 
 // LoadConfig parses and validates a JSON format Psiphon config JSON
@@ -390,7 +394,7 @@ func LoadConfig(configJson []byte) (*Config, error) {
 	var config Config
 	err := json.Unmarshal(configJson, &config)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// Do SetEmitDiagnosticNotices first, to ensure config file errors are emitted.
@@ -400,18 +404,18 @@ func LoadConfig(configJson []byte) (*Config, error) {
 
 	// These fields are required; the rest are optional
 	if config.PropagationChannelId == "" {
-		return nil, ContextError(
+		return nil, common.ContextError(
 			errors.New("propagation channel ID is missing from the configuration file"))
 	}
 	if config.SponsorId == "" {
-		return nil, ContextError(
+		return nil, common.ContextError(
 			errors.New("sponsor ID is missing from the configuration file"))
 	}
 
 	if config.DataStoreDirectory == "" {
 		config.DataStoreDirectory, err = os.Getwd()
 		if err != nil {
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 	}
 
@@ -421,13 +425,13 @@ func LoadConfig(configJson []byte) (*Config, error) {
 
 	_, err = strconv.Atoi(config.ClientVersion)
 	if err != nil {
-		return nil, ContextError(
+		return nil, common.ContextError(
 			fmt.Errorf("invalid client version: %s", err))
 	}
 
 	if config.TunnelProtocol != "" {
-		if !Contains(SupportedTunnelProtocols, config.TunnelProtocol) {
-			return nil, ContextError(
+		if !common.Contains(common.SupportedTunnelProtocols, config.TunnelProtocol) {
+			return nil, common.ContextError(
 				errors.New("invalid tunnel protocol"))
 		}
 	}
@@ -446,24 +450,28 @@ func LoadConfig(configJson []byte) (*Config, error) {
 	}
 
 	if config.NetworkConnectivityChecker != nil {
-		return nil, ContextError(errors.New("NetworkConnectivityChecker interface must be set at runtime"))
+		return nil, common.ContextError(
+			errors.New("NetworkConnectivityChecker interface must be set at runtime"))
 	}
 
 	if config.DeviceBinder != nil {
-		return nil, ContextError(errors.New("DeviceBinder interface must be set at runtime"))
+		return nil, common.ContextError(
+			errors.New("DeviceBinder interface must be set at runtime"))
 	}
 
 	if config.DnsServerGetter != nil {
-		return nil, ContextError(errors.New("DnsServerGetter interface must be set at runtime"))
+		return nil, common.ContextError(
+			errors.New("DnsServerGetter interface must be set at runtime"))
 	}
 
 	if config.HostNameTransformer != nil {
-		return nil, ContextError(errors.New("HostNameTransformer interface must be set at runtime"))
+		return nil, common.ContextError(
+			errors.New("HostNameTransformer interface must be set at runtime"))
 	}
 
 	if config.UpgradeDownloadUrl != "" &&
 		(config.UpgradeDownloadClientVersionHeader == "" || config.UpgradeDownloadFilename == "") {
-		return nil, ContextError(errors.New(
+		return nil, common.ContextError(errors.New(
 			"UpgradeDownloadUrl requires UpgradeDownloadClientVersionHeader and UpgradeDownloadFilename"))
 	}
 

+ 11 - 9
psiphon/controller.go

@@ -29,6 +29,8 @@ import (
 	"net"
 	"sync"
 	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // Controller is a tunnel lifecycle coordinator. It manages lists of servers to
@@ -51,8 +53,8 @@ type Controller struct {
 	establishWaitGroup             *sync.WaitGroup
 	stopEstablishingBroadcast      chan struct{}
 	candidateServerEntries         chan *candidateServerEntry
-	establishPendingConns          *Conns
-	untunneledPendingConns         *Conns
+	establishPendingConns          *common.Conns
+	untunneledPendingConns         *common.Conns
 	untunneledDialConfig           *DialConfig
 	splitTunnelClassifier          *SplitTunnelClassifier
 	signalFetchRemoteServerList    chan struct{}
@@ -83,7 +85,7 @@ func NewController(config *Config) (controller *Controller, err error) {
 	// used across all tunnels established by the controller.
 	sessionId, err := MakeSessionId()
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	NoticeSessionId(sessionId)
 
@@ -92,7 +94,7 @@ func NewController(config *Config) (controller *Controller, err error) {
 	// used to exclude these requests and connection from VPN routing.
 	// TODO: fetch remote server list and untunneled upgrade download should remove
 	// their completed conns from untunneledPendingConns.
-	untunneledPendingConns := new(Conns)
+	untunneledPendingConns := new(common.Conns)
 	untunneledDialConfig := &DialConfig{
 		UpstreamProxyUrl:              config.UpstreamProxyUrl,
 		UpstreamProxyCustomHeaders:    config.UpstreamProxyCustomHeaders,
@@ -121,7 +123,7 @@ func NewController(config *Config) (controller *Controller, err error) {
 		establishedOnce:                false,
 		startedConnectedReporter:       false,
 		isEstablishing:                 false,
-		establishPendingConns:          new(Conns),
+		establishPendingConns:          new(common.Conns),
 		untunneledPendingConns:         untunneledPendingConns,
 		untunneledDialConfig:           untunneledDialConfig,
 		impairedProtocolClassification: make(map[string]int),
@@ -677,7 +679,7 @@ func (controller *Controller) classifyImpairedProtocol(failedTunnel *Tunnel) {
 	} else {
 		controller.impairedProtocolClassification[failedTunnel.protocol] = 0
 	}
-	if len(controller.getImpairedProtocols()) == len(SupportedTunnelProtocols) {
+	if len(controller.getImpairedProtocols()) == len(common.SupportedTunnelProtocols) {
 		// Reset classification if all protocols are classified as impaired as
 		// the network situation (or attack) may not be protocol-specific.
 		// TODO: compare against count of distinct supported protocols for
@@ -877,7 +879,7 @@ func (controller *Controller) Dial(
 
 	tunnel := controller.getNextActiveTunnel()
 	if tunnel == nil {
-		return nil, ContextError(errors.New("no active tunnels"))
+		return nil, common.ContextError(errors.New("no active tunnels"))
 	}
 
 	// Perform split tunnel classification when feature is enabled, and if the remote
@@ -886,7 +888,7 @@ func (controller *Controller) Dial(
 
 		host, _, err := net.SplitHostPort(remoteAddr)
 		if err != nil {
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 
 		// Note: a possible optimization, when split tunnel is active and IsUntunneled performs
@@ -903,7 +905,7 @@ func (controller *Controller) Dial(
 
 	tunneledConn, err := tunnel.Dial(remoteAddr, alwaysTunnel, downstreamConn)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return tunneledConn, nil

+ 14 - 13
psiphon/controller_test.go

@@ -35,6 +35,7 @@ import (
 	"time"
 
 	socks "github.com/Psiphon-Inc/goptlib"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/elazarl/goproxy"
 )
 
@@ -185,7 +186,7 @@ func TestSSH(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			expectNoServerEntries:    false,
-			protocol:                 TUNNEL_PROTOCOL_SSH,
+			protocol:                 common.TUNNEL_PROTOCOL_SSH,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
@@ -202,7 +203,7 @@ func TestObfuscatedSSH(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			expectNoServerEntries:    false,
-			protocol:                 TUNNEL_PROTOCOL_OBFUSCATED_SSH,
+			protocol:                 common.TUNNEL_PROTOCOL_OBFUSCATED_SSH,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
@@ -219,7 +220,7 @@ func TestUnfrontedMeek(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			expectNoServerEntries:    false,
-			protocol:                 TUNNEL_PROTOCOL_UNFRONTED_MEEK,
+			protocol:                 common.TUNNEL_PROTOCOL_UNFRONTED_MEEK,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
@@ -236,7 +237,7 @@ func TestUnfrontedMeekWithTransformer(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			expectNoServerEntries:    false,
-			protocol:                 TUNNEL_PROTOCOL_UNFRONTED_MEEK,
+			protocol:                 common.TUNNEL_PROTOCOL_UNFRONTED_MEEK,
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
@@ -253,7 +254,7 @@ func TestFrontedMeek(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			expectNoServerEntries:    false,
-			protocol:                 TUNNEL_PROTOCOL_FRONTED_MEEK,
+			protocol:                 common.TUNNEL_PROTOCOL_FRONTED_MEEK,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
@@ -270,7 +271,7 @@ func TestFrontedMeekWithTransformer(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			expectNoServerEntries:    false,
-			protocol:                 TUNNEL_PROTOCOL_FRONTED_MEEK,
+			protocol:                 common.TUNNEL_PROTOCOL_FRONTED_MEEK,
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
@@ -287,7 +288,7 @@ func TestFrontedMeekHTTP(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			expectNoServerEntries:    false,
-			protocol:                 TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP,
+			protocol:                 common.TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP,
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
@@ -304,7 +305,7 @@ func TestUnfrontedMeekHTTPS(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			expectNoServerEntries:    false,
-			protocol:                 TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
+			protocol:                 common.TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
@@ -321,7 +322,7 @@ func TestUnfrontedMeekHTTPSWithTransformer(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			expectNoServerEntries:    false,
-			protocol:                 TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
+			protocol:                 common.TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
@@ -355,7 +356,7 @@ func TestObfuscatedSSHWithUpstreamProxy(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			expectNoServerEntries:    false,
-			protocol:                 TUNNEL_PROTOCOL_OBFUSCATED_SSH,
+			protocol:                 common.TUNNEL_PROTOCOL_OBFUSCATED_SSH,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
@@ -372,7 +373,7 @@ func TestUnfrontedMeekWithUpstreamProxy(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			expectNoServerEntries:    false,
-			protocol:                 TUNNEL_PROTOCOL_UNFRONTED_MEEK,
+			protocol:                 common.TUNNEL_PROTOCOL_UNFRONTED_MEEK,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
@@ -389,7 +390,7 @@ func TestUnfrontedMeekHTTPSWithUpstreamProxy(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			expectNoServerEntries:    false,
-			protocol:                 TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
+			protocol:                 common.TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
 			disableEstablishing:      false,
@@ -929,7 +930,7 @@ func hasExpectedCustomHeaders(h http.Header) bool {
 		}
 		// Order may not be the same
 		for _, value := range values {
-			if !Contains(h[name], value) {
+			if !common.Contains(h[name], value) {
 				return false
 			}
 		}

+ 34 - 33
psiphon/dataStore.go

@@ -32,6 +32,7 @@ import (
 	"time"
 
 	"github.com/Psiphon-Inc/bolt"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // The BoltDB dataStore implementation is an alternative to the sqlite3-based
@@ -163,7 +164,7 @@ func StoreServerEntry(serverEntry *ServerEntry, replaceIfExists bool) error {
 	// so instead of skipping we fail with an error.
 	err := ValidateServerEntry(serverEntry)
 	if err != nil {
-		return ContextError(errors.New("invalid server entry"))
+		return common.ContextError(errors.New("invalid server entry"))
 	}
 
 	// BoltDB implementation note:
@@ -199,16 +200,16 @@ func StoreServerEntry(serverEntry *ServerEntry, replaceIfExists bool) error {
 
 		data, err := json.Marshal(serverEntry)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 		err = serverEntries.Put([]byte(serverEntry.IpAddress), data)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 
 		err = insertRankedServerEntry(tx, serverEntry.IpAddress, 1)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 
 		NoticeInfo("updated server %s", serverEntry.IpAddress)
@@ -216,7 +217,7 @@ func StoreServerEntry(serverEntry *ServerEntry, replaceIfExists bool) error {
 		return nil
 	})
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	return nil
@@ -237,7 +238,7 @@ func StoreServerEntries(serverEntries []*ServerEntry, replaceIfExists bool) erro
 	for _, serverEntry := range serverEntries {
 		err := StoreServerEntry(serverEntry, replaceIfExists)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 	}
 
@@ -272,7 +273,7 @@ func PromoteServerEntry(ipAddress string) error {
 	})
 
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	return nil
 }
@@ -288,7 +289,7 @@ func getRankedServerEntries(tx *bolt.Tx) ([]string, error) {
 	rankedServerEntries := make([]string, 0)
 	err := json.Unmarshal(data, &rankedServerEntries)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	return rankedServerEntries, nil
 }
@@ -296,13 +297,13 @@ func getRankedServerEntries(tx *bolt.Tx) ([]string, error) {
 func setRankedServerEntries(tx *bolt.Tx, rankedServerEntries []string) error {
 	data, err := json.Marshal(rankedServerEntries)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	bucket := tx.Bucket([]byte(rankedServerEntriesBucket))
 	err = bucket.Put([]byte(rankedServerEntriesKey), data)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	return nil
@@ -311,7 +312,7 @@ func setRankedServerEntries(tx *bolt.Tx, rankedServerEntries []string) error {
 func insertRankedServerEntry(tx *bolt.Tx, serverEntryId string, position int) error {
 	rankedServerEntries, err := getRankedServerEntries(tx)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	// BoltDB implementation note:
@@ -346,7 +347,7 @@ func insertRankedServerEntry(tx *bolt.Tx, serverEntryId string, position int) er
 
 	err = setRankedServerEntries(tx, rankedServerEntries)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	return nil
@@ -356,7 +357,7 @@ func serverEntrySupportsProtocol(serverEntry *ServerEntry, protocol string) bool
 	// Note: for meek, the capabilities are FRONTED-MEEK and UNFRONTED-MEEK
 	// and the additonal OSSH service is assumed to be available internally.
 	requiredCapability := strings.TrimSuffix(protocol, "-OSSH")
-	return Contains(serverEntry.Capabilities, requiredCapability)
+	return common.Contains(serverEntry.Capabilities, requiredCapability)
 }
 
 // ServerEntryIterator is used to iterate over
@@ -397,7 +398,7 @@ func NewServerEntryIterator(config *Config) (iterator *ServerEntryIterator, err
 // newTargetServerEntryIterator is a helper for initializing the TargetServerEntry case
 func newTargetServerEntryIterator(config *Config) (iterator *ServerEntryIterator, err error) {
 	serverEntry, err := DecodeServerEntry(
-		config.TargetServerEntry, GetCurrentTimestamp(), SERVER_ENTRY_SOURCE_TARGET)
+		config.TargetServerEntry, common.GetCurrentTimestamp(), common.SERVER_ENTRY_SOURCE_TARGET)
 	if err != nil {
 		return nil, err
 	}
@@ -407,7 +408,7 @@ func newTargetServerEntryIterator(config *Config) (iterator *ServerEntryIterator
 	if config.TunnelProtocol != "" {
 		// Note: same capability/protocol mapping as in StoreServerEntry
 		requiredCapability := strings.TrimSuffix(config.TunnelProtocol, "-OSSH")
-		if !Contains(serverEntry.Capabilities, requiredCapability) {
+		if !common.Contains(serverEntry.Capabilities, requiredCapability) {
 			return nil, errors.New("TargetServerEntry does not support TunnelProtocol")
 		}
 	}
@@ -478,7 +479,7 @@ func (iterator *ServerEntryIterator) Reset() error {
 		return nil
 	})
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	for i := len(serverEntryIds) - 1; i > iterator.shuffleHeadLength-1; i-- {
@@ -539,7 +540,7 @@ func (iterator *ServerEntryIterator) Next() (serverEntry *ServerEntry, err error
 			return nil
 		})
 		if err != nil {
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 
 		if data == nil {
@@ -554,7 +555,7 @@ func (iterator *ServerEntryIterator) Next() (serverEntry *ServerEntry, err error
 		if err != nil {
 			// In case of data corruption or a bug causing this condition,
 			// do not stop iterating.
-			NoticeAlert("ServerEntryIterator.Next: %s", ContextError(err))
+			NoticeAlert("ServerEntryIterator.Next: %s", common.ContextError(err))
 			continue
 		}
 
@@ -593,7 +594,7 @@ func scanServerEntries(scanner func(*ServerEntry)) error {
 			if err != nil {
 				// In case of data corruption or a bug causing this condition,
 				// do not stop iterating.
-				NoticeAlert("scanServerEntries: %s", ContextError(err))
+				NoticeAlert("scanServerEntries: %s", common.ContextError(err))
 				continue
 			}
 			scanner(serverEntry)
@@ -603,7 +604,7 @@ func scanServerEntries(scanner func(*ServerEntry)) error {
 	})
 
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	return nil
@@ -668,7 +669,7 @@ func GetServerEntryIpAddresses() (ipAddresses []string, err error) {
 	})
 
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return ipAddresses, nil
@@ -690,7 +691,7 @@ func SetSplitTunnelRoutes(region, etag string, data []byte) error {
 	})
 
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	return nil
 }
@@ -707,7 +708,7 @@ func GetSplitTunnelRoutesETag(region string) (etag string, err error) {
 	})
 
 	if err != nil {
-		return "", ContextError(err)
+		return "", common.ContextError(err)
 	}
 	return etag, nil
 }
@@ -729,7 +730,7 @@ func GetSplitTunnelRoutesData(region string) (data []byte, err error) {
 	})
 
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	return data, nil
 }
@@ -747,7 +748,7 @@ func SetUrlETag(url, etag string) error {
 	})
 
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	return nil
 }
@@ -764,7 +765,7 @@ func GetUrlETag(url string) (etag string, err error) {
 	})
 
 	if err != nil {
-		return "", ContextError(err)
+		return "", common.ContextError(err)
 	}
 	return etag, nil
 }
@@ -780,7 +781,7 @@ func SetKeyValue(key, value string) error {
 	})
 
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	return nil
 }
@@ -797,7 +798,7 @@ func GetKeyValue(key string) (value string, err error) {
 	})
 
 	if err != nil {
-		return "", ContextError(err)
+		return "", common.ContextError(err)
 	}
 	return value, nil
 }
@@ -833,7 +834,7 @@ func StoreTunnelStats(tunnelStats []byte) error {
 	})
 
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	return nil
 }
@@ -913,7 +914,7 @@ func TakeOutUnreportedTunnelStats(maxCount int) ([][]byte, error) {
 	})
 
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	return tunnelStats, nil
 }
@@ -935,7 +936,7 @@ func PutBackUnreportedTunnelStats(tunnelStats [][]byte) error {
 	})
 
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	return nil
 }
@@ -957,7 +958,7 @@ func ClearReportedTunnelStats(tunnelStats [][]byte) error {
 	})
 
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	return nil
 }
@@ -990,7 +991,7 @@ func resetAllTunnelStatsToUnreported() error {
 	})
 
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	return nil
 }

+ 14 - 12
psiphon/httpProxy.go

@@ -29,6 +29,8 @@ import (
 	"strings"
 	"sync"
 	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // HttpProxy is a HTTP server that relays HTTP requests through the Psiphon tunnel.
@@ -62,7 +64,7 @@ type HttpProxy struct {
 	urlProxyTunneledClient *http.Client
 	urlProxyDirectRelay    *http.Transport
 	urlProxyDirectClient   *http.Client
-	openConns              *Conns
+	openConns              *common.Conns
 	stopListeningBroadcast chan struct{}
 }
 
@@ -81,7 +83,7 @@ func NewHttpProxy(
 		if IsAddressInUseError(err) {
 			NoticeHttpProxyPortInUse(config.LocalHttpProxyPort)
 		}
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
@@ -141,7 +143,7 @@ func NewHttpProxy(
 		urlProxyTunneledClient: urlProxyTunneledClient,
 		urlProxyDirectRelay:    urlProxyDirectRelay,
 		urlProxyDirectClient:   urlProxyDirectClient,
-		openConns:              new(Conns),
+		openConns:              new(common.Conns),
 		stopListeningBroadcast: make(chan struct{}),
 	}
 	proxy.serveWaitGroup.Add(1)
@@ -198,14 +200,14 @@ func (proxy *HttpProxy) ServeHTTP(responseWriter http.ResponseWriter, request *h
 		hijacker, _ := responseWriter.(http.Hijacker)
 		conn, _, err := hijacker.Hijack()
 		if err != nil {
-			NoticeAlert("%s", ContextError(err))
+			NoticeAlert("%s", common.ContextError(err))
 			http.Error(responseWriter, "", http.StatusInternalServerError)
 			return
 		}
 		go func() {
 			err := proxy.httpConnectHandler(conn, request.URL.Host)
 			if err != nil {
-				NoticeAlert("%s", ContextError(err))
+				NoticeAlert("%s", common.ContextError(err))
 			}
 		}()
 	} else if request.URL.IsAbs() {
@@ -224,12 +226,12 @@ func (proxy *HttpProxy) httpConnectHandler(localConn net.Conn, target string) (e
 	// open connection for data which will never arrive.
 	remoteConn, err := proxy.tunneler.Dial(target, false, localConn)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	defer remoteConn.Close()
 	_, err = localConn.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	LocalProxyRelay(_HTTP_PROXY_TYPE, localConn, remoteConn)
 	return nil
@@ -263,7 +265,7 @@ func (proxy *HttpProxy) urlProxyHandler(responseWriter http.ResponseWriter, requ
 		err = errors.New("missing origin URL")
 	}
 	if err != nil {
-		NoticeAlert("%s", ContextError(FilterUrlError(err)))
+		NoticeAlert("%s", common.ContextError(FilterUrlError(err)))
 		forceClose(responseWriter)
 		return
 	}
@@ -271,7 +273,7 @@ func (proxy *HttpProxy) urlProxyHandler(responseWriter http.ResponseWriter, requ
 	// Origin URL must be well-formed, absolute, and have a scheme of  "http" or "https"
 	url, err := url.ParseRequestURI(originUrl)
 	if err != nil {
-		NoticeAlert("%s", ContextError(FilterUrlError(err)))
+		NoticeAlert("%s", common.ContextError(FilterUrlError(err)))
 		forceClose(responseWriter)
 		return
 	}
@@ -313,7 +315,7 @@ func relayHttpRequest(
 	}
 
 	if err != nil {
-		NoticeAlert("%s", ContextError(FilterUrlError(err)))
+		NoticeAlert("%s", common.ContextError(FilterUrlError(err)))
 		forceClose(responseWriter)
 		return
 	}
@@ -336,7 +338,7 @@ func relayHttpRequest(
 	responseWriter.WriteHeader(response.StatusCode)
 	_, err = io.Copy(responseWriter, response.Body)
 	if err != nil {
-		NoticeAlert("%s", ContextError(err))
+		NoticeAlert("%s", common.ContextError(err))
 		forceClose(responseWriter)
 		return
 	}
@@ -402,7 +404,7 @@ func (proxy *HttpProxy) serve() {
 	default:
 		if err != nil {
 			proxy.tunneler.SignalComponentFailure()
-			NoticeLocalProxyError(_HTTP_PROXY_TYPE, ContextError(err))
+			NoticeLocalProxyError(_HTTP_PROXY_TYPE, common.ContextError(err))
 		}
 	}
 	NoticeInfo("HTTP proxy stopped")

+ 28 - 27
psiphon/meekConn.go

@@ -34,6 +34,7 @@ import (
 	"sync"
 	"time"
 
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/upstreamproxy"
 	"golang.org/x/crypto/nacl/box"
 )
@@ -105,7 +106,7 @@ type MeekConn struct {
 	url                  *url.URL
 	additionalHeaders    map[string]string
 	cookie               *http.Cookie
-	pendingConns         *Conns
+	pendingConns         *common.Conns
 	transport            transporter
 	mutex                sync.Mutex
 	isClosed             bool
@@ -143,7 +144,7 @@ func DialMeek(
 	// which may be interrupted on MeekConn.Close(). This code previously used the establishTunnel
 	// pendingConns here, but that was a lifecycle mismatch: we don't want to abort HTTP transport
 	// connections while MeekConn is still in use
-	pendingConns := new(Conns)
+	pendingConns := new(common.Conns)
 
 	// Use a copy of DialConfig with the meek pendingConns
 	meekDialConfig := new(DialConfig)
@@ -220,7 +221,7 @@ func DialMeek(
 				meekConfig.DialAddress == meekConfig.HostHeader+":80") {
 			url, err := url.Parse(meekDialConfig.UpstreamProxyUrl)
 			if err != nil {
-				return nil, ContextError(err)
+				return nil, common.ContextError(err)
 			}
 			proxyUrl = http.ProxyURL(url)
 			meekDialConfig.UpstreamProxyUrl = ""
@@ -240,7 +241,7 @@ func DialMeek(
 			// Wrap transport with a transport that can perform HTTP proxy auth negotiation
 			transport, err = upstreamproxy.NewProxyAuthTransport(httpTransport, meekDialConfig.UpstreamProxyCustomHeaders)
 			if err != nil {
-				return nil, ContextError(err)
+				return nil, common.ContextError(err)
 			}
 		} else {
 			transport = httpTransport
@@ -259,7 +260,7 @@ func DialMeek(
 	if meekConfig.UseHTTPS {
 		host, _, err := net.SplitHostPort(meekConfig.DialAddress)
 		if err != nil {
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 		additionalHeaders = map[string]string{
 			"X-Psiphon-Fronting-Address": host,
@@ -268,7 +269,7 @@ func DialMeek(
 
 	cookie, err := makeMeekCookie(meekConfig)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// The main loop of a MeekConn is run in the relay() goroutine.
@@ -312,7 +313,7 @@ func DialMeek(
 	// Enable interruption
 	if !dialConfig.PendingConns.Add(meek) {
 		meek.Close()
-		return nil, ContextError(errors.New("pending connections already closed"))
+		return nil, common.ContextError(errors.New("pending connections already closed"))
 	}
 
 	return meek, nil
@@ -350,7 +351,7 @@ func (meek *MeekConn) closed() bool {
 // net.Conn Deadlines are ignored. net.Conn concurrency semantics are supported.
 func (meek *MeekConn) Read(buffer []byte) (n int, err error) {
 	if meek.closed() {
-		return 0, ContextError(errors.New("meek connection is closed"))
+		return 0, common.ContextError(errors.New("meek connection is closed"))
 	}
 	// Block until there is received data to consume
 	var receiveBuffer *bytes.Buffer
@@ -358,7 +359,7 @@ func (meek *MeekConn) Read(buffer []byte) (n int, err error) {
 	case receiveBuffer = <-meek.partialReceiveBuffer:
 	case receiveBuffer = <-meek.fullReceiveBuffer:
 	case <-meek.broadcastClosed:
-		return 0, ContextError(errors.New("meek connection has closed"))
+		return 0, common.ContextError(errors.New("meek connection has closed"))
 	}
 	n, err = receiveBuffer.Read(buffer)
 	meek.replaceReceiveBuffer(receiveBuffer)
@@ -369,7 +370,7 @@ func (meek *MeekConn) Read(buffer []byte) (n int, err error) {
 // net.Conn Deadlines are ignored. net.Conn concurrency semantics are supported.
 func (meek *MeekConn) Write(buffer []byte) (n int, err error) {
 	if meek.closed() {
-		return 0, ContextError(errors.New("meek connection is closed"))
+		return 0, common.ContextError(errors.New("meek connection is closed"))
 	}
 	// Repeats until all n bytes are written
 	n = len(buffer)
@@ -380,7 +381,7 @@ func (meek *MeekConn) Write(buffer []byte) (n int, err error) {
 		case sendBuffer = <-meek.emptySendBuffer:
 		case sendBuffer = <-meek.partialSendBuffer:
 		case <-meek.broadcastClosed:
-			return 0, ContextError(errors.New("meek connection has closed"))
+			return 0, common.ContextError(errors.New("meek connection has closed"))
 		}
 		writeLen := MAX_SEND_PAYLOAD_LENGTH - sendBuffer.Len()
 		if writeLen > 0 {
@@ -407,17 +408,17 @@ func (meek *MeekConn) RemoteAddr() net.Addr {
 
 // Stub implementation of net.Conn.SetDeadline
 func (meek *MeekConn) SetDeadline(t time.Time) error {
-	return ContextError(errors.New("not supported"))
+	return common.ContextError(errors.New("not supported"))
 }
 
 // Stub implementation of net.Conn.SetReadDeadline
 func (meek *MeekConn) SetReadDeadline(t time.Time) error {
-	return ContextError(errors.New("not supported"))
+	return common.ContextError(errors.New("not supported"))
 }
 
 // Stub implementation of net.Conn.SetWriteDeadline
 func (meek *MeekConn) SetWriteDeadline(t time.Time) error {
-	return ContextError(errors.New("not supported"))
+	return common.ContextError(errors.New("not supported"))
 }
 
 func (meek *MeekConn) replaceReceiveBuffer(receiveBuffer *bytes.Buffer) {
@@ -472,14 +473,14 @@ func (meek *MeekConn) relay() {
 			sendPayloadSize, err = sendBuffer.Read(sendPayload)
 			meek.replaceSendBuffer(sendBuffer)
 			if err != nil {
-				NoticeAlert("%s", ContextError(err))
+				NoticeAlert("%s", common.ContextError(err))
 				go meek.Close()
 				return
 			}
 		}
 		receivedPayload, err := meek.roundTrip(sendPayload[:sendPayloadSize])
 		if err != nil {
-			NoticeAlert("%s", ContextError(err))
+			NoticeAlert("%s", common.ContextError(err))
 			go meek.Close()
 			return
 		}
@@ -489,7 +490,7 @@ func (meek *MeekConn) relay() {
 		}
 		receivedPayloadSize, err := meek.readPayload(receivedPayload)
 		if err != nil {
-			NoticeAlert("%s", ContextError(err))
+			NoticeAlert("%s", common.ContextError(err))
 			go meek.Close()
 			return
 		}
@@ -527,7 +528,7 @@ func (meek *MeekConn) readPayload(receivedPayload io.ReadCloser) (totalSize int6
 		n, err := receiveBuffer.ReadFrom(reader)
 		meek.replaceReceiveBuffer(receiveBuffer)
 		if err != nil {
-			return 0, ContextError(err)
+			return 0, common.ContextError(err)
 		}
 		totalSize += n
 		if n == 0 {
@@ -541,7 +542,7 @@ func (meek *MeekConn) readPayload(receivedPayload io.ReadCloser) (totalSize int6
 func (meek *MeekConn) roundTrip(sendPayload []byte) (receivedPayload io.ReadCloser, err error) {
 	request, err := http.NewRequest("POST", meek.url.String(), bytes.NewReader(sendPayload))
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// Don't use the default user agent ("Go 1.1 package http").
@@ -610,10 +611,10 @@ func (meek *MeekConn) roundTrip(sendPayload []byte) (receivedPayload io.ReadClos
 		time.Sleep(MEEK_ROUND_TRIP_RETRY_DELAY)
 	}
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	if response.StatusCode != http.StatusOK {
-		return nil, ContextError(fmt.Errorf("http request failed %d", response.StatusCode))
+		return nil, common.ContextError(fmt.Errorf("http request failed %d", response.StatusCode))
 	}
 	// observe response cookies for meek session key token.
 	// Once found it must be used for all consecutive requests made to the server
@@ -655,7 +656,7 @@ func makeMeekCookie(meekConfig *MeekConfig) (cookie *http.Cookie, err error) {
 	}
 	serializedCookie, err := json.Marshal(cookieData)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// Encrypt the JSON data
@@ -669,12 +670,12 @@ func makeMeekCookie(meekConfig *MeekConfig) (cookie *http.Cookie, err error) {
 	var publicKey [32]byte
 	decodedPublicKey, err := base64.StdEncoding.DecodeString(meekConfig.MeekCookieEncryptionPublicKey)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	copy(publicKey[:], decodedPublicKey)
 	ephemeralPublicKey, ephemeralPrivateKey, err := box.GenerateKey(rand.Reader)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	box := box.Seal(nil, serializedCookie, &nonce, &publicKey, ephemeralPrivateKey)
 	encryptedCookie := make([]byte, 32+len(box))
@@ -685,7 +686,7 @@ func makeMeekCookie(meekConfig *MeekConfig) (cookie *http.Cookie, err error) {
 	obfuscator, err := NewClientObfuscator(
 		&ObfuscatorConfig{Keyword: meekConfig.MeekObfuscatedKey, MaxPadding: MEEK_COOKIE_MAX_PADDING})
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	obfuscatedCookie := obfuscator.SendSeedMessage()
 	seedLen := len(obfuscatedCookie)
@@ -697,9 +698,9 @@ func makeMeekCookie(meekConfig *MeekConfig) (cookie *http.Cookie, err error) {
 	A := int('A')
 	Z := int('Z')
 	// letterIndex is integer in range [int('A'), int('Z')]
-	letterIndex, err := MakeSecureRandomInt(Z - A + 1)
+	letterIndex, err := common.MakeSecureRandomInt(Z - A + 1)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	return &http.Cookie{
 			Name:  string(byte(A + letterIndex)),

+ 6 - 5
psiphon/migrateDataStore_windows.go

@@ -27,6 +27,7 @@ import (
 	"path/filepath"
 
 	_ "github.com/Psiphon-Inc/go-sqlite3"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 var legacyDb *sql.DB
@@ -158,7 +159,7 @@ func (iterator *legacyServerEntryIterator) Next() (serverEntry *ServerEntry, err
 	if !iterator.cursor.Next() {
 		err = iterator.cursor.Err()
 		if err != nil {
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 		// There is no next item
 		return nil, nil
@@ -167,12 +168,12 @@ func (iterator *legacyServerEntryIterator) Next() (serverEntry *ServerEntry, err
 	var data []byte
 	err = iterator.cursor.Scan(&data)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	serverEntry = new(ServerEntry)
 	err = json.Unmarshal(data, serverEntry)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return MakeCompatibleServerEntry(serverEntry), nil
@@ -185,7 +186,7 @@ func (iterator *legacyServerEntryIterator) Reset() error {
 
 	transaction, err := legacyDb.Begin()
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	var cursor *sql.Rows
 
@@ -213,7 +214,7 @@ func (iterator *legacyServerEntryIterator) Reset() error {
 	cursor, err = transaction.Query(query, params...)
 	if err != nil {
 		transaction.Rollback()
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	iterator.transaction = transaction
 	iterator.cursor = cursor

+ 19 - 142
psiphon/net.go

@@ -17,37 +17,6 @@
  *
  */
 
-// for HTTPSServer.ServeTLS:
-/*
-Copyright (c) 2012 The Go Authors. All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are
-met:
-
-   * Redistributions of source code must retain the above copyright
-notice, this list of conditions and the following disclaimer.
-   * Redistributions in binary form must reproduce the above
-copyright notice, this list of conditions and the following disclaimer
-in the documentation and/or other materials provided with the
-distribution.
-   * Neither the name of Google Inc. nor the names of its
-contributors may be used to endorse or promote products derived from
-this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-*/
-
 package psiphon
 
 import (
@@ -66,6 +35,7 @@ import (
 	"time"
 
 	"github.com/Psiphon-Inc/dns"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 const DNS_PORT = 53
@@ -97,7 +67,7 @@ type DialConfig struct {
 	// a conn is added to pendingConns before the network connect begins and
 	// removed from pendingConns once the connect succeeds or fails.
 	// May be nil.
-	PendingConns *Conns
+	PendingConns *common.Conns
 
 	// BindToDevice parameters are used to exclude connections and
 	// associated DNS requests from VPN routing.
@@ -177,53 +147,6 @@ func (TimeoutError) Temporary() bool { return true }
 // Dialer is a custom dialer compatible with http.Transport.Dial.
 type Dialer func(string, string) (net.Conn, error)
 
-// Conns is a synchronized list of Conns that is used to coordinate
-// interrupting a set of goroutines establishing connections, or
-// close a set of open connections, etc.
-// Once the list is closed, no more items may be added to the
-// list (unless it is reset).
-type Conns struct {
-	mutex    sync.Mutex
-	isClosed bool
-	conns    map[net.Conn]bool
-}
-
-func (conns *Conns) Reset() {
-	conns.mutex.Lock()
-	defer conns.mutex.Unlock()
-	conns.isClosed = false
-	conns.conns = make(map[net.Conn]bool)
-}
-
-func (conns *Conns) Add(conn net.Conn) bool {
-	conns.mutex.Lock()
-	defer conns.mutex.Unlock()
-	if conns.isClosed {
-		return false
-	}
-	if conns.conns == nil {
-		conns.conns = make(map[net.Conn]bool)
-	}
-	conns.conns[conn] = true
-	return true
-}
-
-func (conns *Conns) Remove(conn net.Conn) {
-	conns.mutex.Lock()
-	defer conns.mutex.Unlock()
-	delete(conns.conns, conn)
-}
-
-func (conns *Conns) CloseAll() {
-	conns.mutex.Lock()
-	defer conns.mutex.Unlock()
-	conns.isClosed = true
-	for conn, _ := range conns.conns {
-		conn.Close()
-	}
-	conns.conns = make(map[net.Conn]bool)
-}
-
 // LocalProxyRelay sends to remoteConn bytes received from localConn,
 // and sends to localConn bytes received from remoteConn.
 func LocalProxyRelay(proxyType string, localConn, remoteConn net.Conn) {
@@ -233,13 +156,13 @@ func LocalProxyRelay(proxyType string, localConn, remoteConn net.Conn) {
 		defer copyWaitGroup.Done()
 		_, err := io.Copy(localConn, remoteConn)
 		if err != nil {
-			err = fmt.Errorf("Relay failed: %s", ContextError(err))
+			err = fmt.Errorf("Relay failed: %s", common.ContextError(err))
 			NoticeLocalProxyError(proxyType, err)
 		}
 	}()
 	_, err := io.Copy(remoteConn, localConn)
 	if err != nil {
-		err = fmt.Errorf("Relay failed: %s", ContextError(err))
+		err = fmt.Errorf("Relay failed: %s", common.ContextError(err))
 		NoticeLocalProxyError(proxyType, err)
 	}
 	copyWaitGroup.Wait()
@@ -299,7 +222,7 @@ func ResolveIP(host string, conn net.Conn) (addrs []net.IP, ttls []time.Duration
 	// Process the response
 	response, err := dnsConn.ReadMsg()
 	if err != nil {
-		return nil, nil, ContextError(err)
+		return nil, nil, common.ContextError(err)
 	}
 	addrs = make([]net.IP, 0)
 	ttls = make([]time.Duration, 0)
@@ -334,7 +257,7 @@ func MakeUntunneledHttpsClient(
 
 	urlComponents, err := url.Parse(requestUrl)
 	if err != nil {
-		return nil, "", ContextError(err)
+		return nil, "", common.ContextError(err)
 	}
 
 	urlComponents.Scheme = "http"
@@ -396,13 +319,13 @@ func MakeTunneledHttpClient(
 
 	if config.UseTrustedCACertificatesForStockTLS {
 		if config.TrustedCACertificatesFilename == "" {
-			return nil, ContextError(errors.New(
+			return nil, common.ContextError(errors.New(
 				"UseTrustedCACertificatesForStockTLS requires TrustedCACertificatesFilename"))
 		}
 		rootCAs := x509.NewCertPool()
 		certData, err := ioutil.ReadFile(config.TrustedCACertificatesFilename)
 		if err != nil {
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 		rootCAs.AppendCertsFromPEM(certData)
 		transport.TLSClientConfig = &tls.Config{RootCAs: rootCAs}
@@ -431,13 +354,13 @@ func MakeDownloadHttpClient(
 	if tunnel != nil {
 		httpClient, err = MakeTunneledHttpClient(config, tunnel, requestTimeout)
 		if err != nil {
-			return nil, "", ContextError(err)
+			return nil, "", common.ContextError(err)
 		}
 	} else {
 		httpClient, requestUrl, err = MakeUntunneledHttpsClient(
 			untunneledDialConfig, nil, requestUrl, requestTimeout)
 		if err != nil {
-			return nil, "", ContextError(err)
+			return nil, "", common.ContextError(err)
 		}
 	}
 
@@ -470,13 +393,13 @@ func ResumeDownload(
 
 	file, err := os.OpenFile(partialFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
 	if err != nil {
-		return 0, "", ContextError(err)
+		return 0, "", common.ContextError(err)
 	}
 	defer file.Close()
 
 	fileInfo, err := file.Stat()
 	if err != nil {
-		return 0, "", ContextError(err)
+		return 0, "", common.ContextError(err)
 	}
 
 	// A partial download should have an ETag which is to be sent with the
@@ -494,14 +417,14 @@ func ResumeDownload(
 		if err != nil {
 			os.Remove(partialFilename)
 			os.Remove(partialETagFilename)
-			return 0, "", ContextError(
+			return 0, "", common.ContextError(
 				fmt.Errorf("failed to load partial download ETag: %s", err))
 		}
 	}
 
 	request, err := http.NewRequest("GET", requestUrl, nil)
 	if err != nil {
-		return 0, "", ContextError(err)
+		return 0, "", common.ContextError(err)
 	}
 
 	request.Header.Add("Range", fmt.Sprintf("bytes=%d-", fileInfo.Size()))
@@ -546,7 +469,7 @@ func ResumeDownload(
 		err = fmt.Errorf("unexpected response status code: %d", response.StatusCode)
 	}
 	if err != nil {
-		return 0, "", ContextError(err)
+		return 0, "", common.ContextError(err)
 	}
 	defer response.Body.Close()
 
@@ -557,7 +480,7 @@ func ResumeDownload(
 		// simply failing and relying on the caller's retry schedule.
 		os.Remove(partialFilename)
 		os.Remove(partialETagFilename)
-		return 0, "", ContextError(errors.New("partial download ETag mismatch"))
+		return 0, "", common.ContextError(errors.New("partial download ETag mismatch"))
 
 	} else if response.StatusCode == http.StatusNotModified {
 		// This status code is possible in the "If-None-Match" case. Don't leave
@@ -580,14 +503,14 @@ func ResumeDownload(
 	// an error; the caller may use this to report partial download progress.
 
 	if err != nil {
-		return n, "", ContextError(err)
+		return n, "", common.ContextError(err)
 	}
 
 	// Ensure the file is flushed to disk. The deferred close
 	// will be a noop when this succeeds.
 	err = file.Close()
 	if err != nil {
-		return n, "", ContextError(err)
+		return n, "", common.ContextError(err)
 	}
 
 	// Remove if exists, to enable rename
@@ -595,56 +518,10 @@ func ResumeDownload(
 
 	err = os.Rename(partialFilename, downloadFilename)
 	if err != nil {
-		return n, "", ContextError(err)
+		return n, "", common.ContextError(err)
 	}
 
 	os.Remove(partialETagFilename)
 
 	return n, responseETag, nil
 }
-
-// IPAddressFromAddr is a helper which extracts an IP address
-// from a net.Addr or returns "" if there is no IP address.
-func IPAddressFromAddr(addr net.Addr) string {
-	ipAddress := ""
-	if addr != nil {
-		host, _, err := net.SplitHostPort(addr.String())
-		if err == nil {
-			ipAddress = host
-		}
-	}
-	return ipAddress
-}
-
-// HTTPSServer is a wrapper around http.Server which adds the
-// ServeTLS function.
-type HTTPSServer struct {
-	http.Server
-}
-
-// ServeTLS is a offers the equivalent interface as http.Serve.
-// The http package has both ListenAndServe and ListenAndServeTLS higher-
-// level interfaces, but only Serve (not TLS) offers a lower-level interface that
-// allows the caller to keep a refererence to the Listener, allowing for external
-// shutdown. ListenAndServeTLS also requires the TLS cert and key to be in files
-// and we avoid that here.
-// tcpKeepAliveListener is used in http.ListenAndServeTLS but not exported,
-// so we use a copy from https://golang.org/src/net/http/server.go.
-func (server *HTTPSServer) ServeTLS(listener net.Listener) error {
-	tlsListener := tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, server.TLSConfig)
-	return server.Serve(tlsListener)
-}
-
-type tcpKeepAliveListener struct {
-	*net.TCPListener
-}
-
-func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
-	tc, err := ln.AcceptTCP()
-	if err != nil {
-		return
-	}
-	tc.SetKeepAlive(true)
-	tc.SetKeepAlivePeriod(3 * time.Minute)
-	return tc, nil
-}

+ 5 - 3
psiphon/networkInterface.go

@@ -22,6 +22,8 @@ package psiphon
 import (
 	"errors"
 	"net"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // Take in an interface name ("lo", "eth0", "any") passed from either
@@ -41,12 +43,12 @@ func GetInterfaceIPAddress(listenInterface string) (string, error) {
 	} else {
 		availableInterfaces, err := net.InterfaceByName(listenInterface)
 		if err != nil {
-			return "", ContextError(err)
+			return "", common.ContextError(err)
 		}
 
 		addrs, err := availableInterfaces.Addrs()
 		if err != nil {
-			return "", ContextError(err)
+			return "", common.ContextError(err)
 		}
 		for _, addr := range addrs {
 			iptype := addr.(*net.IPNet)
@@ -62,6 +64,6 @@ func GetInterfaceIPAddress(listenInterface string) (string, error) {
 		}
 	}
 
-	return "", ContextError(errors.New("Could not find IP address of specified interface"))
+	return "", common.ContextError(errors.New("Could not find IP address of specified interface"))
 
 }

+ 4 - 2
psiphon/notice.go

@@ -31,6 +31,8 @@ import (
 	"sync"
 	"sync/atomic"
 	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 var noticeLoggerMutex sync.Mutex
@@ -110,7 +112,7 @@ func outputNotice(noticeType string, noticeFlags uint32, args ...interface{}) {
 	if err == nil {
 		output = string(encodedJson)
 	} else {
-		output = fmt.Sprintf("{\"Alert\":{\"message\":\"%s\"}}", ContextError(err))
+		output = fmt.Sprintf("{\"Alert\":{\"message\":\"%s\"}}", common.ContextError(err))
 	}
 	noticeLoggerMutex.Lock()
 	defer noticeLoggerMutex.Unlock()
@@ -309,7 +311,7 @@ func NoticeLocalProxyError(proxyType string, err error) {
 	// the root error that repeats (the full error often contains
 	// different specific values, e.g., local port numbers, but
 	// the same repeating root).
-	// Assumes error format of ContextError.
+	// Assumes error format of common.ContextError.
 	repetitionMessage := err.Error()
 	index := strings.LastIndex(repetitionMessage, ": ")
 	if index != -1 {

+ 26 - 24
psiphon/obfuscatedSshConn.go

@@ -25,6 +25,8 @@ import (
 	"errors"
 	"io"
 	"net"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 const (
@@ -115,7 +117,7 @@ func NewObfuscatedSshConn(
 	if mode == OBFUSCATION_CONN_MODE_CLIENT {
 		obfuscator, err = NewClientObfuscator(&ObfuscatorConfig{Keyword: obfuscationKeyword})
 		if err != nil {
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 		readDeobfuscate = obfuscator.ObfuscateServerToClient
 		writeObfuscate = obfuscator.ObfuscateClientToServer
@@ -126,7 +128,7 @@ func NewObfuscatedSshConn(
 			conn, &ObfuscatorConfig{Keyword: obfuscationKeyword})
 		if err != nil {
 			// TODO: readForver() equivilent
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 		readDeobfuscate = obfuscator.ObfuscateClientToServer
 		writeObfuscate = obfuscator.ObfuscateServerToClient
@@ -161,7 +163,7 @@ func (conn *ObfuscatedSshConn) Write(buffer []byte) (n int, err error) {
 	}
 	err = conn.transformAndWrite(buffer)
 	if err != nil {
-		return 0, ContextError(err)
+		return 0, common.ContextError(err)
 	}
 	// Reports that we wrote all the bytes
 	// (althogh we may have buffered some or all)
@@ -218,7 +220,7 @@ func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error
 				conn.readBuffer, err = readSshIdentificationLine(
 					conn.Conn, conn.readDeobfuscate)
 				if err != nil {
-					return 0, ContextError(err)
+					return 0, common.ContextError(err)
 				}
 				if bytes.HasPrefix(conn.readBuffer, []byte("SSH-")) {
 					break
@@ -235,7 +237,7 @@ func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error
 			conn.readBuffer, isMsgNewKeys, err = readSshPacket(
 				conn.Conn, conn.readDeobfuscate)
 			if err != nil {
-				return 0, ContextError(err)
+				return 0, common.ContextError(err)
 			}
 
 			if isMsgNewKeys {
@@ -247,7 +249,7 @@ func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error
 		nextState = OBFUSCATION_READ_STATE_FINISHED
 
 	case OBFUSCATION_READ_STATE_FINISHED:
-		return 0, ContextError(errors.New("invalid read state"))
+		return 0, common.ContextError(errors.New("invalid read state"))
 	}
 
 	n = copy(buffer, conn.readBuffer)
@@ -306,18 +308,18 @@ func (conn *ObfuscatedSshConn) transformAndWrite(buffer []byte) (err error) {
 	if conn.writeState == OBFUSCATION_WRITE_STATE_CLIENT_SEND_SEED_MESSAGE {
 		_, err = conn.Conn.Write(conn.obfuscator.SendSeedMessage())
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 		conn.writeState = OBFUSCATION_WRITE_STATE_IDENTIFICATION_LINE
 	} else if conn.writeState == OBFUSCATION_WRITE_STATE_SERVER_SEND_IDENTIFICATION_LINE_PADDING {
 		padding, err := makeServerIdentificationLinePadding()
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 		conn.writeObfuscate(padding)
 		_, err = conn.Conn.Write(padding)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 		conn.writeState = OBFUSCATION_WRITE_STATE_IDENTIFICATION_LINE
 	}
@@ -336,21 +338,21 @@ func (conn *ObfuscatedSshConn) transformAndWrite(buffer []byte) (err error) {
 		var hasMsgNewKeys bool
 		conn.writeBuffer, sendBuffer, hasMsgNewKeys, err = extractSshPackets(conn.writeBuffer)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 		if hasMsgNewKeys {
 			conn.writeState = OBFUSCATION_WRITE_STATE_FINISHED
 		}
 
 	case OBFUSCATION_WRITE_STATE_FINISHED:
-		return ContextError(errors.New("invalid write state"))
+		return common.ContextError(errors.New("invalid write state"))
 	}
 
 	if sendBuffer != nil {
 		conn.writeObfuscate(sendBuffer)
 		_, err := conn.Conn.Write(sendBuffer)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 	}
 
@@ -358,7 +360,7 @@ func (conn *ObfuscatedSshConn) transformAndWrite(buffer []byte) (err error) {
 		// After SSH_MSG_NEWKEYS, any remaining bytes are un-obfuscated
 		_, err := conn.Conn.Write(conn.writeBuffer)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 		// The buffer memory is no longer used
 		conn.writeBuffer = nil
@@ -376,7 +378,7 @@ func readSshIdentificationLine(
 	for len(readBuffer) < SSH_MAX_SERVER_LINE_LENGTH {
 		_, err := io.ReadFull(conn, oneByte[:])
 		if err != nil {
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 		deobfuscate(oneByte[:])
 		readBuffer = append(readBuffer, oneByte[0])
@@ -386,7 +388,7 @@ func readSshIdentificationLine(
 		}
 	}
 	if !validLine {
-		return nil, ContextError(errors.New("invalid identification line"))
+		return nil, common.ContextError(errors.New("invalid identification line"))
 	}
 	return readBuffer, nil
 }
@@ -397,18 +399,18 @@ func readSshPacket(
 	prefix := make([]byte, SSH_PACKET_PREFIX_LENGTH)
 	_, err := io.ReadFull(conn, prefix)
 	if err != nil {
-		return nil, false, ContextError(err)
+		return nil, false, common.ContextError(err)
 	}
 	deobfuscate(prefix)
 	packetLength, _, payloadLength, messageLength := getSshPacketPrefix(prefix)
 	if packetLength > SSH_MAX_PACKET_LENGTH {
-		return nil, false, ContextError(errors.New("ssh packet length too large"))
+		return nil, false, common.ContextError(errors.New("ssh packet length too large"))
 	}
 	readBuffer := make([]byte, messageLength)
 	copy(readBuffer, prefix)
 	_, err = io.ReadFull(conn, readBuffer[len(prefix):])
 	if err != nil {
-		return nil, false, ContextError(err)
+		return nil, false, common.ContextError(err)
 	}
 	deobfuscate(readBuffer[len(prefix):])
 	isMsgNewKeys := false
@@ -424,9 +426,9 @@ func readSshPacket(
 // From the original patch to sshd.c:
 // https://bitbucket.org/psiphon/psiphon-circumvention-system/commits/f40865ce624b680be840dc2432283c8137bd896d
 func makeServerIdentificationLinePadding() ([]byte, error) {
-	paddingLength, err := MakeSecureRandomInt(OBFUSCATE_MAX_PADDING - 2) // 2 = CRLF
+	paddingLength, err := common.MakeSecureRandomInt(OBFUSCATE_MAX_PADDING - 2) // 2 = CRLF
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	paddingLength += 2
 	padding := make([]byte, paddingLength)
@@ -490,14 +492,14 @@ func extractSshPackets(writeBuffer []byte) ([]byte, []byte, bool, error) {
 		possiblePaddings := (SSH_MAX_PADDING_LENGTH - paddingLength) / SSH_PADDING_MULTIPLE
 		if possiblePaddings > 0 {
 			// selectedPadding is integer in range [0, possiblePaddings)
-			selectedPadding, err := MakeSecureRandomInt(possiblePaddings)
+			selectedPadding, err := common.MakeSecureRandomInt(possiblePaddings)
 			if err != nil {
-				return nil, nil, false, ContextError(err)
+				return nil, nil, false, common.ContextError(err)
 			}
 			extraPaddingLength := selectedPadding * SSH_PADDING_MULTIPLE
-			extraPadding, err := MakeSecureRandomBytes(extraPaddingLength)
+			extraPadding, err := common.MakeSecureRandomBytes(extraPaddingLength)
 			if err != nil {
-				return nil, nil, false, ContextError(err)
+				return nil, nil, false, common.ContextError(err)
 			}
 			setSshPacketPrefix(
 				packetBuffer, packetLength+extraPaddingLength, paddingLength+extraPaddingLength)

+ 28 - 26
psiphon/obfuscator.go

@@ -26,6 +26,8 @@ import (
 	"encoding/binary"
 	"errors"
 	"io"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 const (
@@ -58,14 +60,14 @@ type ObfuscatorConfig struct {
 func NewClientObfuscator(
 	config *ObfuscatorConfig) (obfuscator *Obfuscator, err error) {
 
-	seed, err := MakeSecureRandomBytes(OBFUSCATE_SEED_LENGTH)
+	seed, err := common.MakeSecureRandomBytes(OBFUSCATE_SEED_LENGTH)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	clientToServerCipher, serverToClientCipher, err := initObfuscatorCiphers(seed, config)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	maxPadding := OBFUSCATE_MAX_PADDING
@@ -75,7 +77,7 @@ func NewClientObfuscator(
 
 	seedMessage, err := makeSeedMessage(maxPadding, seed, clientToServerCipher)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return &Obfuscator{
@@ -92,7 +94,7 @@ func NewServerObfuscator(
 	clientToServerCipher, serverToClientCipher, err := readSeedMessage(
 		clientReader, config)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return &Obfuscator{
@@ -123,22 +125,22 @@ func initObfuscatorCiphers(
 
 	clientToServerKey, err := deriveKey(seed, []byte(config.Keyword), []byte(OBFUSCATE_CLIENT_TO_SERVER_IV))
 	if err != nil {
-		return nil, nil, ContextError(err)
+		return nil, nil, common.ContextError(err)
 	}
 
 	serverToClientKey, err := deriveKey(seed, []byte(config.Keyword), []byte(OBFUSCATE_SERVER_TO_CLIENT_IV))
 	if err != nil {
-		return nil, nil, ContextError(err)
+		return nil, nil, common.ContextError(err)
 	}
 
 	clientToServerCipher, err := rc4.NewCipher(clientToServerKey)
 	if err != nil {
-		return nil, nil, ContextError(err)
+		return nil, nil, common.ContextError(err)
 	}
 
 	serverToClientCipher, err := rc4.NewCipher(serverToClientKey)
 	if err != nil {
-		return nil, nil, ContextError(err)
+		return nil, nil, common.ContextError(err)
 	}
 
 	return clientToServerCipher, serverToClientCipher, nil
@@ -156,37 +158,37 @@ func deriveKey(seed, keyword, iv []byte) ([]byte, error) {
 		digest = h.Sum(nil)
 	}
 	if len(digest) < OBFUSCATE_KEY_LENGTH {
-		return nil, ContextError(errors.New("insufficient bytes for obfuscation key"))
+		return nil, common.ContextError(errors.New("insufficient bytes for obfuscation key"))
 	}
 	return digest[0:OBFUSCATE_KEY_LENGTH], nil
 }
 
 func makeSeedMessage(maxPadding int, seed []byte, clientToServerCipher *rc4.Cipher) ([]byte, error) {
 	// paddingLength is integer in range [0, maxPadding]
-	paddingLength, err := MakeSecureRandomInt(maxPadding + 1)
+	paddingLength, err := common.MakeSecureRandomInt(maxPadding + 1)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
-	padding, err := MakeSecureRandomBytes(paddingLength)
+	padding, err := common.MakeSecureRandomBytes(paddingLength)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	buffer := new(bytes.Buffer)
 	err = binary.Write(buffer, binary.BigEndian, seed)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	err = binary.Write(buffer, binary.BigEndian, uint32(OBFUSCATE_MAGIC_VALUE))
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	err = binary.Write(buffer, binary.BigEndian, uint32(paddingLength))
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	err = binary.Write(buffer, binary.BigEndian, padding)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	seedMessage := buffer.Bytes()
 	clientToServerCipher.XORKeyStream(seedMessage[len(seed):], seedMessage[len(seed):])
@@ -199,18 +201,18 @@ func readSeedMessage(
 	seed := make([]byte, OBFUSCATE_SEED_LENGTH)
 	_, err := io.ReadFull(clientReader, seed)
 	if err != nil {
-		return nil, nil, ContextError(err)
+		return nil, nil, common.ContextError(err)
 	}
 
 	clientToServerCipher, serverToClientCipher, err := initObfuscatorCiphers(seed, config)
 	if err != nil {
-		return nil, nil, ContextError(err)
+		return nil, nil, common.ContextError(err)
 	}
 
 	fixedLengthFields := make([]byte, 8) // 4 bytes each for magic value and padding length
 	_, err = io.ReadFull(clientReader, fixedLengthFields)
 	if err != nil {
-		return nil, nil, ContextError(err)
+		return nil, nil, common.ContextError(err)
 	}
 
 	clientToServerCipher.XORKeyStream(fixedLengthFields, fixedLengthFields)
@@ -220,25 +222,25 @@ func readSeedMessage(
 	var magicValue, paddingLength int32
 	err = binary.Read(buffer, binary.BigEndian, &magicValue)
 	if err != nil {
-		return nil, nil, ContextError(err)
+		return nil, nil, common.ContextError(err)
 	}
 	err = binary.Read(buffer, binary.BigEndian, &paddingLength)
 	if err != nil {
-		return nil, nil, ContextError(err)
+		return nil, nil, common.ContextError(err)
 	}
 
 	if magicValue != OBFUSCATE_MAGIC_VALUE {
-		return nil, nil, ContextError(errors.New("invalid magic value"))
+		return nil, nil, common.ContextError(errors.New("invalid magic value"))
 	}
 
 	if paddingLength < 0 || paddingLength > OBFUSCATE_MAX_PADDING {
-		return nil, nil, ContextError(errors.New("invalid padding length"))
+		return nil, nil, common.ContextError(errors.New("invalid padding length"))
 	}
 
 	padding := make([]byte, paddingLength)
 	_, err = io.ReadFull(clientReader, padding)
 	if err != nil {
-		return nil, nil, ContextError(err)
+		return nil, nil, common.ContextError(err)
 	}
 
 	clientToServerCipher.XORKeyStream(padding, padding)

+ 5 - 4
psiphon/opensslConn.go

@@ -27,6 +27,7 @@ import (
 	"strings"
 
 	"github.com/Psiphon-Inc/openssl"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // newOpenSSLConn wraps a connection with TLS which mimicks stock Android TLS.
@@ -37,7 +38,7 @@ func newOpenSSLConn(rawConn net.Conn, hostname string, config *CustomTLSConfig)
 
 	ctx, err := openssl.NewCtx()
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	if !config.SkipVerify {
@@ -51,7 +52,7 @@ func newOpenSSLConn(rawConn net.Conn, hostname string, config *CustomTLSConfig)
 			}
 			err = ctx.LoadVerifyLocations(config.TrustedCACertificatesFilename, "")
 			if err != nil {
-				return nil, ContextError(err)
+				return nil, common.ContextError(err)
 			}
 		}
 	} else {
@@ -93,7 +94,7 @@ func newOpenSSLConn(rawConn net.Conn, hostname string, config *CustomTLSConfig)
 
 	conn, err := openssl.Client(rawConn, ctx)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	if config.SNIServerName != "" {
@@ -103,7 +104,7 @@ func newOpenSSLConn(rawConn net.Conn, hostname string, config *CustomTLSConfig)
 		if net.ParseIP(config.SNIServerName) == nil {
 			err = conn.SetTlsExtHostName(config.SNIServerName)
 			if err != nil {
-				return nil, ContextError(err)
+				return nil, common.ContextError(err)
 			}
 		}
 	}

+ 3 - 1
psiphon/opensslConn_unsupported.go

@@ -24,9 +24,11 @@ package psiphon
 import (
 	"errors"
 	"net"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // newOpenSSLConn simply returns an error when used on an unsupported platform.
 func newOpenSSLConn(rawConn net.Conn, hostname string, config *CustomTLSConfig) (handshakeConn, error) {
-	return nil, ContextError(errors.New("newOpenSSLConn not supported on this platform"))
+	return nil, common.ContextError(errors.New("newOpenSSLConn not supported on this platform"))
 }

+ 8 - 6
psiphon/package.go

@@ -27,6 +27,8 @@ import (
 	"encoding/base64"
 	"encoding/json"
 	"errors"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // AuthenticatedDataPackage is a JSON record containing some Psiphon data
@@ -45,24 +47,24 @@ func ReadAuthenticatedDataPackage(
 	var authenticatedDataPackage *AuthenticatedDataPackage
 	err = json.Unmarshal(rawPackage, &authenticatedDataPackage)
 	if err != nil {
-		return "", ContextError(err)
+		return "", common.ContextError(err)
 	}
 
 	derEncodedPublicKey, err := base64.StdEncoding.DecodeString(signingPublicKey)
 	if err != nil {
-		return "", ContextError(err)
+		return "", common.ContextError(err)
 	}
 	publicKey, err := x509.ParsePKIXPublicKey(derEncodedPublicKey)
 	if err != nil {
-		return "", ContextError(err)
+		return "", common.ContextError(err)
 	}
 	rsaPublicKey, ok := publicKey.(*rsa.PublicKey)
 	if !ok {
-		return "", ContextError(errors.New("unexpected signing public key type"))
+		return "", common.ContextError(errors.New("unexpected signing public key type"))
 	}
 	signature, err := base64.StdEncoding.DecodeString(authenticatedDataPackage.Signature)
 	if err != nil {
-		return "", ContextError(err)
+		return "", common.ContextError(err)
 	}
 	// TODO: can distinguish signed-with-different-key from other errors:
 	// match digest(publicKey) against authenticatedDataPackage.SigningPublicKeyDigest
@@ -71,7 +73,7 @@ func ReadAuthenticatedDataPackage(
 	digest := hash.Sum(nil)
 	err = rsa.VerifyPKCS1v15(rsaPublicKey, crypto.SHA256, digest, signature)
 	if err != nil {
-		return "", ContextError(err)
+		return "", common.ContextError(err)
 	}
 
 	return authenticatedDataPackage.Data, nil

+ 14 - 12
psiphon/remoteServerList.go

@@ -25,6 +25,8 @@ import (
 	"os"
 	"strings"
 	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // FetchRemoteServerList downloads a remote server list JSON record from
@@ -47,7 +49,7 @@ func FetchRemoteServerList(
 		config.RemoteServerListUrl,
 		time.Duration(*config.FetchRemoteServerListTimeoutSeconds)*time.Second)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	// Proceed with download
@@ -60,7 +62,7 @@ func FetchRemoteServerList(
 
 	lastETag, err := GetUrlETag(config.RemoteServerListUrl)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	n, responseETag, err := ResumeDownload(
@@ -69,7 +71,7 @@ func FetchRemoteServerList(
 	NoticeRemoteServerListDownloadedBytes(n)
 
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	if responseETag == lastETag {
@@ -84,38 +86,38 @@ func FetchRemoteServerList(
 
 	downloadContent, err := os.Open(downloadFilename)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	defer downloadContent.Close()
 
 	zlibReader, err := zlib.NewReader(downloadContent)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	dataPackage, err := ioutil.ReadAll(zlibReader)
 	zlibReader.Close()
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	remoteServerList, err := ReadAuthenticatedDataPackage(
 		dataPackage, config.RemoteServerListSignaturePublicKey)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	serverEntries, err := DecodeAndValidateServerEntryList(
 		remoteServerList,
-		GetCurrentTimestamp(),
-		SERVER_ENTRY_SOURCE_REMOTE)
+		common.GetCurrentTimestamp(),
+		common.SERVER_ENTRY_SOURCE_REMOTE)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	err = StoreServerEntries(serverEntries, true)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	// Now that the server entries are successfully imported, store the response
@@ -124,7 +126,7 @@ func FetchRemoteServerList(
 	if responseETag != "" {
 		err := SetUrlETag(config.RemoteServerListUrl, responseETag)
 		if err != nil {
-			NoticeAlert("failed to set remote server list ETag: %s", ContextError(err))
+			NoticeAlert("failed to set remote server list ETag: %s", common.ContextError(err))
 			// This fetch is still reported as a success, even if we can't store the etag
 		}
 	}

+ 48 - 48
psiphon/server/api.go

@@ -29,7 +29,7 @@ import (
 	"strings"
 	"unicode"
 
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 const (
@@ -67,22 +67,22 @@ func sshAPIRequestHandler(
 	var params requestJSONObject
 	err := json.Unmarshal(requestPayload, &params)
 	if err != nil {
-		return nil, psiphon.ContextError(
+		return nil, common.ContextError(
 			fmt.Errorf("invalid payload for request name: %s: %s", name, err))
 	}
 
 	switch name {
-	case psiphon.SERVER_API_HANDSHAKE_REQUEST_NAME:
+	case common.PSIPHON_API_HANDSHAKE_REQUEST_NAME:
 		return handshakeAPIRequestHandler(support, geoIPData, params)
-	case psiphon.SERVER_API_CONNECTED_REQUEST_NAME:
+	case common.PSIPHON_API_CONNECTED_REQUEST_NAME:
 		return connectedAPIRequestHandler(support, geoIPData, params)
-	case psiphon.SERVER_API_STATUS_REQUEST_NAME:
+	case common.PSIPHON_API_STATUS_REQUEST_NAME:
 		return statusAPIRequestHandler(support, geoIPData, params)
-	case psiphon.SERVER_API_CLIENT_VERIFICATION_REQUEST_NAME:
+	case common.PSIPHON_API_CLIENT_VERIFICATION_REQUEST_NAME:
 		return clientVerificationAPIRequestHandler(support, geoIPData, params)
 	}
 
-	return nil, psiphon.ContextError(fmt.Errorf("invalid request name: %s", name))
+	return nil, common.ContextError(fmt.Errorf("invalid request name: %s", name))
 }
 
 // handshakeAPIRequestHandler implements the "handshake" API request.
@@ -98,7 +98,7 @@ func handshakeAPIRequestHandler(
 
 	err := validateRequestParams(support, params, baseRequestParams)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	log.WithContextFields(
@@ -146,7 +146,7 @@ func handshakeAPIRequestHandler(
 
 	handshakeResponse.ClientRegion = clientRegion
 
-	handshakeResponse.ServerTimestamp = psiphon.GetCurrentTimestamp()
+	handshakeResponse.ServerTimestamp = common.GetCurrentTimestamp()
 
 	handshakeResponse.ClientVerificationRequired = CLIENT_VERIFICATION_REQUIRED
 	handshakeResponse.ClientVerificationServerNonce = ""
@@ -155,7 +155,7 @@ func handshakeAPIRequestHandler(
 
 	responsePayload, err := json.Marshal(handshakeResponse)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return responsePayload, nil
@@ -179,7 +179,7 @@ func connectedAPIRequestHandler(
 
 	err := validateRequestParams(support, params, connectedRequestParams)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	log.WithContextFields(
@@ -195,11 +195,11 @@ func connectedAPIRequestHandler(
 	}
 
 	connectedResponse.ConnectedTimestamp =
-		psiphon.TruncateTimestampToHour(psiphon.GetCurrentTimestamp())
+		common.TruncateTimestampToHour(common.GetCurrentTimestamp())
 
 	responsePayload, err := json.Marshal(connectedResponse)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return responsePayload, nil
@@ -224,19 +224,19 @@ func statusAPIRequestHandler(
 
 	err := validateRequestParams(support, params, statusRequestParams)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	statusData, err := getJSONObjectRequestParam(params, "statusData")
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// Overall bytes transferred stats
 
 	bytesTransferred, err := getInt64RequestParam(statusData, "bytes_transferred")
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	bytesTransferredFields := getRequestLogFields(
 		support, "bytes_transferred", geoIPData, params, statusRequestParams)
@@ -250,7 +250,7 @@ func statusAPIRequestHandler(
 
 		hostBytes, err := getMapStringInt64RequestParam(statusData, "host_bytes")
 		if err != nil {
-			return nil, psiphon.ContextError(err)
+			return nil, common.ContextError(err)
 		}
 		domainBytesFields := getRequestLogFields(
 			support, "domain_bytes", geoIPData, params, statusRequestParams)
@@ -268,7 +268,7 @@ func statusAPIRequestHandler(
 
 		tunnelStats, err := getJSONObjectArrayRequestParam(statusData, "tunnel_stats")
 		if err != nil {
-			return nil, psiphon.ContextError(err)
+			return nil, common.ContextError(err)
 		}
 		sessionFields := getRequestLogFields(
 			support, "session", geoIPData, params, statusRequestParams)
@@ -276,48 +276,48 @@ func statusAPIRequestHandler(
 
 			sessionID, err := getStringRequestParam(tunnelStat, "session_id")
 			if err != nil {
-				return nil, psiphon.ContextError(err)
+				return nil, common.ContextError(err)
 			}
 			sessionFields["session_id"] = sessionID
 
 			tunnelNumber, err := getInt64RequestParam(tunnelStat, "tunnel_number")
 			if err != nil {
-				return nil, psiphon.ContextError(err)
+				return nil, common.ContextError(err)
 			}
 			sessionFields["tunnel_number"] = tunnelNumber
 
 			tunnelServerIPAddress, err := getStringRequestParam(tunnelStat, "tunnel_server_ip_address")
 			if err != nil {
-				return nil, psiphon.ContextError(err)
+				return nil, common.ContextError(err)
 			}
 			sessionFields["tunnel_server_ip_address"] = tunnelServerIPAddress
 
 			serverHandshakeTimestamp, err := getStringRequestParam(tunnelStat, "server_handshake_timestamp")
 			if err != nil {
-				return nil, psiphon.ContextError(err)
+				return nil, common.ContextError(err)
 			}
 			sessionFields["server_handshake_timestamp"] = serverHandshakeTimestamp
 
 			strDuration, err := getStringRequestParam(tunnelStat, "duration")
 			if err != nil {
-				return nil, psiphon.ContextError(err)
+				return nil, common.ContextError(err)
 			}
 			duration, err := strconv.ParseInt(strDuration, 10, 64)
 			if err != nil {
-				return nil, psiphon.ContextError(err)
+				return nil, common.ContextError(err)
 			}
 			// Client reports durations in nanoseconds; divide to get to milliseconds
 			sessionFields["duration"] = duration / 1000000
 
 			totalBytesSent, err := getInt64RequestParam(tunnelStat, "total_bytes_sent")
 			if err != nil {
-				return nil, psiphon.ContextError(err)
+				return nil, common.ContextError(err)
 			}
 			sessionFields["total_bytes_sent"] = totalBytesSent
 
 			totalBytesReceived, err := getInt64RequestParam(tunnelStat, "total_bytes_received")
 			if err != nil {
-				return nil, psiphon.ContextError(err)
+				return nil, common.ContextError(err)
 			}
 			sessionFields["total_bytes_received"] = totalBytesReceived
 
@@ -339,7 +339,7 @@ func clientVerificationAPIRequestHandler(
 
 	err := validateRequestParams(support, params, baseRequestParams)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// Ignoring error as params are validated
@@ -347,7 +347,7 @@ func clientVerificationAPIRequestHandler(
 
 	verificationData, err := getJSONObjectRequestParam(params, "verificationData")
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	logFields := getRequestLogFields(
@@ -424,7 +424,7 @@ func validateRequestParams(
 			if expectedParam.flags&requestParamOptional != 0 {
 				continue
 			}
-			return psiphon.ContextError(
+			return common.ContextError(
 				fmt.Errorf("missing param: %s", expectedParam.name))
 		}
 		var err error
@@ -434,7 +434,7 @@ func validateRequestParams(
 			err = validateStringRequestParam(support, expectedParam, value)
 		}
 		if err != nil {
-			return psiphon.ContextError(err)
+			return common.ContextError(err)
 		}
 	}
 
@@ -448,11 +448,11 @@ func validateStringRequestParam(
 
 	strValue, ok := value.(string)
 	if !ok {
-		return psiphon.ContextError(
+		return common.ContextError(
 			fmt.Errorf("unexpected string param type: %s", expectedParam.name))
 	}
 	if !expectedParam.validator(support, strValue) {
-		return psiphon.ContextError(
+		return common.ContextError(
 			fmt.Errorf("invalid param: %s", expectedParam.name))
 	}
 	return nil
@@ -465,13 +465,13 @@ func validateStringArrayRequestParam(
 
 	arrayValue, ok := value.([]interface{})
 	if !ok {
-		return psiphon.ContextError(
+		return common.ContextError(
 			fmt.Errorf("unexpected string param type: %s", expectedParam.name))
 	}
 	for _, value := range arrayValue {
 		err := validateStringRequestParam(support, expectedParam, value)
 		if err != nil {
-			return psiphon.ContextError(err)
+			return common.ContextError(err)
 		}
 	}
 	return nil
@@ -563,45 +563,45 @@ func getRequestLogFields(
 
 func getStringRequestParam(params requestJSONObject, name string) (string, error) {
 	if params[name] == nil {
-		return "", psiphon.ContextError(fmt.Errorf("missing param: %s", name))
+		return "", common.ContextError(fmt.Errorf("missing param: %s", name))
 	}
 	value, ok := params[name].(string)
 	if !ok {
-		return "", psiphon.ContextError(fmt.Errorf("invalid param: %s", name))
+		return "", common.ContextError(fmt.Errorf("invalid param: %s", name))
 	}
 	return value, nil
 }
 
 func getInt64RequestParam(params requestJSONObject, name string) (int64, error) {
 	if params[name] == nil {
-		return 0, psiphon.ContextError(fmt.Errorf("missing param: %s", name))
+		return 0, common.ContextError(fmt.Errorf("missing param: %s", name))
 	}
 	value, ok := params[name].(float64)
 	if !ok {
-		return 0, psiphon.ContextError(fmt.Errorf("invalid param: %s", name))
+		return 0, common.ContextError(fmt.Errorf("invalid param: %s", name))
 	}
 	return int64(value), nil
 }
 
 func getJSONObjectRequestParam(params requestJSONObject, name string) (requestJSONObject, error) {
 	if params[name] == nil {
-		return nil, psiphon.ContextError(fmt.Errorf("missing param: %s", name))
+		return nil, common.ContextError(fmt.Errorf("missing param: %s", name))
 	}
 	// TODO: can't use requestJSONObject type?
 	value, ok := params[name].(map[string]interface{})
 	if !ok {
-		return nil, psiphon.ContextError(fmt.Errorf("invalid param: %s", name))
+		return nil, common.ContextError(fmt.Errorf("invalid param: %s", name))
 	}
 	return requestJSONObject(value), nil
 }
 
 func getJSONObjectArrayRequestParam(params requestJSONObject, name string) ([]requestJSONObject, error) {
 	if params[name] == nil {
-		return nil, psiphon.ContextError(fmt.Errorf("missing param: %s", name))
+		return nil, common.ContextError(fmt.Errorf("missing param: %s", name))
 	}
 	value, ok := params[name].([]interface{})
 	if !ok {
-		return nil, psiphon.ContextError(fmt.Errorf("invalid param: %s", name))
+		return nil, common.ContextError(fmt.Errorf("invalid param: %s", name))
 	}
 
 	result := make([]requestJSONObject, len(value))
@@ -609,7 +609,7 @@ func getJSONObjectArrayRequestParam(params requestJSONObject, name string) ([]re
 		// TODO: can't use requestJSONObject type?
 		resultItem, ok := item.(map[string]interface{})
 		if !ok {
-			return nil, psiphon.ContextError(fmt.Errorf("invalid param: %s", name))
+			return nil, common.ContextError(fmt.Errorf("invalid param: %s", name))
 		}
 		result[i] = requestJSONObject(resultItem)
 	}
@@ -619,19 +619,19 @@ func getJSONObjectArrayRequestParam(params requestJSONObject, name string) ([]re
 
 func getMapStringInt64RequestParam(params requestJSONObject, name string) (map[string]int64, error) {
 	if params[name] == nil {
-		return nil, psiphon.ContextError(fmt.Errorf("missing param: %s", name))
+		return nil, common.ContextError(fmt.Errorf("missing param: %s", name))
 	}
 	// TODO: can't use requestJSONObject type?
 	value, ok := params[name].(map[string]interface{})
 	if !ok {
-		return nil, psiphon.ContextError(fmt.Errorf("invalid param: %s", name))
+		return nil, common.ContextError(fmt.Errorf("invalid param: %s", name))
 	}
 
 	result := make(map[string]int64)
 	for k, v := range value {
 		numValue, ok := v.(float64)
 		if !ok {
-			return nil, psiphon.ContextError(fmt.Errorf("invalid param: %s", name))
+			return nil, common.ContextError(fmt.Errorf("invalid param: %s", name))
 		}
 		result[k] = int64(numValue)
 	}
@@ -687,7 +687,7 @@ func isClientPlatform(_ *SupportServices, value string) bool {
 }
 
 func isRelayProtocol(_ *SupportServices, value string) bool {
-	return psiphon.Contains(psiphon.SupportedTunnelProtocols, value)
+	return common.Contains(common.SupportedTunnelProtocols, value)
 }
 
 func isBooleanFlag(_ *SupportServices, value string) bool {
@@ -769,7 +769,7 @@ func isHostHeader(support *SupportServices, value string) bool {
 }
 
 func isServerEntrySource(_ *SupportServices, value string) bool {
-	return psiphon.Contains(psiphon.SupportedServerEntrySources, value)
+	return common.Contains(common.SupportedServerEntrySources, value)
 }
 
 var isISO8601DateRegex = regexp.MustCompile(

+ 37 - 36
psiphon/server/config.go

@@ -33,6 +33,7 @@ import (
 	"strings"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"golang.org/x/crypto/nacl/box"
 	"golang.org/x/crypto/ssh"
 )
@@ -208,7 +209,7 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 	var config Config
 	err := json.Unmarshal(configJSON, &config)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	if config.ServerIPAddress == "" {
@@ -223,8 +224,8 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 	}
 
 	for tunnelProtocol, _ := range config.TunnelProtocolPorts {
-		if psiphon.TunnelProtocolUsesSSH(tunnelProtocol) ||
-			psiphon.TunnelProtocolUsesObfuscatedSSH(tunnelProtocol) {
+		if common.TunnelProtocolUsesSSH(tunnelProtocol) ||
+			common.TunnelProtocolUsesObfuscatedSSH(tunnelProtocol) {
 			if config.SSHPrivateKey == "" || config.SSHServerVersion == "" ||
 				config.SSHUserName == "" || config.SSHPassword == "" {
 				return nil, fmt.Errorf(
@@ -232,22 +233,22 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 					tunnelProtocol)
 			}
 		}
-		if psiphon.TunnelProtocolUsesObfuscatedSSH(tunnelProtocol) {
+		if common.TunnelProtocolUsesObfuscatedSSH(tunnelProtocol) {
 			if config.ObfuscatedSSHKey == "" {
 				return nil, fmt.Errorf(
 					"Tunnel protocol %s requires ObfuscatedSSHKey",
 					tunnelProtocol)
 			}
 		}
-		if psiphon.TunnelProtocolUsesMeekHTTP(tunnelProtocol) ||
-			psiphon.TunnelProtocolUsesMeekHTTPS(tunnelProtocol) {
+		if common.TunnelProtocolUsesMeekHTTP(tunnelProtocol) ||
+			common.TunnelProtocolUsesMeekHTTPS(tunnelProtocol) {
 			if config.MeekCookieEncryptionPrivateKey == "" || config.MeekObfuscatedKey == "" {
 				return nil, fmt.Errorf(
 					"Tunnel protocol %s requires MeekCookieEncryptionPrivateKey, MeekObfuscatedKey",
 					tunnelProtocol)
 			}
 		}
-		if psiphon.TunnelProtocolUsesMeekHTTPS(tunnelProtocol) {
+		if common.TunnelProtocolUsesMeekHTTPS(tunnelProtocol) {
 			if config.MeekCertificateCommonName == "" {
 				return nil, fmt.Errorf(
 					"Tunnel protocol %s requires MeekCertificateCommonName",
@@ -305,11 +306,11 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 	// Input validation
 
 	if net.ParseIP(params.ServerIPAddress) == nil {
-		return nil, nil, nil, psiphon.ContextError(errors.New("invalid IP address"))
+		return nil, nil, nil, common.ContextError(errors.New("invalid IP address"))
 	}
 
 	if len(params.TunnelProtocolPorts) == 0 {
-		return nil, nil, nil, psiphon.ContextError(errors.New("no tunnel protocols"))
+		return nil, nil, nil, common.ContextError(errors.New("no tunnel protocols"))
 	}
 
 	usedPort := make(map[int]bool)
@@ -321,17 +322,17 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 
 	for protocol, port := range params.TunnelProtocolPorts {
 
-		if !psiphon.Contains(psiphon.SupportedTunnelProtocols, protocol) {
-			return nil, nil, nil, psiphon.ContextError(errors.New("invalid tunnel protocol"))
+		if !common.Contains(common.SupportedTunnelProtocols, protocol) {
+			return nil, nil, nil, common.ContextError(errors.New("invalid tunnel protocol"))
 		}
 
 		if usedPort[port] {
-			return nil, nil, nil, psiphon.ContextError(errors.New("duplicate listening port"))
+			return nil, nil, nil, common.ContextError(errors.New("duplicate listening port"))
 		}
 		usedPort[port] = true
 
-		if psiphon.TunnelProtocolUsesMeekHTTP(protocol) ||
-			psiphon.TunnelProtocolUsesMeekHTTPS(protocol) {
+		if common.TunnelProtocolUsesMeekHTTP(protocol) ||
+			common.TunnelProtocolUsesMeekHTTPS(protocol) {
 			usingMeek = true
 		}
 	}
@@ -342,14 +343,14 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 
 	if params.WebServerPort != 0 {
 		var err error
-		webServerSecret, err = psiphon.MakeRandomStringHex(WEB_SERVER_SECRET_BYTE_LENGTH)
+		webServerSecret, err = common.MakeRandomStringHex(WEB_SERVER_SECRET_BYTE_LENGTH)
 		if err != nil {
-			return nil, nil, nil, psiphon.ContextError(err)
+			return nil, nil, nil, common.ContextError(err)
 		}
 
 		webServerCertificate, webServerPrivateKey, err = GenerateWebServerCertificate("")
 		if err != nil {
-			return nil, nil, nil, psiphon.ContextError(err)
+			return nil, nil, nil, common.ContextError(err)
 		}
 	}
 
@@ -358,7 +359,7 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 	// TODO: use other key types: anti-fingerprint by varying params
 	rsaKey, err := rsa.GenerateKey(rand.Reader, SSH_RSA_HOST_KEY_BITS)
 	if err != nil {
-		return nil, nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, common.ContextError(err)
 	}
 
 	sshPrivateKey := pem.EncodeToMemory(
@@ -370,21 +371,21 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 
 	signer, err := ssh.NewSignerFromKey(rsaKey)
 	if err != nil {
-		return nil, nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, common.ContextError(err)
 	}
 
 	sshPublicKey := signer.PublicKey()
 
-	sshUserNameSuffix, err := psiphon.MakeRandomStringHex(SSH_USERNAME_SUFFIX_BYTE_LENGTH)
+	sshUserNameSuffix, err := common.MakeRandomStringHex(SSH_USERNAME_SUFFIX_BYTE_LENGTH)
 	if err != nil {
-		return nil, nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, common.ContextError(err)
 	}
 
 	sshUserName := "psiphon_" + sshUserNameSuffix
 
-	sshPassword, err := psiphon.MakeRandomStringHex(SSH_PASSWORD_BYTE_LENGTH)
+	sshPassword, err := common.MakeRandomStringHex(SSH_PASSWORD_BYTE_LENGTH)
 	if err != nil {
-		return nil, nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, common.ContextError(err)
 	}
 
 	// TODO: vary version string for anti-fingerprint
@@ -392,9 +393,9 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 
 	// Obfuscated SSH config
 
-	obfuscatedSSHKey, err := psiphon.MakeRandomStringHex(SSH_OBFUSCATED_KEY_BYTE_LENGTH)
+	obfuscatedSSHKey, err := common.MakeRandomStringHex(SSH_OBFUSCATED_KEY_BYTE_LENGTH)
 	if err != nil {
-		return nil, nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, common.ContextError(err)
 	}
 
 	// Meek config
@@ -405,23 +406,23 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 		rawMeekCookieEncryptionPublicKey, rawMeekCookieEncryptionPrivateKey, err :=
 			box.GenerateKey(rand.Reader)
 		if err != nil {
-			return nil, nil, nil, psiphon.ContextError(err)
+			return nil, nil, nil, common.ContextError(err)
 		}
 
 		meekCookieEncryptionPublicKey = base64.StdEncoding.EncodeToString(rawMeekCookieEncryptionPublicKey[:])
 		meekCookieEncryptionPrivateKey = base64.StdEncoding.EncodeToString(rawMeekCookieEncryptionPrivateKey[:])
 
-		meekObfuscatedKey, err = psiphon.MakeRandomStringHex(SSH_OBFUSCATED_KEY_BYTE_LENGTH)
+		meekObfuscatedKey, err = common.MakeRandomStringHex(SSH_OBFUSCATED_KEY_BYTE_LENGTH)
 		if err != nil {
-			return nil, nil, nil, psiphon.ContextError(err)
+			return nil, nil, nil, common.ContextError(err)
 		}
 	}
 
 	// Other config
 
-	discoveryValueHMACKey, err := psiphon.MakeRandomStringBase64(DISCOVERY_VALUE_KEY_BYTE_LENGTH)
+	discoveryValueHMACKey, err := common.MakeRandomStringBase64(DISCOVERY_VALUE_KEY_BYTE_LENGTH)
 	if err != nil {
-		return nil, nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, common.ContextError(err)
 	}
 
 	// Assemble configs and server entry
@@ -459,12 +460,12 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 
 	encodedConfig, err := json.MarshalIndent(config, "\n", "    ")
 	if err != nil {
-		return nil, nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, common.ContextError(err)
 	}
 
 	trafficRulesSet := &TrafficRulesSet{
 		DefaultRules: TrafficRules{
-			DefaultLimits: RateLimits{
+			DefaultLimits: common.RateLimits{
 				DownstreamUnlimitedBytes: 0,
 				DownstreamBytesPerSecond: 0,
 				UpstreamUnlimitedBytes:   0,
@@ -483,17 +484,17 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 
 	encodedTrafficRulesSet, err := json.MarshalIndent(trafficRulesSet, "\n", "    ")
 	if err != nil {
-		return nil, nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, common.ContextError(err)
 	}
 
 	capabilities := []string{}
 
 	if params.EnableSSHAPIRequests {
-		capabilities = append(capabilities, psiphon.CAPABILITY_SSH_API_REQUESTS)
+		capabilities = append(capabilities, common.CAPABILITY_SSH_API_REQUESTS)
 	}
 
 	if params.WebServerPort != 0 {
-		capabilities = append(capabilities, psiphon.CAPABILITY_UNTUNNELED_WEB_API_REQUESTS)
+		capabilities = append(capabilities, common.CAPABILITY_UNTUNNELED_WEB_API_REQUESTS)
 	}
 
 	for protocol, _ := range params.TunnelProtocolPorts {
@@ -549,7 +550,7 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 
 	encodedServerEntry, err := psiphon.EncodeServerEntry(serverEntry)
 	if err != nil {
-		return nil, nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, common.ContextError(err)
 	}
 
 	return encodedConfig, encodedTrafficRulesSet, []byte(encodedServerEntry), nil

+ 11 - 11
psiphon/server/dns.go

@@ -28,7 +28,7 @@ import (
 	"sync/atomic"
 	"time"
 
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 const (
@@ -41,7 +41,7 @@ const (
 // "/etc/resolv.conf" on platforms where it is available; and
 // otherwise using a default value.
 type DNSResolver struct {
-	psiphon.ReloadableFile
+	common.ReloadableFile
 	lastReloadTime int64
 	isReloading    int32
 	resolver       net.IP
@@ -71,14 +71,14 @@ func NewDNSResolver(defaultResolver string) (*DNSResolver, error) {
 		lastReloadTime: time.Now().Unix(),
 	}
 
-	dns.ReloadableFile = psiphon.NewReloadableFile(
+	dns.ReloadableFile = common.NewReloadableFile(
 		DNS_SYSTEM_CONFIG_FILENAME,
 		func(filename string) error {
 
 			resolver, err := parseResolveConf(filename)
 			if err != nil {
 				// On error, state remains the same
-				return psiphon.ContextError(err)
+				return common.ContextError(err)
 			}
 
 			dns.resolver = resolver
@@ -94,7 +94,7 @@ func NewDNSResolver(defaultResolver string) (*DNSResolver, error) {
 	_, err := dns.Reload()
 	if err != nil {
 		if defaultResolver == "" {
-			return nil, psiphon.ContextError(err)
+			return nil, common.ContextError(err)
 		}
 
 		log.WithContextFields(
@@ -103,7 +103,7 @@ func NewDNSResolver(defaultResolver string) (*DNSResolver, error) {
 
 		resolver, err := parseResolver(defaultResolver)
 		if err != nil {
-			return nil, psiphon.ContextError(err)
+			return nil, common.ContextError(err)
 		}
 
 		dns.resolver = resolver
@@ -122,7 +122,7 @@ func (dns *DNSResolver) Get() net.IP {
 	// atomic.LoadInt64 reload time check and the RLock (an atomic.AddInt32
 	// when no write lock is pending). An atomic.CompareAndSwapInt32 is
 	// used to ensure only one goroutine enters Reload() and blocks on
-	// its write lock. Finally, since since psiphon.ReloadableFile.Reload
+	// its write lock. Finally, since since ReloadableFile.Reload
 	// checks whether the underlying file has changed _before_ aquiring a
 	// write lock, we only incur write lock blocking when "/etc/resolv.conf"
 	// has actually changed.
@@ -160,7 +160,7 @@ func (dns *DNSResolver) Get() net.IP {
 func parseResolveConf(filename string) (net.IP, error) {
 	file, err := os.Open(filename)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	defer file.Close()
 
@@ -179,16 +179,16 @@ func parseResolveConf(filename string) (net.IP, error) {
 		}
 	}
 	if err := scanner.Err(); err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
-	return nil, psiphon.ContextError(errors.New("nameserver not found"))
+	return nil, common.ContextError(errors.New("nameserver not found"))
 }
 
 func parseResolver(resolver string) (net.IP, error) {
 
 	ipAddress := net.ParseIP(resolver)
 	if ipAddress == nil {
-		return nil, psiphon.ContextError(errors.New("invalid IP address"))
+		return nil, common.ContextError(errors.New("invalid IP address"))
 	}
 
 	return ipAddress, nil

+ 7 - 7
psiphon/server/geoip.go

@@ -27,7 +27,7 @@ import (
 
 	cache "github.com/Psiphon-Inc/go-cache"
 	maxminddb "github.com/Psiphon-Inc/maxminddb-golang"
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 const (
@@ -69,7 +69,7 @@ type GeoIPService struct {
 }
 
 type geoIPDatabase struct {
-	psiphon.ReloadableFile
+	common.ReloadableFile
 	maxMindReader *maxminddb.Reader
 }
 
@@ -87,13 +87,13 @@ func NewGeoIPService(
 	for i, filename := range databaseFilenames {
 
 		database := &geoIPDatabase{}
-		database.ReloadableFile = psiphon.NewReloadableFile(
+		database.ReloadableFile = common.NewReloadableFile(
 			filename,
 			func(filename string) error {
 				maxMindReader, err := maxminddb.Open(filename)
 				if err != nil {
 					// On error, database state remains the same
-					return psiphon.ContextError(err)
+					return common.ContextError(err)
 				}
 				if database.maxMindReader != nil {
 					database.maxMindReader.Close()
@@ -104,7 +104,7 @@ func NewGeoIPService(
 
 		_, err := database.Reload()
 		if err != nil {
-			return nil, psiphon.ContextError(err)
+			return nil, common.ContextError(err)
 		}
 
 		geoIP.databases[i] = database
@@ -116,8 +116,8 @@ func NewGeoIPService(
 // Reloaders gets the list of reloadable databases in use
 // by the GeoIPService. This list is used to hot reload
 // these databases.
-func (geoIP *GeoIPService) Reloaders() []psiphon.Reloader {
-	reloaders := make([]psiphon.Reloader, len(geoIP.databases))
+func (geoIP *GeoIPService) Reloaders() []common.Reloader {
+	reloaders := make([]common.Reloader, len(geoIP.databases))
 	for i, database := range geoIP.databases {
 		reloaders[i] = database
 	}

+ 5 - 5
psiphon/server/log.go

@@ -24,7 +24,7 @@ import (
 	"os"
 
 	"github.com/Psiphon-Inc/logrus"
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // ContextLogger adds context logging functionality to the
@@ -43,7 +43,7 @@ type LogFields logrus.Fields
 func (logger *ContextLogger) WithContext() *logrus.Entry {
 	return log.WithFields(
 		logrus.Fields{
-			"context": psiphon.GetParentContext(),
+			"context": common.GetParentContext(),
 		})
 }
 
@@ -56,7 +56,7 @@ func (logger *ContextLogger) WithContextFields(fields LogFields) *logrus.Entry {
 	if ok {
 		fields["fields.context"] = fields["context"]
 	}
-	fields["context"] = psiphon.GetParentContext()
+	fields["context"] = common.GetParentContext()
 	return log.WithFields(logrus.Fields(fields))
 }
 
@@ -77,7 +77,7 @@ func InitLogging(config *Config) error {
 
 	level, err := logrus.ParseLevel(config.LogLevel)
 	if err != nil {
-		return psiphon.ContextError(err)
+		return common.ContextError(err)
 	}
 
 	logWriter := os.Stderr
@@ -86,7 +86,7 @@ func InitLogging(config *Config) error {
 		logWriter, err = os.OpenFile(
 			config.LogFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0666)
 		if err != nil {
-			return psiphon.ContextError(err)
+			return common.ContextError(err)
 		}
 	}
 

+ 23 - 22
psiphon/server/meek.go

@@ -34,6 +34,7 @@ import (
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"golang.org/x/crypto/nacl/box"
 )
 
@@ -82,7 +83,7 @@ type MeekServer struct {
 	listener      net.Listener
 	tlsConfig     *tls.Config
 	clientHandler func(clientConn net.Conn)
-	openConns     *psiphon.Conns
+	openConns     *common.Conns
 	stopBroadcast <-chan struct{}
 	sessionsLock  sync.RWMutex
 	sessions      map[string]*meekSession
@@ -100,7 +101,7 @@ func NewMeekServer(
 		support:       support,
 		listener:      listener,
 		clientHandler: clientHandler,
-		openConns:     new(psiphon.Conns),
+		openConns:     new(common.Conns),
 		stopBroadcast: stopBroadcast,
 		sessions:      make(map[string]*meekSession),
 	}
@@ -108,7 +109,7 @@ func NewMeekServer(
 	if useTLS {
 		tlsConfig, err := makeMeekTLSConfig(support)
 		if err != nil {
-			return nil, psiphon.ContextError(err)
+			return nil, common.ContextError(err)
 		}
 		meekServer.tlsConfig = tlsConfig
 	}
@@ -168,7 +169,7 @@ func (server *MeekServer) Run() error {
 	var err error
 	if server.tlsConfig != nil {
 		httpServer.TLSConfig = server.tlsConfig
-		httpsServer := psiphon.HTTPSServer{Server: *httpServer}
+		httpsServer := HTTPSServer{Server: *httpServer}
 		err = httpsServer.ServeTLS(server.listener)
 	} else {
 		err = httpServer.Serve(server.listener)
@@ -296,7 +297,7 @@ func (server *MeekServer) getSession(
 
 	payloadJSON, err := getMeekCookiePayload(server.support, meekCookie.Value)
 	if err != nil {
-		return "", nil, psiphon.ContextError(err)
+		return "", nil, common.ContextError(err)
 	}
 
 	// Note: this meek server ignores all but Version MeekProtocolVersion;
@@ -309,7 +310,7 @@ func (server *MeekServer) getSession(
 
 	err = json.Unmarshal(payloadJSON, &clientSessionData)
 	if err != nil {
-		return "", nil, psiphon.ContextError(err)
+		return "", nil, common.ContextError(err)
 	}
 
 	// Determine the client remote address, which is used for geolocation
@@ -368,7 +369,7 @@ func (server *MeekServer) getSession(
 	if clientSessionData.MeekProtocolVersion >= MEEK_PROTOCOL_VERSION_2 {
 		sessionID, err = makeMeekSessionID()
 		if err != nil {
-			return "", nil, psiphon.ContextError(err)
+			return "", nil, common.ContextError(err)
 		}
 	}
 
@@ -466,13 +467,13 @@ func makeMeekTLSConfig(support *SupportServices) (*tls.Config, error) {
 	certificate, privateKey, err := GenerateWebServerCertificate(
 		support.Config.MeekCertificateCommonName)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	tlsCertificate, err := tls.X509KeyPair(
 		[]byte(certificate), []byte(privateKey))
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return &tls.Config{
@@ -514,7 +515,7 @@ func makeMeekTLSConfig(support *SupportServices) (*tls.Config, error) {
 func getMeekCookiePayload(support *SupportServices, cookieValue string) ([]byte, error) {
 	decodedValue, err := base64.StdEncoding.DecodeString(cookieValue)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// The data consists of an obfuscated seed message prepended
@@ -528,12 +529,12 @@ func getMeekCookiePayload(support *SupportServices, cookieValue string) ([]byte,
 		reader,
 		&psiphon.ObfuscatorConfig{Keyword: support.Config.MeekObfuscatedKey})
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	offset, err := reader.Seek(0, 1)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	encryptedPayload := decodedValue[offset:]
 
@@ -545,18 +546,18 @@ func getMeekCookiePayload(support *SupportServices, cookieValue string) ([]byte,
 	decodedPrivateKey, err := base64.StdEncoding.DecodeString(
 		support.Config.MeekCookieEncryptionPrivateKey)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	copy(privateKey[:], decodedPrivateKey)
 
 	if len(encryptedPayload) < 32 {
-		return nil, psiphon.ContextError(errors.New("unexpected encrypted payload size"))
+		return nil, common.ContextError(errors.New("unexpected encrypted payload size"))
 	}
 	copy(ephemeralPublicKey[0:32], encryptedPayload[0:32])
 
 	payload, ok := box.Open(nil, encryptedPayload[32:], &nonce, &ephemeralPublicKey, &privateKey)
 	if !ok {
-		return nil, psiphon.ContextError(errors.New("open box failed"))
+		return nil, common.ContextError(errors.New("open box failed"))
 	}
 
 	return payload, nil
@@ -566,14 +567,14 @@ func getMeekCookiePayload(support *SupportServices, cookieValue string) ([]byte,
 // frustrate traffic analysis of both plaintext and TLS meek traffic.
 func makeMeekSessionID() (string, error) {
 	size := MEEK_MIN_SESSION_ID_LENGTH
-	n, err := psiphon.MakeSecureRandomInt(MEEK_MAX_SESSION_ID_LENGTH - MEEK_MIN_SESSION_ID_LENGTH)
+	n, err := common.MakeSecureRandomInt(MEEK_MAX_SESSION_ID_LENGTH - MEEK_MIN_SESSION_ID_LENGTH)
 	if err != nil {
-		return "", psiphon.ContextError(err)
+		return "", common.ContextError(err)
 	}
 	size += n
-	sessionID, err := psiphon.MakeRandomStringBase64(size)
+	sessionID, err := common.MakeRandomStringBase64(size)
 	if err != nil {
-		return "", psiphon.ContextError(err)
+		return "", common.ContextError(err)
 	}
 	return sessionID, nil
 }
@@ -785,15 +786,15 @@ func (conn *meekConn) SetDeadline(t time.Time) error {
 	if time.Now().Add(MEEK_MAX_SESSION_STALENESS).Before(t) {
 		return nil
 	}
-	return psiphon.ContextError(errors.New("not supported"))
+	return common.ContextError(errors.New("not supported"))
 }
 
 // Stub implementation of net.Conn.SetReadDeadline
 func (conn *meekConn) SetReadDeadline(t time.Time) error {
-	return psiphon.ContextError(errors.New("not supported"))
+	return common.ContextError(errors.New("not supported"))
 }
 
 // Stub implementation of net.Conn.SetWriteDeadline
 func (conn *meekConn) SetWriteDeadline(t time.Time) error {
-	return psiphon.ContextError(errors.New("not supported"))
+	return common.ContextError(errors.New("not supported"))
 }

+ 61 - 80
psiphon/server/net.go

@@ -17,18 +17,49 @@
  *
  */
 
+// for HTTPSServer.ServeTLS:
+/*
+Copyright (c) 2012 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+   * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+   * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+   * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+
 package server
 
 import (
 	"container/list"
-	"io"
+	"crypto/tls"
 	"net"
+	"net/http"
 	"sync"
 	"sync/atomic"
 	"time"
 
-	"github.com/Psiphon-Inc/ratelimit"
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // LRUConns is a concurrency-safe list of net.Conns ordered
@@ -146,7 +177,7 @@ func NewActivityMonitoredConn(
 	if inactivityTimeout > 0 {
 		err := conn.SetDeadline(time.Now().Add(inactivityTimeout))
 		if err != nil {
-			return nil, psiphon.ContextError(err)
+			return nil, common.ContextError(err)
 		}
 	}
 
@@ -182,7 +213,7 @@ func (conn *ActivityMonitoredConn) Read(buffer []byte) (int, error) {
 		if conn.inactivityTimeout > 0 {
 			err = conn.Conn.SetDeadline(time.Now().Add(conn.inactivityTimeout))
 			if err != nil {
-				return n, psiphon.ContextError(err)
+				return n, common.ContextError(err)
 			}
 		}
 		if conn.lruEntry != nil {
@@ -203,7 +234,7 @@ func (conn *ActivityMonitoredConn) Write(buffer []byte) (int, error) {
 		if conn.inactivityTimeout > 0 {
 			err = conn.Conn.SetDeadline(time.Now().Add(conn.inactivityTimeout))
 			if err != nil {
-				return n, psiphon.ContextError(err)
+				return n, common.ContextError(err)
 			}
 		}
 
@@ -216,85 +247,35 @@ func (conn *ActivityMonitoredConn) Write(buffer []byte) (int, error) {
 	return n, err
 }
 
-// ThrottledConn wraps a net.Conn with read and write rate limiters.
-// Rates are specified as bytes per second. Optional unlimited byte
-// counts allow for a number of bytes to read or write before
-// applying rate limiting. Specify limit values of 0 to set no rate
-// limit (unlimited counts are ignored in this case).
-// The underlying rate limiter uses the token bucket algorithm to
-// calculate delay times for read and write operations.
-type ThrottledConn struct {
-	net.Conn
-	unlimitedReadBytes  int64
-	limitingReads       int32
-	limitedReader       io.Reader
-	unlimitedWriteBytes int64
-	limitingWrites      int32
-	limitedWriter       io.Writer
+// HTTPSServer is a wrapper around http.Server which adds the
+// ServeTLS function.
+type HTTPSServer struct {
+	http.Server
 }
 
-// NewThrottledConn initializes a new ThrottledConn.
-func NewThrottledConn(
-	conn net.Conn,
-	unlimitedReadBytes, limitReadBytesPerSecond,
-	unlimitedWriteBytes, limitWriteBytesPerSecond int64) *ThrottledConn {
-
-	// When no limit is specified, the rate limited reader/writer
-	// is simply the base reader/writer.
-
-	var reader io.Reader
-	if limitReadBytesPerSecond == 0 {
-		reader = conn
-	} else {
-		reader = ratelimit.Reader(conn,
-			ratelimit.NewBucketWithRate(
-				float64(limitReadBytesPerSecond), limitReadBytesPerSecond))
-	}
-
-	var writer io.Writer
-	if limitWriteBytesPerSecond == 0 {
-		writer = conn
-	} else {
-		writer = ratelimit.Writer(conn,
-			ratelimit.NewBucketWithRate(
-				float64(limitWriteBytesPerSecond), limitWriteBytesPerSecond))
-	}
-
-	return &ThrottledConn{
-		Conn:                conn,
-		unlimitedReadBytes:  unlimitedReadBytes,
-		limitingReads:       0,
-		limitedReader:       reader,
-		unlimitedWriteBytes: unlimitedWriteBytes,
-		limitingWrites:      0,
-		limitedWriter:       writer,
-	}
+// ServeTLS is a offers the equivalent interface as http.Serve.
+// The http package has both ListenAndServe and ListenAndServeTLS higher-
+// level interfaces, but only Serve (not TLS) offers a lower-level interface that
+// allows the caller to keep a refererence to the Listener, allowing for external
+// shutdown. ListenAndServeTLS also requires the TLS cert and key to be in files
+// and we avoid that here.
+// tcpKeepAliveListener is used in http.ListenAndServeTLS but not exported,
+// so we use a copy from https://golang.org/src/net/http/server.go.
+func (server *HTTPSServer) ServeTLS(listener net.Listener) error {
+	tlsListener := tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, server.TLSConfig)
+	return server.Serve(tlsListener)
 }
 
-func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
-
-	// Use the base reader until the unlimited count is exhausted.
-	if atomic.LoadInt32(&conn.limitingReads) == 0 {
-		if atomic.AddInt64(&conn.unlimitedReadBytes, -int64(len(buffer))) <= 0 {
-			atomic.StoreInt32(&conn.limitingReads, 1)
-		} else {
-			return conn.Read(buffer)
-		}
-	}
-
-	return conn.limitedReader.Read(buffer)
+type tcpKeepAliveListener struct {
+	*net.TCPListener
 }
 
-func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
-
-	// Use the base writer until the unlimited count is exhausted.
-	if atomic.LoadInt32(&conn.limitingWrites) == 0 {
-		if atomic.AddInt64(&conn.unlimitedWriteBytes, -int64(len(buffer))) <= 0 {
-			atomic.StoreInt32(&conn.limitingWrites, 1)
-		} else {
-			return conn.Write(buffer)
-		}
+func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
+	tc, err := ln.AcceptTCP()
+	if err != nil {
+		return
 	}
-
-	return conn.limitedWriter.Write(buffer)
+	tc.SetKeepAlive(true)
+	tc.SetKeepAlivePeriod(3 * time.Minute)
+	return tc, nil
 }

+ 6 - 6
psiphon/server/psinet/psinet.go

@@ -34,14 +34,14 @@ import (
 	"strings"
 	"time"
 
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // Database serves Psiphon API data requests. It's safe for
 // concurrent usage. The Reload function supports hot reloading
 // of Psiphon network data while the server is running.
 type Database struct {
-	psiphon.ReloadableFile
+	common.ReloadableFile
 
 	AlternateMeekFrontingAddresses      map[string][]string        `json:"alternate_meek_fronting_addresses"`
 	AlternateMeekFrontingAddressesRegex map[string]string          `json:"alternate_meek_fronting_addresses_regex"`
@@ -130,27 +130,27 @@ func NewDatabase(filename string) (*Database, error) {
 
 	database := &Database{}
 
-	database.ReloadableFile = psiphon.NewReloadableFile(
+	database.ReloadableFile = common.NewReloadableFile(
 		filename,
 		func(filename string) error {
 			psinetJSON, err := ioutil.ReadFile(filename)
 			if err != nil {
 				// On error, state remains the same
-				return psiphon.ContextError(err)
+				return common.ContextError(err)
 			}
 			err = json.Unmarshal(psinetJSON, &database)
 			if err != nil {
 				// On error, state remains the same
 				// (Unmarshal first validates the provided
 				//  JOSN and then populates the interface)
-				return psiphon.ContextError(err)
+				return common.ContextError(err)
 			}
 			return nil
 		})
 
 	_, err := database.Reload()
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return database, nil

+ 2 - 2
psiphon/server/safetyNet.go

@@ -29,7 +29,7 @@ import (
 	"strings"
 	"time"
 
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 const (
@@ -164,7 +164,7 @@ func (body *jwtBody) verifyJWTBody() (validApkCert, validApkPackageName bool) {
 	}
 
 	// Verify apk package name
-	if psiphon.Contains(psiphonApkPackagenames, body.ApkPackageName) {
+	if common.Contains(psiphonApkPackagenames, body.ApkPackageName) {
 		validApkPackageName = true
 	}
 

+ 5 - 4
psiphon/server/server_test.go

@@ -33,6 +33,7 @@ import (
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 func TestMain(m *testing.M) {
@@ -240,7 +241,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	psiphon.SetNoticeOutput(psiphon.NewNoticeReceiver(
 		func(notice []byte) {
 
-			//fmt.Printf("%s\n", string(notice))
+			fmt.Printf("%s\n", string(notice))
 
 			noticeType, payload, err := psiphon.GetNotice(notice)
 			if err != nil {
@@ -341,10 +342,10 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 
 func pavePsinetDatabaseFile(t *testing.T, psinetFilename string) (string, string) {
 
-	sponsorID, _ := psiphon.MakeRandomStringHex(8)
+	sponsorID, _ := common.MakeRandomStringHex(8)
 
-	fakeDomain, _ := psiphon.MakeRandomStringHex(4)
-	fakePath, _ := psiphon.MakeRandomStringHex(4)
+	fakeDomain, _ := common.MakeRandomStringHex(4)
+	fakePath, _ := common.MakeRandomStringHex(4)
 	expectedHomepageURL := fmt.Sprintf("https://%s.com/%s", fakeDomain, fakePath)
 
 	psinetJSONFormat := `

+ 10 - 10
psiphon/server/services.go

@@ -31,7 +31,7 @@ import (
 	"syscall"
 	"time"
 
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/server/psinet"
 )
 
@@ -43,19 +43,19 @@ func RunServices(configJSON []byte) error {
 	config, err := LoadConfig(configJSON)
 	if err != nil {
 		log.WithContextFields(LogFields{"error": err}).Error("load config failed")
-		return psiphon.ContextError(err)
+		return common.ContextError(err)
 	}
 
 	err = InitLogging(config)
 	if err != nil {
 		log.WithContextFields(LogFields{"error": err}).Error("init logging failed")
-		return psiphon.ContextError(err)
+		return common.ContextError(err)
 	}
 
 	supportServices, err := NewSupportServices(config)
 	if err != nil {
 		log.WithContextFields(LogFields{"error": err}).Error("init support services failed")
-		return psiphon.ContextError(err)
+		return common.ContextError(err)
 	}
 
 	waitGroup := new(sync.WaitGroup)
@@ -65,7 +65,7 @@ func RunServices(configJSON []byte) error {
 	tunnelServer, err := NewTunnelServer(supportServices, shutdownBroadcast)
 	if err != nil {
 		log.WithContextFields(LogFields{"error": err}).Error("init tunnel server failed")
-		return psiphon.ContextError(err)
+		return common.ContextError(err)
 	}
 
 	if config.RunLoadMonitor() {
@@ -188,23 +188,23 @@ type SupportServices struct {
 func NewSupportServices(config *Config) (*SupportServices, error) {
 	trafficRulesSet, err := NewTrafficRulesSet(config.TrafficRulesFilename)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	psinetDatabase, err := psinet.NewDatabase(config.PsinetDatabaseFilename)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	geoIPService, err := NewGeoIPService(
 		config.GeoIPDatabaseFilenames, config.DiscoveryValueHMACKey)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	dnsResolver, err := NewDNSResolver(config.DNSResolverIPAddress)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return &SupportServices{
@@ -225,7 +225,7 @@ func NewSupportServices(config *Config) (*SupportServices, error) {
 func (support *SupportServices) Reload() {
 
 	reloaders := append(
-		[]psiphon.Reloader{support.TrafficRulesSet, support.PsinetDatabase},
+		[]common.Reloader{support.TrafficRulesSet, support.PsinetDatabase},
 		support.GeoIPService.Reloaders()...)
 
 	for _, reloader := range reloaders {

+ 9 - 32
psiphon/server/trafficRules.go

@@ -24,14 +24,14 @@ import (
 	"io/ioutil"
 	"strings"
 
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // TrafficRulesSet represents the various traffic rules to
 // apply to Psiphon client tunnels. The Reload function supports
 // hot reloading of rules data while the server is running.
 type TrafficRulesSet struct {
-	psiphon.ReloadableFile
+	common.ReloadableFile
 
 	// DefaultRules specifies the traffic rules to be used when no
 	// regional-specific rules are set or apply to a particular
@@ -45,41 +45,18 @@ type TrafficRulesSet struct {
 	RegionalRules map[string]TrafficRules
 }
 
-// RateLimits specify the rate limits for tunneled data transfer
-// between an individual client and the server.
-type RateLimits struct {
-
-	// DownstreamUnlimitedBytes specifies the number of downstream
-	// bytes to transfer, approximately, before starting rate
-	// limiting.
-	DownstreamUnlimitedBytes int64
-
-	// DownstreamBytesPerSecond specifies a rate limit for downstream
-	// data transfer. The default, 0, is no limit.
-	DownstreamBytesPerSecond int
-
-	// UpstreamUnlimitedBytes specifies the number of upstream
-	// bytes to transfer, approximately, before starting rate
-	// limiting.
-	UpstreamUnlimitedBytes int64
-
-	// UpstreamBytesPerSecond specifies a rate limit for upstream
-	// data transfer. The default, 0, is no limit.
-	UpstreamBytesPerSecond int
-}
-
 // TrafficRules specify the limits placed on client traffic.
 type TrafficRules struct {
 	// DefaultLimits are the rate limits to be applied when
 	// no protocol-specific rates are set.
-	DefaultLimits RateLimits
+	DefaultLimits common.RateLimits
 
 	// ProtocolLimits specifies the rate limits for particular
 	// tunnel protocols. The key for each rate limit entry is one
 	// or more space delimited Psiphon tunnel protocol names. Valid
 	// tunnel protocols includes the same list as for
 	// TunnelProtocolPorts.
-	ProtocolLimits map[string]RateLimits
+	ProtocolLimits map[string]common.RateLimits
 
 	// IdleTCPPortForwardTimeoutMilliseconds is the timeout period
 	// after which idle (no bytes flowing in either direction)
@@ -130,27 +107,27 @@ func NewTrafficRulesSet(filename string) (*TrafficRulesSet, error) {
 
 	set := &TrafficRulesSet{}
 
-	set.ReloadableFile = psiphon.NewReloadableFile(
+	set.ReloadableFile = common.NewReloadableFile(
 		filename,
 		func(filename string) error {
 			configJSON, err := ioutil.ReadFile(filename)
 			if err != nil {
 				// On error, state remains the same
-				return psiphon.ContextError(err)
+				return common.ContextError(err)
 			}
 			err = json.Unmarshal(configJSON, &set)
 			if err != nil {
 				// On error, state remains the same
 				// (Unmarshal first validates the provided
 				//  JOSN and then populates the interface)
-				return psiphon.ContextError(err)
+				return common.ContextError(err)
 			}
 			return nil
 		})
 
 	_, err := set.Reload()
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return set, nil
@@ -176,7 +153,7 @@ func (set *TrafficRulesSet) GetTrafficRules(clientCountryCode string) TrafficRul
 // GetRateLimits looks up the rate limits for the specified tunnel protocol.
 // If there are no specific RateLimits for the protocol, default RateLimits are
 // returned.
-func (rules *TrafficRules) GetRateLimits(clientTunnelProtocol string) RateLimits {
+func (rules *TrafficRules) GetRateLimits(clientTunnelProtocol string) common.RateLimits {
 
 	// TODO: faster lookup?
 	for tunnelProtocols, rateLimits := range rules.ProtocolLimits {

+ 22 - 26
psiphon/server/tunnelServer.go

@@ -31,6 +31,7 @@ import (
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"golang.org/x/crypto/ssh"
 )
 
@@ -68,7 +69,7 @@ func NewTunnelServer(
 
 	sshServer, err := newSSHServer(support, shutdownBroadcast)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return &TunnelServer{
@@ -130,7 +131,7 @@ func (server *TunnelServer) Run() error {
 			for _, existingListener := range listeners {
 				existingListener.Listener.Close()
 			}
-			return psiphon.ContextError(err)
+			return common.ContextError(err)
 		}
 
 		log.WithContextFields(
@@ -209,13 +210,13 @@ func newSSHServer(
 
 	privateKey, err := ssh.ParseRawPrivateKey([]byte(support.Config.SSHPrivateKey))
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
 	signer, err := ssh.NewSignerFromKey(privateKey)
 	if err != nil {
-		return nil, psiphon.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return &sshServer{
@@ -246,18 +247,18 @@ func (sshServer *sshServer) runListener(
 	// TunnelServer.Run will properly shut down instead of remaining
 	// running.
 
-	if psiphon.TunnelProtocolUsesMeekHTTP(tunnelProtocol) ||
-		psiphon.TunnelProtocolUsesMeekHTTPS(tunnelProtocol) {
+	if common.TunnelProtocolUsesMeekHTTP(tunnelProtocol) ||
+		common.TunnelProtocolUsesMeekHTTPS(tunnelProtocol) {
 
 		meekServer, err := NewMeekServer(
 			sshServer.support,
 			listener,
-			psiphon.TunnelProtocolUsesMeekHTTPS(tunnelProtocol),
+			common.TunnelProtocolUsesMeekHTTPS(tunnelProtocol),
 			handleClient,
 			sshServer.shutdownBroadcast)
 		if err != nil {
 			select {
-			case listenerError <- psiphon.ContextError(err):
+			case listenerError <- common.ContextError(err):
 			default:
 			}
 			return
@@ -287,7 +288,7 @@ func (sshServer *sshServer) runListener(
 				}
 
 				select {
-				case listenerError <- psiphon.ContextError(err):
+				case listenerError <- common.ContextError(err):
 				default:
 				}
 				return
@@ -406,7 +407,7 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 	defer sshServer.unregisterAcceptedClient(tunnelProtocol)
 
 	geoIPData := sshServer.support.GeoIPService.Lookup(
-		psiphon.IPAddressFromAddr(clientConn.RemoteAddr()))
+		common.IPAddressFromAddr(clientConn.RemoteAddr()))
 
 	// TODO: apply reload of TrafficRulesSet to existing clients
 
@@ -437,13 +438,8 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 
 	// Further wrap the connection in a rate limiting ThrottledConn.
 
-	rateLimits := sshClient.trafficRules.GetRateLimits(tunnelProtocol)
-	clientConn = NewThrottledConn(
-		clientConn,
-		rateLimits.DownstreamUnlimitedBytes,
-		int64(rateLimits.DownstreamBytesPerSecond),
-		rateLimits.UpstreamUnlimitedBytes,
-		int64(rateLimits.UpstreamBytesPerSecond))
+	clientConn = common.NewThrottledConn(
+		clientConn, sshClient.trafficRules.GetRateLimits(tunnelProtocol))
 
 	// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
 	// respect shutdownBroadcast and implement a specific handshake timeout.
@@ -478,7 +474,7 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 
 		// Wrap the connection in an SSH deobfuscator when required.
 
-		if psiphon.TunnelProtocolUsesObfuscatedSSH(tunnelProtocol) {
+		if common.TunnelProtocolUsesObfuscatedSSH(tunnelProtocol) {
 			// Note: NewObfuscatedSshConn blocks on network I/O
 			// TODO: ensure this won't block shutdown
 			conn, result.err = psiphon.NewObfuscatedSshConn(
@@ -486,7 +482,7 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 				clientConn,
 				sshServer.support.Config.ObfuscatedSSHKey)
 			if result.err != nil {
-				result.err = psiphon.ContextError(result.err)
+				result.err = common.ContextError(result.err)
 			}
 		}
 
@@ -590,16 +586,16 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
 		// send the hex encoded session ID prepended to the SSH password.
 		// Note: there's an even older case where clients don't send any session ID,
 		// but that's no longer supported.
-		if len(password) == 2*psiphon.PSIPHON_API_CLIENT_SESSION_ID_LENGTH+2*SSH_PASSWORD_BYTE_LENGTH {
-			sshPasswordPayload.SessionId = string(password[0 : 2*psiphon.PSIPHON_API_CLIENT_SESSION_ID_LENGTH])
-			sshPasswordPayload.SshPassword = string(password[2*psiphon.PSIPHON_API_CLIENT_SESSION_ID_LENGTH : len(password)])
+		if len(password) == 2*common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH+2*SSH_PASSWORD_BYTE_LENGTH {
+			sshPasswordPayload.SessionId = string(password[0 : 2*common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH])
+			sshPasswordPayload.SshPassword = string(password[2*common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH : len(password)])
 		} else {
-			return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
+			return nil, common.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
 		}
 	}
 
 	if !isHexDigits(sshClient.sshServer.support, sshPasswordPayload.SessionId) {
-		return nil, psiphon.ContextError(fmt.Errorf("invalid session ID for %q", conn.User()))
+		return nil, common.ContextError(fmt.Errorf("invalid session ID for %q", conn.User()))
 	}
 
 	userOk := (subtle.ConstantTimeCompare(
@@ -609,7 +605,7 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
 		[]byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.support.Config.SSHPassword)) == 1)
 
 	if !userOk || !passwordOk {
-		return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
+		return nil, common.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
 	}
 
 	psiphonSessionID := sshPasswordPayload.SessionId
@@ -806,7 +802,7 @@ func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChanne
 func (sshClient *sshClient) isPortForwardPermitted(
 	host string, port int, allowPorts []int, denyPorts []int) bool {
 
-	if psiphon.Contains(SSH_DISALLOWED_PORT_FORWARD_HOSTS, host) {
+	if common.Contains(SSH_DISALLOWED_PORT_FORWARD_HOSTS, host) {
 		return false
 	}
 

+ 7 - 7
psiphon/server/udp.go

@@ -29,7 +29,7 @@ import (
 	"sync/atomic"
 	"time"
 
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"golang.org/x/crypto/ssh"
 )
 
@@ -384,18 +384,18 @@ func readUdpgwMessage(
 
 		_, err := io.ReadFull(reader, buffer[0:2])
 		if err != nil {
-			return nil, psiphon.ContextError(err)
+			return nil, common.ContextError(err)
 		}
 
 		size := uint16(buffer[0]) + uint16(buffer[1])<<8
 
 		if int(size) > len(buffer)-2 {
-			return nil, psiphon.ContextError(errors.New("invalid udpgw message size"))
+			return nil, common.ContextError(errors.New("invalid udpgw message size"))
 		}
 
 		_, err = io.ReadFull(reader, buffer[2:2+size])
 		if err != nil {
-			return nil, psiphon.ContextError(err)
+			return nil, common.ContextError(err)
 		}
 
 		flags := buffer[2]
@@ -417,7 +417,7 @@ func readUdpgwMessage(
 		if flags&udpgwProtocolFlagIPv6 == udpgwProtocolFlagIPv6 {
 
 			if size < 21 {
-				return nil, psiphon.ContextError(errors.New("invalid udpgw message size"))
+				return nil, common.ContextError(errors.New("invalid udpgw message size"))
 			}
 
 			remoteIP = make([]byte, 16)
@@ -429,7 +429,7 @@ func readUdpgwMessage(
 		} else {
 
 			if size < 9 {
-				return nil, psiphon.ContextError(errors.New("invalid udpgw message size"))
+				return nil, common.ContextError(errors.New("invalid udpgw message size"))
 			}
 
 			remoteIP = make([]byte, 4)
@@ -465,7 +465,7 @@ func writeUdpgwPreamble(
 	buffer []byte) error {
 
 	if preambleSize != 7+len(remoteIP) {
-		return errors.New("invalid udpgw preamble size")
+		return common.ContextError(errors.New("invalid udpgw preamble size"))
 	}
 
 	size := uint16(preambleSize-2) + packetSize

+ 7 - 7
psiphon/server/utils.go

@@ -29,7 +29,7 @@ import (
 	"math/big"
 	"time"
 
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // GenerateWebServerCertificate creates a self-signed web server certificate,
@@ -53,15 +53,15 @@ func GenerateWebServerCertificate(commonName string) (string, string, error) {
 
 	rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
 	if err != nil {
-		return "", "", psiphon.ContextError(err)
+		return "", "", common.ContextError(err)
 	}
 
 	// Validity period is ~10 years, starting some number of ~months
 	// back in the last year.
 
-	age, err := psiphon.MakeSecureRandomInt(12)
+	age, err := common.MakeSecureRandomInt(12)
 	if err != nil {
-		return "", "", psiphon.ContextError(err)
+		return "", "", common.ContextError(err)
 	}
 	age += 1
 	validityPeriod := 10 * 365 * 24 * time.Hour
@@ -71,12 +71,12 @@ func GenerateWebServerCertificate(commonName string) (string, string, error) {
 	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
 	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
 	if err != nil {
-		return "", "", psiphon.ContextError(err)
+		return "", "", common.ContextError(err)
 	}
 
 	publicKeyBytes, err := x509.MarshalPKIXPublicKey(rsaKey.Public())
 	if err != nil {
-		return "", "", psiphon.ContextError(err)
+		return "", "", common.ContextError(err)
 	}
 	// as per RFC3280 sec. 4.2.1.2
 	subjectKeyID := sha1.Sum(publicKeyBytes)
@@ -107,7 +107,7 @@ func GenerateWebServerCertificate(commonName string) (string, string, error) {
 		rsaKey.Public(),
 		rsaKey)
 	if err != nil {
-		return "", "", psiphon.ContextError(err)
+		return "", "", common.ContextError(err)
 	}
 
 	webServerCertificate := pem.EncodeToMemory(

+ 8 - 8
psiphon/server/webServer.go

@@ -30,7 +30,7 @@ import (
 	"sync"
 	"time"
 
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 const WEB_SERVER_IO_TIMEOUT = 10 * time.Second
@@ -72,7 +72,7 @@ func RunWebServer(
 		[]byte(support.Config.WebServerCertificate),
 		[]byte(support.Config.WebServerPrivateKey))
 	if err != nil {
-		return psiphon.ContextError(err)
+		return common.ContextError(err)
 	}
 
 	tlsConfig := &tls.Config{
@@ -86,7 +86,7 @@ func RunWebServer(
 	// Note: WriteTimeout includes time awaiting request, as per:
 	// https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts
 
-	server := &psiphon.HTTPSServer{
+	server := &HTTPSServer{
 		http.Server{
 			MaxHeaderBytes: MAX_API_PARAMS_SIZE,
 			Handler:        serveMux,
@@ -102,7 +102,7 @@ func RunWebServer(
 			support.Config.ServerIPAddress,
 			support.Config.WebServerPort))
 	if err != nil {
-		return psiphon.ContextError(err)
+		return common.ContextError(err)
 	}
 
 	log.WithContext().Info("starting")
@@ -126,7 +126,7 @@ func RunWebServer(
 		default:
 			if err != nil {
 				select {
-				case errors <- psiphon.ContextError(err):
+				case errors <- common.ContextError(err):
 				default:
 				}
 			}
@@ -177,7 +177,7 @@ func convertHTTPRequestToAPIRequest(
 				var arrayValue []interface{}
 				err := json.Unmarshal([]byte(value), &arrayValue)
 				if err != nil {
-					return nil, psiphon.ContextError(err)
+					return nil, common.ContextError(err)
 				}
 				params[name] = arrayValue
 			} else {
@@ -192,12 +192,12 @@ func convertHTTPRequestToAPIRequest(
 		r.Body = http.MaxBytesReader(w, r.Body, MAX_API_PARAMS_SIZE)
 		body, err := ioutil.ReadAll(r.Body)
 		if err != nil {
-			return nil, psiphon.ContextError(err)
+			return nil, common.ContextError(err)
 		}
 		var bodyParams requestJSONObject
 		err = json.Unmarshal(body, &bodyParams)
 		if err != nil {
-			return nil, psiphon.ContextError(err)
+			return nil, common.ContextError(err)
 		}
 		params[requestBodyName] = bodyParams
 	}

+ 53 - 54
psiphon/serverApi.go

@@ -34,16 +34,10 @@ import (
 	"sync/atomic"
 	"time"
 
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/transferstats"
 )
 
-const (
-	SERVER_API_HANDSHAKE_REQUEST_NAME           = "psiphon-handshake"
-	SERVER_API_CONNECTED_REQUEST_NAME           = "psiphon-connected"
-	SERVER_API_STATUS_REQUEST_NAME              = "psiphon-status"
-	SERVER_API_CLIENT_VERIFICATION_REQUEST_NAME = "psiphon-client-verification"
-)
-
 // ServerContext is a utility struct which holds all of the data associated
 // with a Psiphon server connection. In addition to the established tunnel, this
 // includes data and transport mechanisms for Psiphon API requests. Legacy servers
@@ -74,9 +68,9 @@ var nextTunnelNumber int64
 // Controller (e.g., the user's commanded start and stop) and we measure this
 // duration as well as the duration of each tunnel within the session.
 func MakeSessionId() (sessionId string, err error) {
-	randomId, err := MakeSecureRandomBytes(PSIPHON_API_CLIENT_SESSION_ID_LENGTH)
+	randomId, err := common.MakeSecureRandomBytes(common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH)
 	if err != nil {
-		return "", ContextError(err)
+		return "", common.ContextError(err)
 	}
 	return hex.EncodeToString(randomId), nil
 }
@@ -93,7 +87,7 @@ func NewServerContext(tunnel *Tunnel, sessionId string) (*ServerContext, error)
 		var err error
 		psiphonHttpsClient, err = makePsiphonHttpsClient(tunnel)
 		if err != nil {
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 	}
 
@@ -106,7 +100,7 @@ func NewServerContext(tunnel *Tunnel, sessionId string) (*ServerContext, error)
 
 	err := serverContext.doHandshakeRequest()
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return serverContext, nil
@@ -123,7 +117,7 @@ func (serverContext *ServerContext) doHandshakeRequest() error {
 	/*
 		serverEntryIpAddresses, err := GetServerEntryIpAddresses()
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 
 		// Submit a list of known servers -- this will be used for
@@ -138,13 +132,13 @@ func (serverContext *ServerContext) doHandshakeRequest() error {
 
 		request, err := makeSSHAPIRequestPayload(params)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 
 		response, err = serverContext.tunnel.SendAPIRequest(
-			SERVER_API_HANDSHAKE_REQUEST_NAME, request)
+			common.PSIPHON_API_HANDSHAKE_REQUEST_NAME, request)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 
 	} else {
@@ -154,7 +148,7 @@ func (serverContext *ServerContext) doHandshakeRequest() error {
 		responseBody, err := serverContext.doGetRequest(
 			makeRequestUrl(serverContext.tunnel, "", "handshake", params))
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 		// Skip legacy format lines and just parse the JSON config line
 		configLinePrefix := []byte("Config: ")
@@ -165,7 +159,7 @@ func (serverContext *ServerContext) doHandshakeRequest() error {
 			}
 		}
 		if len(response) == 0 {
-			return ContextError(errors.New("no config line found"))
+			return common.ContextError(errors.New("no config line found"))
 		}
 	}
 
@@ -187,7 +181,7 @@ func (serverContext *ServerContext) doHandshakeRequest() error {
 	}
 	err := json.Unmarshal(response, &handshakeResponse)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	serverContext.clientRegion = handshakeResponse.ClientRegion
@@ -202,10 +196,10 @@ func (serverContext *ServerContext) doHandshakeRequest() error {
 
 		serverEntry, err := DecodeServerEntry(
 			encodedServerEntry,
-			TruncateTimestampToHour(handshakeResponse.ServerTimestamp),
-			SERVER_ENTRY_SOURCE_DISCOVERY)
+			common.TruncateTimestampToHour(handshakeResponse.ServerTimestamp),
+			common.SERVER_ENTRY_SOURCE_DISCOVERY)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 
 		err = ValidateServerEntry(serverEntry)
@@ -222,7 +216,7 @@ func (serverContext *ServerContext) doHandshakeRequest() error {
 	// StoreServerEntries that don't get triggered by StoreServerEntry.
 	err = StoreServerEntries(decodedServerEntries, true)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	// TODO: formally communicate the sponsor and upgrade info to an
@@ -271,7 +265,7 @@ func (serverContext *ServerContext) DoConnectedRequest() error {
 	const DATA_STORE_LAST_CONNECTED_KEY = "lastConnected"
 	lastConnected, err := GetKeyValue(DATA_STORE_LAST_CONNECTED_KEY)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	if lastConnected == "" {
 		lastConnected = "None"
@@ -284,13 +278,13 @@ func (serverContext *ServerContext) DoConnectedRequest() error {
 
 		request, err := makeSSHAPIRequestPayload(params)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 
 		response, err = serverContext.tunnel.SendAPIRequest(
-			SERVER_API_CONNECTED_REQUEST_NAME, request)
+			common.PSIPHON_API_CONNECTED_REQUEST_NAME, request)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 
 	} else {
@@ -300,7 +294,7 @@ func (serverContext *ServerContext) DoConnectedRequest() error {
 		response, err = serverContext.doGetRequest(
 			makeRequestUrl(serverContext.tunnel, "", "connected", params))
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 	}
 
@@ -309,13 +303,13 @@ func (serverContext *ServerContext) DoConnectedRequest() error {
 	}
 	err = json.Unmarshal(response, &connectedResponse)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	err = SetKeyValue(
 		DATA_STORE_LAST_CONNECTED_KEY, connectedResponse.ConnectedTimestamp)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	return nil
 }
@@ -336,7 +330,7 @@ func (serverContext *ServerContext) DoStatusRequest(tunnel *Tunnel) error {
 	statusPayload, statusPayloadInfo, err := makeStatusRequestPayload(
 		tunnel.serverEntry.IpAddress)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	if serverContext.psiphonHttpsClient == nil {
@@ -349,7 +343,7 @@ func (serverContext *ServerContext) DoStatusRequest(tunnel *Tunnel) error {
 
 		if err == nil {
 			_, err = serverContext.tunnel.SendAPIRequest(
-				SERVER_API_STATUS_REQUEST_NAME, request)
+				common.PSIPHON_API_STATUS_REQUEST_NAME, request)
 		}
 
 	} else {
@@ -368,7 +362,7 @@ func (serverContext *ServerContext) DoStatusRequest(tunnel *Tunnel) error {
 		// the request but the client failed to receive the response.
 		putBackStatusRequestPayload(statusPayloadInfo)
 
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	confirmStatusRequestPayload(statusPayloadInfo)
@@ -385,7 +379,12 @@ func (serverContext *ServerContext) getStatusParams(isTunneled bool) requestJSON
 	// TODO: base64 encoding of padding means the padding size is not exactly
 	// [0, PADDING_MAX_BYTES].
 
-	randomPadding := MakeSecureRandomPadding(0, PSIPHON_API_STATUS_REQUEST_PADDING_MAX_BYTES)
+	randomPadding, err := common.MakeSecureRandomPadding(0, PSIPHON_API_STATUS_REQUEST_PADDING_MAX_BYTES)
+	if err != nil {
+		NoticeAlert("MakeSecureRandomPadding failed: %s", err)
+		// Proceed without random padding
+		randomPadding = make([]byte, 0)
+	}
 	params["padding"] = base64.StdEncoding.EncodeToString(randomPadding)
 
 	// Legacy clients set "connected" to "0" when disconnecting, and this value
@@ -422,7 +421,7 @@ func makeStatusRequestPayload(
 		PSIPHON_API_TUNNEL_STATS_MAX_COUNT)
 	if err != nil {
 		NoticeAlert(
-			"TakeOutUnreportedTunnelStats failed: %s", ContextError(err))
+			"TakeOutUnreportedTunnelStats failed: %s", common.ContextError(err))
 		tunnelStats = nil
 		// Proceed with transferStats only
 	}
@@ -452,7 +451,7 @@ func makeStatusRequestPayload(
 		// Send the transfer stats and tunnel stats later
 		putBackStatusRequestPayload(payloadInfo)
 
-		return nil, nil, ContextError(err)
+		return nil, nil, common.ContextError(err)
 	}
 
 	return jsonPayload, payloadInfo, nil
@@ -466,7 +465,7 @@ func putBackStatusRequestPayload(payloadInfo *statusRequestPayloadInfo) {
 		// These tunnel stats records won't be resent under after a
 		// datastore re-initialization.
 		NoticeAlert(
-			"PutBackUnreportedTunnelStats failed: %s", ContextError(err))
+			"PutBackUnreportedTunnelStats failed: %s", common.ContextError(err))
 	}
 }
 
@@ -475,7 +474,7 @@ func confirmStatusRequestPayload(payloadInfo *statusRequestPayloadInfo) {
 	if err != nil {
 		// These tunnel stats records may be resent.
 		NoticeAlert(
-			"ClearReportedTunnelStats failed: %s", ContextError(err))
+			"ClearReportedTunnelStats failed: %s", common.ContextError(err))
 	}
 }
 
@@ -507,7 +506,7 @@ func (serverContext *ServerContext) doUntunneledStatusRequest(
 
 	certificate, err := DecodeCertificate(tunnel.serverEntry.WebServerCertificate)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	timeout := time.Duration(*tunnel.config.PsiphonApiServerTimeoutSeconds) * time.Second
@@ -532,12 +531,12 @@ func (serverContext *ServerContext) doUntunneledStatusRequest(
 		url,
 		timeout)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	statusPayload, statusPayloadInfo, err := makeStatusRequestPayload(tunnel.serverEntry.IpAddress)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	bodyType := "application/json"
@@ -556,7 +555,7 @@ func (serverContext *ServerContext) doUntunneledStatusRequest(
 		putBackStatusRequestPayload(statusPayloadInfo)
 
 		// Trim this error since it may include long URLs
-		return ContextError(TrimError(err))
+		return common.ContextError(TrimError(err))
 	}
 	confirmStatusRequestPayload(statusPayloadInfo)
 	response.Body.Close()
@@ -634,7 +633,7 @@ func RecordTunnelStats(
 
 	tunnelStatsJson, err := json.Marshal(tunnelStats)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	return StoreTunnelStats(tunnelStatsJson)
@@ -659,13 +658,13 @@ func (serverContext *ServerContext) DoClientVerificationRequest(
 
 		request, err := makeSSHAPIRequestPayload(params)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 
 		response, err = serverContext.tunnel.SendAPIRequest(
-			SERVER_API_CLIENT_VERIFICATION_REQUEST_NAME, request)
+			common.PSIPHON_API_CLIENT_VERIFICATION_REQUEST_NAME, request)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 
 	} else {
@@ -676,7 +675,7 @@ func (serverContext *ServerContext) DoClientVerificationRequest(
 			"application/json",
 			bytes.NewReader([]byte(verificationPayload)))
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 	}
 
@@ -691,7 +690,7 @@ func (serverContext *ServerContext) DoClientVerificationRequest(
 
 	err = json.Unmarshal(response, &clientVerificationResponse)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	if clientVerificationResponse.ClientVerificationTTLSeconds > 0 {
@@ -717,12 +716,12 @@ func (serverContext *ServerContext) doGetRequest(
 	}
 	if err != nil {
 		// Trim this error since it may include long URLs
-		return nil, ContextError(TrimError(err))
+		return nil, common.ContextError(TrimError(err))
 	}
 	defer response.Body.Close()
 	body, err := ioutil.ReadAll(response.Body)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	return body, nil
 }
@@ -738,12 +737,12 @@ func (serverContext *ServerContext) doPostRequest(
 	}
 	if err != nil {
 		// Trim this error since it may include long URLs
-		return nil, ContextError(TrimError(err))
+		return nil, common.ContextError(TrimError(err))
 	}
 	defer response.Body.Close()
 	responseBody, err = ioutil.ReadAll(response.Body)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	return responseBody, nil
 }
@@ -814,7 +813,7 @@ func (serverContext *ServerContext) getBaseParams() requestJSONObject {
 	// a precise handshake request server timestamp, is truncated
 	// to hour granularity to avoid introducing a reconstructable
 	// cross-session user trace into server logs.
-	localServerEntryTimestamp := TruncateTimestampToHour(tunnel.serverEntry.LocalTimestamp)
+	localServerEntryTimestamp := common.TruncateTimestampToHour(tunnel.serverEntry.LocalTimestamp)
 	if localServerEntryTimestamp != "" {
 		params["server_entry_timestamp"] = localServerEntryTimestamp
 	}
@@ -826,7 +825,7 @@ func (serverContext *ServerContext) getBaseParams() requestJSONObject {
 func makeSSHAPIRequestPayload(params requestJSONObject) ([]byte, error) {
 	jsonPayload, err := json.Marshal(params)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	return jsonPayload, nil
 }
@@ -890,7 +889,7 @@ func makeRequestUrl(tunnel *Tunnel, port, path string, params requestJSONObject)
 func makePsiphonHttpsClient(tunnel *Tunnel) (httpsClient *http.Client, err error) {
 	certificate, err := DecodeCertificate(tunnel.serverEntry.WebServerCertificate)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
 		// TODO: check tunnel.isClosed, and apply TUNNEL_PORT_FORWARD_DIAL_TIMEOUT as in Tunnel.Dial?

+ 12 - 61
psiphon/serverEntry.go

@@ -27,41 +27,10 @@ import (
 	"fmt"
 	"net"
 	"strings"
-)
-
-const (
-	TUNNEL_PROTOCOL_SSH                  = "SSH"
-	TUNNEL_PROTOCOL_OBFUSCATED_SSH       = "OSSH"
-	TUNNEL_PROTOCOL_UNFRONTED_MEEK       = "UNFRONTED-MEEK-OSSH"
-	TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS = "UNFRONTED-MEEK-HTTPS-OSSH"
-	TUNNEL_PROTOCOL_FRONTED_MEEK         = "FRONTED-MEEK-OSSH"
-	TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP    = "FRONTED-MEEK-HTTP-OSSH"
 
-	SERVER_ENTRY_SOURCE_EMBEDDED  = "EMBEDDED"
-	SERVER_ENTRY_SOURCE_REMOTE    = "REMOTE"
-	SERVER_ENTRY_SOURCE_DISCOVERY = "DISCOVERY"
-	SERVER_ENTRY_SOURCE_TARGET    = "TARGET"
-
-	CAPABILITY_SSH_API_REQUESTS            = "ssh-api-requests"
-	CAPABILITY_UNTUNNELED_WEB_API_REQUESTS = "handshake"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
-var SupportedTunnelProtocols = []string{
-	TUNNEL_PROTOCOL_FRONTED_MEEK,
-	TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP,
-	TUNNEL_PROTOCOL_UNFRONTED_MEEK,
-	TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
-	TUNNEL_PROTOCOL_OBFUSCATED_SSH,
-	TUNNEL_PROTOCOL_SSH,
-}
-
-var SupportedServerEntrySources = []string{
-	SERVER_ENTRY_SOURCE_EMBEDDED,
-	SERVER_ENTRY_SOURCE_REMOTE,
-	SERVER_ENTRY_SOURCE_DISCOVERY,
-	SERVER_ENTRY_SOURCE_TARGET,
-}
-
 // ServerEntry represents a Psiphon server. It contains information
 // about how to establish a tunnel connection to the server through
 // several protocols. Server entries are JSON records downloaded from
@@ -96,24 +65,6 @@ type ServerEntry struct {
 	LocalTimestamp string `json:"localTimestamp"`
 }
 
-func TunnelProtocolUsesSSH(protocol string) bool {
-	return true
-}
-
-func TunnelProtocolUsesObfuscatedSSH(protocol string) bool {
-	return protocol != TUNNEL_PROTOCOL_SSH
-}
-
-func TunnelProtocolUsesMeekHTTP(protocol string) bool {
-	return protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK ||
-		protocol == TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP
-}
-
-func TunnelProtocolUsesMeekHTTPS(protocol string) bool {
-	return protocol == TUNNEL_PROTOCOL_FRONTED_MEEK ||
-		protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS
-}
-
 // GetCapability returns the server capability corresponding
 // to the protocol.
 func GetCapability(protocol string) string {
@@ -124,14 +75,14 @@ func GetCapability(protocol string) string {
 // the necessary capability to support the specified tunnel protocol.
 func (serverEntry *ServerEntry) SupportsProtocol(protocol string) bool {
 	requiredCapability := GetCapability(protocol)
-	return Contains(serverEntry.Capabilities, requiredCapability)
+	return common.Contains(serverEntry.Capabilities, requiredCapability)
 }
 
 // GetSupportedProtocols returns a list of tunnel protocols supported
 // by the ServerEntry's capabilities.
 func (serverEntry *ServerEntry) GetSupportedProtocols() []string {
 	supportedProtocols := make([]string, 0)
-	for _, protocol := range SupportedTunnelProtocols {
+	for _, protocol := range common.SupportedTunnelProtocols {
 		if serverEntry.SupportsProtocol(protocol) {
 			supportedProtocols = append(supportedProtocols, protocol)
 		}
@@ -163,16 +114,16 @@ func (serverEntry *ServerEntry) DisableImpairedProtocols(impairedProtocols []str
 // SupportsSSHAPIRequests returns true when the server supports
 // SSH API requests.
 func (serverEntry *ServerEntry) SupportsSSHAPIRequests() bool {
-	return Contains(serverEntry.Capabilities, CAPABILITY_SSH_API_REQUESTS)
+	return common.Contains(serverEntry.Capabilities, common.CAPABILITY_SSH_API_REQUESTS)
 }
 
 func (serverEntry *ServerEntry) GetUntunneledWebRequestPorts() []string {
 	ports := make([]string, 0)
-	if Contains(serverEntry.Capabilities, CAPABILITY_UNTUNNELED_WEB_API_REQUESTS) {
+	if common.Contains(serverEntry.Capabilities, common.CAPABILITY_UNTUNNELED_WEB_API_REQUESTS) {
 		// Server-side configuration quirk: there's a port forward from
 		// port 443 to the web server, which we can try, except on servers
 		// running FRONTED_MEEK, which listens on port 443.
-		if !serverEntry.SupportsProtocol(TUNNEL_PROTOCOL_FRONTED_MEEK) {
+		if !serverEntry.SupportsProtocol(common.TUNNEL_PROTOCOL_FRONTED_MEEK) {
 			ports = append(ports, "443")
 		}
 		ports = append(ports, serverEntry.WebServerPort)
@@ -185,7 +136,7 @@ func (serverEntry *ServerEntry) GetUntunneledWebRequestPorts() []string {
 func EncodeServerEntry(serverEntry *ServerEntry) (string, error) {
 	serverEntryContents, err := json.Marshal(serverEntry)
 	if err != nil {
-		return "", ContextError(err)
+		return "", common.ContextError(err)
 	}
 
 	return hex.EncodeToString([]byte(fmt.Sprintf(
@@ -213,19 +164,19 @@ func DecodeServerEntry(
 
 	hexDecodedServerEntry, err := hex.DecodeString(encodedServerEntry)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// Skip past legacy format (4 space delimited fields) and just parse the JSON config
 	fields := bytes.SplitN(hexDecodedServerEntry, []byte(" "), 5)
 	if len(fields) != 5 {
-		return nil, ContextError(errors.New("invalid encoded server entry"))
+		return nil, common.ContextError(errors.New("invalid encoded server entry"))
 	}
 
 	serverEntry = new(ServerEntry)
 	err = json.Unmarshal(fields[4], &serverEntry)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// NOTE: if the source JSON happens to have values in these fields, they get clobbered.
@@ -247,7 +198,7 @@ func ValidateServerEntry(serverEntry *ServerEntry) error {
 		// Some callers skip invalid server entries without propagating
 		// the error mesage, so issue a notice.
 		NoticeAlert(errMsg)
-		return ContextError(errors.New(errMsg))
+		return common.ContextError(errors.New(errMsg))
 	}
 	return nil
 }
@@ -269,7 +220,7 @@ func DecodeAndValidateServerEntryList(
 		// TODO: skip this entry and continue if can't decode?
 		serverEntry, err := DecodeServerEntry(encodedServerEntry, timestamp, serverEntrySource)
 		if err != nil {
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 
 		if ValidateServerEntry(serverEntry) != nil {

+ 4 - 2
psiphon/serverEntry_test.go

@@ -22,6 +22,8 @@ package psiphon
 import (
 	"encoding/hex"
 	"testing"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 const (
@@ -41,7 +43,7 @@ func TestDecodeAndValidateServerEntryList(t *testing.T) {
 		hex.EncodeToString([]byte(_INVALID_MALFORMED_IP_ADDRESS_SERVER_ENTRY))
 
 	serverEntries, err := DecodeAndValidateServerEntryList(
-		testEncodedServerEntryList, GetCurrentTimestamp(), SERVER_ENTRY_SOURCE_EMBEDDED)
+		testEncodedServerEntryList, common.GetCurrentTimestamp(), common.SERVER_ENTRY_SOURCE_EMBEDDED)
 	if err != nil {
 		t.Error(err.Error())
 		t.FailNow()
@@ -64,7 +66,7 @@ func TestInvalidServerEntries(t *testing.T) {
 	for _, testCase := range testCases {
 		encodedServerEntry := hex.EncodeToString([]byte(testCase))
 		serverEntry, err := DecodeServerEntry(
-			encodedServerEntry, GetCurrentTimestamp(), SERVER_ENTRY_SOURCE_EMBEDDED)
+			encodedServerEntry, common.GetCurrentTimestamp(), common.SERVER_ENTRY_SOURCE_EMBEDDED)
 		if err != nil {
 			t.Error(err.Error())
 		}

+ 7 - 6
psiphon/socksProxy.go

@@ -25,6 +25,7 @@ import (
 	"sync"
 
 	socks "github.com/Psiphon-Inc/goptlib"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // SocksProxy is a SOCKS server that accepts local host connections
@@ -35,7 +36,7 @@ type SocksProxy struct {
 	tunneler               Tunneler
 	listener               *socks.SocksListener
 	serveWaitGroup         *sync.WaitGroup
-	openConns              *Conns
+	openConns              *common.Conns
 	stopListeningBroadcast chan struct{}
 }
 
@@ -55,13 +56,13 @@ func NewSocksProxy(
 		if IsAddressInUseError(err) {
 			NoticeSocksProxyPortInUse(config.LocalSocksProxyPort)
 		}
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	proxy = &SocksProxy{
 		tunneler:               tunneler,
 		listener:               listener,
 		serveWaitGroup:         new(sync.WaitGroup),
-		openConns:              new(Conns),
+		openConns:              new(common.Conns),
 		stopListeningBroadcast: make(chan struct{}),
 	}
 	proxy.serveWaitGroup.Add(1)
@@ -88,12 +89,12 @@ func (proxy *SocksProxy) socksConnectionHandler(localConn *socks.SocksConn) (err
 	// open connection for data which will never arrive.
 	remoteConn, err := proxy.tunneler.Dial(localConn.Req.Target, false, localConn)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	defer remoteConn.Close()
 	err = localConn.Grant(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: 0})
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	LocalProxyRelay(_SOCKS_PROXY_TYPE, localConn, remoteConn)
 	return nil
@@ -127,7 +128,7 @@ loop:
 		go func() {
 			err := proxy.socksConnectionHandler(socksConnection)
 			if err != nil {
-				NoticeLocalProxyError(_SOCKS_PROXY_TYPE, ContextError(err))
+				NoticeLocalProxyError(_SOCKS_PROXY_TYPE, common.ContextError(err))
 			}
 		}()
 	}

+ 18 - 16
psiphon/splitTunnel.go

@@ -34,6 +34,8 @@ import (
 	"strings"
 	"sync"
 	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // SplitTunnelClassifier determines whether a network destination
@@ -219,12 +221,12 @@ func (classifier *SplitTunnelClassifier) getRoutes(tunnel *Tunnel) (routesData [
 	url := fmt.Sprintf(classifier.fetchRoutesUrlFormat, tunnel.serverContext.clientRegion)
 	request, err := http.NewRequest("GET", url, nil)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	etag, err := GetSplitTunnelRoutesETag(tunnel.serverContext.clientRegion)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	if etag != "" {
 		request.Header.Add("If-None-Match", etag)
@@ -255,7 +257,7 @@ func (classifier *SplitTunnelClassifier) getRoutes(tunnel *Tunnel) (routesData [
 		err = fmt.Errorf("unexpected response status code: %d", response.StatusCode)
 	}
 	if err != nil {
-		NoticeAlert("failed to request split tunnel routes package: %s", ContextError(err))
+		NoticeAlert("failed to request split tunnel routes package: %s", common.ContextError(err))
 		useCachedRoutes = true
 	}
 
@@ -270,7 +272,7 @@ func (classifier *SplitTunnelClassifier) getRoutes(tunnel *Tunnel) (routesData [
 	if !useCachedRoutes {
 		routesDataPackage, err = ioutil.ReadAll(response.Body)
 		if err != nil {
-			NoticeAlert("failed to download split tunnel routes package: %s", ContextError(err))
+			NoticeAlert("failed to download split tunnel routes package: %s", common.ContextError(err))
 			useCachedRoutes = true
 		}
 	}
@@ -280,7 +282,7 @@ func (classifier *SplitTunnelClassifier) getRoutes(tunnel *Tunnel) (routesData [
 		encodedRoutesData, err = ReadAuthenticatedDataPackage(
 			routesDataPackage, classifier.routesSignaturePublicKey)
 		if err != nil {
-			NoticeAlert("failed to read split tunnel routes package: %s", ContextError(err))
+			NoticeAlert("failed to read split tunnel routes package: %s", common.ContextError(err))
 			useCachedRoutes = true
 		}
 	}
@@ -289,7 +291,7 @@ func (classifier *SplitTunnelClassifier) getRoutes(tunnel *Tunnel) (routesData [
 	if !useCachedRoutes {
 		compressedRoutesData, err = base64.StdEncoding.DecodeString(encodedRoutesData)
 		if err != nil {
-			NoticeAlert("failed to decode split tunnel routes: %s", ContextError(err))
+			NoticeAlert("failed to decode split tunnel routes: %s", common.ContextError(err))
 			useCachedRoutes = true
 		}
 	}
@@ -301,7 +303,7 @@ func (classifier *SplitTunnelClassifier) getRoutes(tunnel *Tunnel) (routesData [
 			zlibReader.Close()
 		}
 		if err != nil {
-			NoticeAlert("failed to decompress split tunnel routes: %s", ContextError(err))
+			NoticeAlert("failed to decompress split tunnel routes: %s", common.ContextError(err))
 			useCachedRoutes = true
 		}
 	}
@@ -311,7 +313,7 @@ func (classifier *SplitTunnelClassifier) getRoutes(tunnel *Tunnel) (routesData [
 		if etag != "" {
 			err := SetSplitTunnelRoutes(tunnel.serverContext.clientRegion, etag, routesData)
 			if err != nil {
-				NoticeAlert("failed to cache split tunnel routes: %s", ContextError(err))
+				NoticeAlert("failed to cache split tunnel routes: %s", common.ContextError(err))
 				// Proceed with fetched data, even when we can't cache it
 			}
 		}
@@ -320,10 +322,10 @@ func (classifier *SplitTunnelClassifier) getRoutes(tunnel *Tunnel) (routesData [
 	if useCachedRoutes {
 		routesData, err = GetSplitTunnelRoutesData(tunnel.serverContext.clientRegion)
 		if err != nil {
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 		if routesData == nil {
-			return nil, ContextError(errors.New("no cached routes"))
+			return nil, common.ContextError(errors.New("no cached routes"))
 		}
 	}
 
@@ -346,7 +348,7 @@ func (classifier *SplitTunnelClassifier) installRoutes(routesData []byte) (err e
 
 	classifier.routes, err = NewNetworkList(routesData)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	classifier.isRoutesSet = true
@@ -392,7 +394,7 @@ func NewNetworkList(routesData []byte) (networkList, error) {
 		list = append(list, net.IPNet{IP: ip.Mask(mask), Mask: mask})
 	}
 	if len(list) == 0 {
-		return nil, ContextError(errors.New("Routes data contains no networks"))
+		return nil, common.ContextError(errors.New("Routes data contains no networks"))
 	}
 
 	// Sort data for fast lookup
@@ -480,7 +482,7 @@ func tunneledLookupIP(
 	// dnsServerAddress must be an IP address
 	ipAddr = net.ParseIP(dnsServerAddress)
 	if ipAddr == nil {
-		return nil, 0, ContextError(errors.New("invalid IP address"))
+		return nil, 0, common.ContextError(errors.New("invalid IP address"))
 	}
 
 	// Dial's alwaysTunnel is set to true to ensure this connection
@@ -490,15 +492,15 @@ func tunneledLookupIP(
 	conn, err := dnsTunneler.Dial(fmt.Sprintf(
 		"%s:%d", dnsServerAddress, DNS_PORT), true, nil)
 	if err != nil {
-		return nil, 0, ContextError(err)
+		return nil, 0, common.ContextError(err)
 	}
 
 	ipAddrs, ttls, err := ResolveIP(host, conn)
 	if err != nil {
-		return nil, 0, ContextError(err)
+		return nil, 0, common.ContextError(err)
 	}
 	if len(ipAddrs) < 1 {
-		return nil, 0, ContextError(errors.New("no IP address"))
+		return nil, 0, common.ContextError(errors.New("no IP address"))
 	}
 
 	return ipAddrs[0], ttls[0], nil

+ 9 - 7
psiphon/tlsDialer.go

@@ -77,6 +77,8 @@ import (
 	"errors"
 	"net"
 	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // CustomTLSConfig contains parameters to determine the behavior
@@ -159,13 +161,13 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (net.Conn, err
 
 	rawConn, err := config.Dial(network, dialAddr)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	hostname, _, err := net.SplitHostPort(dialAddr)
 	if err != nil {
 		rawConn.Close()
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	tlsConfig := &tls.Config{}
@@ -199,7 +201,7 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (net.Conn, err
 		conn, err = newOpenSSLConn(rawConn, hostname, config)
 		if err != nil {
 			rawConn.Close()
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 	} else {
 		conn = tls.Client(rawConn, tlsConfig)
@@ -233,7 +235,7 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (net.Conn, err
 
 	if err != nil {
 		rawConn.Close()
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	return conn, nil
@@ -242,10 +244,10 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (net.Conn, err
 func verifyLegacyCertificate(conn *tls.Conn, expectedCertificate *x509.Certificate) error {
 	certs := conn.ConnectionState().PeerCertificates
 	if len(certs) < 1 {
-		return ContextError(errors.New("no certificate to verify"))
+		return common.ContextError(errors.New("no certificate to verify"))
 	}
 	if !bytes.Equal(certs[0].Raw, expectedCertificate.Raw) {
-		return ContextError(errors.New("unexpected certificate"))
+		return common.ContextError(errors.New("unexpected certificate"))
 	}
 	return nil
 }
@@ -269,7 +271,7 @@ func verifyServerCerts(conn *tls.Conn, hostname string, config *tls.Config) erro
 
 	_, err := certs[0].Verify(opts)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 	return nil
 }

+ 66 - 48
psiphon/tunnel.go

@@ -33,6 +33,7 @@ import (
 	"time"
 
 	regen "github.com/Psiphon-Inc/goregen"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/transferstats"
 	"golang.org/x/crypto/ssh"
 )
@@ -113,27 +114,30 @@ func EstablishTunnel(
 	config *Config,
 	untunneledDialConfig *DialConfig,
 	sessionId string,
-	pendingConns *Conns,
+	pendingConns *common.Conns,
 	serverEntry *ServerEntry,
 	tunnelOwner TunnelOwner) (tunnel *Tunnel, err error) {
 
 	selectedProtocol, err := selectProtocol(config, serverEntry)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// Build transport layers and establish SSH connection
-	conn, sshClient, dialStats, err := dialSsh(
+	dialConn, sshClient, dialStats, err := dialSsh(
 		config, pendingConns, serverEntry, selectedProtocol, sessionId)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
+	// Apply throttling (if configured)
+	conn := common.NewThrottledConn(dialConn, config.RateLimits)
+
 	// Cleanup on error
 	defer func() {
 		if err != nil {
 			sshClient.Close()
-			conn.Close()
+			dialConn.Close()
 		}
 	}()
 
@@ -165,7 +169,7 @@ func EstablishTunnel(
 		NoticeInfo("starting server context for %s", tunnel.serverEntry.IpAddress)
 		tunnel.serverContext, err = NewServerContext(tunnel, sessionId)
 		if err != nil {
-			return nil, ContextError(
+			return nil, common.ContextError(
 				fmt.Errorf("error starting server context for %s: %s",
 					tunnel.serverEntry.IpAddress, err))
 		}
@@ -174,7 +178,7 @@ func EstablishTunnel(
 	tunnel.startTime = time.Now()
 
 	// Now that network operations are complete, cancel interruptibility
-	pendingConns.Remove(conn)
+	pendingConns.Remove(dialConn)
 
 	// Spawn the operateTunnel goroutine, which monitors the tunnel and handles periodic stats updates.
 	tunnel.operateWaitGroup.Add(1)
@@ -236,18 +240,18 @@ func (tunnel *Tunnel) SendAPIRequest(
 	name string, requestPayload []byte) ([]byte, error) {
 
 	if tunnel.IsClosed() {
-		return nil, ContextError(errors.New("tunnel is closed"))
+		return nil, common.ContextError(errors.New("tunnel is closed"))
 	}
 
 	ok, responsePayload, err := tunnel.sshClient.Conn.SendRequest(
 		name, true, requestPayload)
 
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	if !ok {
-		return nil, ContextError(errors.New("API request rejected"))
+		return nil, common.ContextError(errors.New("API request rejected"))
 	}
 
 	return responsePayload, nil
@@ -259,7 +263,7 @@ func (tunnel *Tunnel) Dial(
 	remoteAddr string, alwaysTunnel bool, downstreamConn net.Conn) (conn net.Conn, err error) {
 
 	if tunnel.IsClosed() {
-		return nil, ContextError(errors.New("tunnel is closed"))
+		return nil, common.ContextError(errors.New("tunnel is closed"))
 	}
 
 	type tunnelDialResult struct {
@@ -284,7 +288,7 @@ func (tunnel *Tunnel) Dial(
 		case tunnel.signalPortForwardFailure <- *new(struct{}):
 		default:
 		}
-		return nil, ContextError(result.err)
+		return nil, common.ContextError(result.err)
 	}
 
 	conn = &TunneledConn{
@@ -372,7 +376,7 @@ func selectProtocol(config *Config, serverEntry *ServerEntry) (selectedProtocol
 	// for now, the code is simply assuming that MEEK capabilities imply OSSH capability.
 	if config.TunnelProtocol != "" {
 		if !serverEntry.SupportsProtocol(config.TunnelProtocol) {
-			return "", ContextError(fmt.Errorf("server does not have required capability"))
+			return "", common.ContextError(fmt.Errorf("server does not have required capability"))
 		}
 		selectedProtocol = config.TunnelProtocol
 	} else {
@@ -384,12 +388,12 @@ func selectProtocol(config *Config, serverEntry *ServerEntry) (selectedProtocol
 
 		candidateProtocols := serverEntry.GetSupportedProtocols()
 		if len(candidateProtocols) == 0 {
-			return "", ContextError(fmt.Errorf("server does not have any supported capabilities"))
+			return "", common.ContextError(fmt.Errorf("server does not have any supported capabilities"))
 		}
 
-		index, err := MakeSecureRandomInt(len(candidateProtocols))
+		index, err := common.MakeSecureRandomInt(len(candidateProtocols))
 		if err != nil {
-			return "", ContextError(err)
+			return "", common.ContextError(err)
 		}
 		selectedProtocol = candidateProtocols[index]
 	}
@@ -407,7 +411,7 @@ func selectFrontingParameters(
 
 		frontingAddress, err = regen.Generate(serverEntry.MeekFrontingAddressesRegex)
 		if err != nil {
-			return "", "", ContextError(err)
+			return "", "", common.ContextError(err)
 		}
 	} else {
 
@@ -415,19 +419,19 @@ func selectFrontingParameters(
 		// fronting-capable servers.
 
 		if len(serverEntry.MeekFrontingAddresses) == 0 {
-			return "", "", ContextError(errors.New("MeekFrontingAddresses is empty"))
+			return "", "", common.ContextError(errors.New("MeekFrontingAddresses is empty"))
 		}
-		index, err := MakeSecureRandomInt(len(serverEntry.MeekFrontingAddresses))
+		index, err := common.MakeSecureRandomInt(len(serverEntry.MeekFrontingAddresses))
 		if err != nil {
-			return "", "", ContextError(err)
+			return "", "", common.ContextError(err)
 		}
 		frontingAddress = serverEntry.MeekFrontingAddresses[index]
 	}
 
 	if len(serverEntry.MeekFrontingHosts) > 0 {
-		index, err := MakeSecureRandomInt(len(serverEntry.MeekFrontingHosts))
+		index, err := common.MakeSecureRandomInt(len(serverEntry.MeekFrontingHosts))
 		if err != nil {
-			return "", "", ContextError(err)
+			return "", "", common.ContextError(err)
 		}
 		frontingHost = serverEntry.MeekFrontingHosts[index]
 	} else {
@@ -455,10 +459,10 @@ func initMeekConfig(
 	transformedHostName := false
 
 	switch selectedProtocol {
-	case TUNNEL_PROTOCOL_FRONTED_MEEK:
+	case common.TUNNEL_PROTOCOL_FRONTED_MEEK:
 		frontingAddress, frontingHost, err := selectFrontingParameters(serverEntry)
 		if err != nil {
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 		dialAddress = fmt.Sprintf("%s:443", frontingAddress)
 		useHTTPS = true
@@ -468,15 +472,15 @@ func initMeekConfig(
 		}
 		hostHeader = frontingHost
 
-	case TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP:
+	case common.TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP:
 		frontingAddress, frontingHost, err := selectFrontingParameters(serverEntry)
 		if err != nil {
-			return nil, ContextError(err)
+			return nil, common.ContextError(err)
 		}
 		dialAddress = fmt.Sprintf("%s:80", frontingAddress)
 		hostHeader = frontingHost
 
-	case TUNNEL_PROTOCOL_UNFRONTED_MEEK:
+	case common.TUNNEL_PROTOCOL_UNFRONTED_MEEK:
 		dialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.MeekServerPort)
 		hostname := serverEntry.IpAddress
 		hostname, transformedHostName = config.HostNameTransformer.TransformHostName(hostname)
@@ -486,7 +490,7 @@ func initMeekConfig(
 			hostHeader = fmt.Sprintf("%s:%d", hostname, serverEntry.MeekServerPort)
 		}
 
-	case TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS:
+	case common.TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS:
 		dialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.MeekServerPort)
 		useHTTPS = true
 		SNIServerName, transformedHostName =
@@ -498,7 +502,7 @@ func initMeekConfig(
 		}
 
 	default:
-		return nil, ContextError(errors.New("unexpected selectedProtocol"))
+		return nil, common.ContextError(errors.New("unexpected selectedProtocol"))
 	}
 
 	// The unnderlying TLS will automatically disable SNI for IP address server name
@@ -524,7 +528,7 @@ func initMeekConfig(
 // When additional dial configuration is used, DialStats are recorded and returned.
 func dialSsh(
 	config *Config,
-	pendingConns *Conns,
+	pendingConns *common.Conns,
 	serverEntry *ServerEntry,
 	selectedProtocol,
 	sessionId string) (net.Conn, *ssh.Client, *TunnelDialStats, error) {
@@ -538,18 +542,18 @@ func dialSsh(
 	var err error
 
 	switch selectedProtocol {
-	case TUNNEL_PROTOCOL_OBFUSCATED_SSH:
+	case common.TUNNEL_PROTOCOL_OBFUSCATED_SSH:
 		useObfuscatedSsh = true
 		directTCPDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedPort)
 
-	case TUNNEL_PROTOCOL_SSH:
+	case common.TUNNEL_PROTOCOL_SSH:
 		directTCPDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshPort)
 
 	default:
 		useObfuscatedSsh = true
 		meekConfig, err = initMeekConfig(config, serverEntry, selectedProtocol, sessionId)
 		if err != nil {
-			return nil, nil, nil, ContextError(err)
+			return nil, nil, nil, common.ContextError(err)
 		}
 	}
 
@@ -589,12 +593,12 @@ func dialSsh(
 	if meekConfig != nil {
 		conn, err = DialMeek(meekConfig, dialConfig)
 		if err != nil {
-			return nil, nil, nil, ContextError(err)
+			return nil, nil, nil, common.ContextError(err)
 		}
 	} else {
 		conn, err = DialTCP(directTCPDialAddress, dialConfig)
 		if err != nil {
-			return nil, nil, nil, ContextError(err)
+			return nil, nil, nil, common.ContextError(err)
 		}
 	}
 
@@ -612,19 +616,19 @@ func dialSsh(
 		sshConn, err = NewObfuscatedSshConn(
 			OBFUSCATION_CONN_MODE_CLIENT, conn, serverEntry.SshObfuscatedKey)
 		if err != nil {
-			return nil, nil, nil, ContextError(err)
+			return nil, nil, nil, common.ContextError(err)
 		}
 	}
 
 	// Now establish the SSH session over the conn transport
 	expectedPublicKey, err := base64.StdEncoding.DecodeString(serverEntry.SshHostKey)
 	if err != nil {
-		return nil, nil, nil, ContextError(err)
+		return nil, nil, nil, common.ContextError(err)
 	}
 	sshCertChecker := &ssh.CertChecker{
 		HostKeyFallback: func(addr string, remote net.Addr, publicKey ssh.PublicKey) error {
 			if !bytes.Equal(expectedPublicKey, publicKey.Marshal()) {
-				return ContextError(errors.New("unexpected host public key"))
+				return common.ContextError(errors.New("unexpected host public key"))
 			}
 			return nil
 		},
@@ -635,7 +639,7 @@ func dialSsh(
 			SshPassword string `json:"SshPassword"`
 		}{sessionId, serverEntry.SshPassword})
 	if err != nil {
-		return nil, nil, nil, ContextError(err)
+		return nil, nil, nil, common.ContextError(err)
 	}
 	sshClientConfig := &ssh.ClientConfig{
 		User: serverEntry.SshUsername,
@@ -681,7 +685,7 @@ func dialSsh(
 
 	result := <-resultChannel
 	if result.err != nil {
-		return nil, nil, nil, ContextError(result.err)
+		return nil, nil, nil, common.ContextError(result.err)
 	}
 
 	var dialStats *TunnelDialStats
@@ -719,6 +723,16 @@ func dialSsh(
 	return conn, result.sshClient, dialStats, nil
 }
 
+func makeRandomPeriod(min, max time.Duration) time.Duration {
+	period, err := common.MakeRandomPeriod(min, max)
+	if err != nil {
+		NoticeAlert("MakeRandomPeriod failed: %s", err)
+		// Proceed without random period
+		period = max
+	}
+	return period
+}
+
 // operateTunnel monitors the health of the tunnel and performs
 // periodic work.
 //
@@ -781,7 +795,7 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 	// from a range, to make the resulting traffic less fingerprintable,
 	// Note: not using Tickers since these are not fixed time periods.
 	nextStatusRequestPeriod := func() time.Duration {
-		return MakeRandomPeriod(
+		return makeRandomPeriod(
 			PSIPHON_API_STATUS_REQUEST_PERIOD_MIN,
 			PSIPHON_API_STATUS_REQUEST_PERIOD_MAX)
 	}
@@ -800,13 +814,13 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 	unreported := CountUnreportedTunnelStats()
 	if unreported > 0 {
 		NoticeInfo("Unreported tunnel stats: %d", unreported)
-		statsTimer.Reset(MakeRandomPeriod(
+		statsTimer.Reset(makeRandomPeriod(
 			PSIPHON_API_STATUS_REQUEST_SHORT_PERIOD_MIN,
 			PSIPHON_API_STATUS_REQUEST_SHORT_PERIOD_MAX))
 	}
 
 	nextSshKeepAlivePeriod := func() time.Duration {
-		return MakeRandomPeriod(
+		return makeRandomPeriod(
 			TUNNEL_SSH_KEEP_ALIVE_PERIOD_MIN,
 			TUNNEL_SSH_KEEP_ALIVE_PERIOD_MAX)
 	}
@@ -975,7 +989,7 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 			totalSent,
 			totalReceived)
 		if err != nil {
-			NoticeAlert("RecordTunnelStats failed: %s", ContextError(err))
+			NoticeAlert("RecordTunnelStats failed: %s", common.ContextError(err))
 		}
 	}
 
@@ -1032,9 +1046,13 @@ func sendSshKeepAlive(
 
 	go func() {
 		// Random padding to frustrate fingerprinting
-		_, _, err := sshClient.SendRequest(
-			"keepalive@openssh.com", true,
-			MakeSecureRandomPadding(0, TUNNEL_SSH_KEEP_ALIVE_PAYLOAD_MAX_BYTES))
+		randomPadding, err := common.MakeSecureRandomPadding(0, TUNNEL_SSH_KEEP_ALIVE_PAYLOAD_MAX_BYTES)
+		if err != nil {
+			NoticeAlert("MakeSecureRandomPadding failed: %s", err)
+			// Proceed without random padding
+			randomPadding = make([]byte, 0)
+		}
+		_, _, err = sshClient.SendRequest("keepalive@openssh.com", true, randomPadding)
 		errChannel <- err
 	}()
 
@@ -1044,7 +1062,7 @@ func sendSshKeepAlive(
 		conn.Close()
 	}
 
-	return ContextError(err)
+	return common.ContextError(err)
 }
 
 // sendStats is a helper for sending session stats to the server.

+ 7 - 5
psiphon/upgradeDownload.go

@@ -24,6 +24,8 @@ import (
 	"net/http"
 	"os"
 	"strconv"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 // DownloadUpgrade performs a resumable download of client upgrade files.
@@ -80,7 +82,7 @@ func DownloadUpgrade(
 	if availableClientVersion == "" {
 		request, err := http.NewRequest("HEAD", requestUrl, nil)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 		response, err := httpClient.Do(request)
 		if err == nil && response.StatusCode != http.StatusOK {
@@ -88,13 +90,13 @@ func DownloadUpgrade(
 			err = fmt.Errorf("unexpected response status code: %d", response.StatusCode)
 		}
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 		defer response.Body.Close()
 
 		currentClientVersion, err := strconv.Atoi(config.ClientVersion)
 		if err != nil {
-			return ContextError(err)
+			return common.ContextError(err)
 		}
 
 		// Note: if the header is missing, Header.Get returns "" and then
@@ -134,12 +136,12 @@ func DownloadUpgrade(
 	NoticeClientUpgradeDownloadedBytes(n)
 
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	err = os.Rename(downloadFilename, config.UpgradeDownloadFilename)
 	if err != nil {
-		return ContextError(err)
+		return common.ContextError(err)
 	}
 
 	NoticeClientUpgradeDownloaded(config.UpgradeDownloadFilename)

+ 0 - 290
psiphon/upstreamproxy/upstreamproxy_test.go

@@ -1,290 +0,0 @@
-/*
- * Copyright (c) 2016, Psiphon Inc.
- * All rights reserved.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program.  If not, see <http://www.gnu.org/licenses/>.
- *
- */
-
-package upstreamproxy
-
-import (
-	"encoding/json"
-	"flag"
-	"fmt"
-	"io/ioutil"
-	"net/http"
-	"net/url"
-	"os"
-	"sync"
-	"testing"
-	"time"
-
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/server"
-	"github.com/elazarl/goproxy"
-)
-
-// Note: upstreamproxy_test is redundant -- it doesn't test any cases not
-// covered by controller_test; and its code is largely copied from server_test
-// and controller_test. upstreamproxy_test exists so that coverage within the
-// upstreamproxy package can be measured and reported.
-
-func TestMain(m *testing.M) {
-	flag.Parse()
-	os.Remove(psiphon.DATA_STORE_FILENAME)
-	initUpstreamProxy()
-	psiphon.SetEmitDiagnosticNotices(true)
-	os.Exit(m.Run())
-}
-
-func TestSSHViaUpstreamProxy(t *testing.T) {
-	runServer(t, "SSH")
-}
-
-func TestOSSHViaUpstreamProxy(t *testing.T) {
-	runServer(t, "OSSH")
-}
-
-func TestUnfrontedMeekViaUpstreamProxy(t *testing.T) {
-	runServer(t, "UNFRONTED-MEEK-OSSH")
-}
-
-func TestUnfrontedMeekHTTPSViaUpstreamProxy(t *testing.T) {
-	runServer(t, "UNFRONTED-MEEK-HTTPS-OSSH")
-}
-
-func runServer(t *testing.T, tunnelProtocol string) {
-
-	// create a server
-
-	var err error
-	serverIPaddress := ""
-	for _, interfaceName := range []string{"eth0", "en0"} {
-		serverIPaddress, err = psiphon.GetInterfaceIPAddress(interfaceName)
-		if err == nil {
-			break
-		}
-	}
-	if err != nil {
-		t.Fatalf("error getting server IP address: %s", err)
-	}
-
-	serverConfigJSON, _, encodedServerEntry, err := server.GenerateConfig(
-		&server.GenerateConfigParams{
-			ServerIPAddress:      serverIPaddress,
-			EnableSSHAPIRequests: true,
-			WebServerPort:        8000,
-			TunnelProtocolPorts:  map[string]int{tunnelProtocol: 4000},
-		})
-	if err != nil {
-		t.Fatalf("error generating server config: %s", err)
-	}
-
-	// customize server config
-
-	var serverConfig interface{}
-	json.Unmarshal(serverConfigJSON, &serverConfig)
-	serverConfig.(map[string]interface{})["GeoIPDatabaseFilename"] = ""
-	serverConfig.(map[string]interface{})["PsinetDatabaseFilename"] = ""
-	serverConfig.(map[string]interface{})["TrafficRulesFilename"] = ""
-	serverConfigJSON, _ = json.Marshal(serverConfig)
-
-	// run server
-
-	serverWaitGroup := new(sync.WaitGroup)
-	serverWaitGroup.Add(1)
-	go func() {
-		defer serverWaitGroup.Done()
-		err := server.RunServices(serverConfigJSON)
-		if err != nil {
-			// TODO: wrong goroutine for t.FatalNow()
-			t.Fatalf("error running server: %s", err)
-		}
-	}()
-	defer func() {
-		p, _ := os.FindProcess(os.Getpid())
-		p.Signal(os.Interrupt)
-		serverWaitGroup.Wait()
-	}()
-
-	// connect to server with client
-
-	// TODO: currently, TargetServerEntry only works with one tunnel
-	numTunnels := 1
-	localHTTPProxyPort := 8081
-	establishTunnelPausePeriodSeconds := 1
-
-	// Note: calling LoadConfig ensures all *int config fields are initialized
-	clientConfigJSON := `
-    {
-        "ClientVersion" : "0",
-        "SponsorId" : "0",
-        "PropagationChannelId" : "0"
-    }`
-	clientConfig, _ := psiphon.LoadConfig([]byte(clientConfigJSON))
-
-	clientConfig.ConnectionWorkerPoolSize = numTunnels
-	clientConfig.TunnelPoolSize = numTunnels
-	clientConfig.DisableRemoteServerListFetcher = true
-	clientConfig.EstablishTunnelPausePeriodSeconds = &establishTunnelPausePeriodSeconds
-	clientConfig.TargetServerEntry = string(encodedServerEntry)
-	clientConfig.TunnelProtocol = tunnelProtocol
-	clientConfig.LocalHttpProxyPort = localHTTPProxyPort
-
-	clientConfig.UpstreamProxyUrl = upstreamProxyURL
-	clientConfig.UpstreamProxyCustomHeaders = upstreamProxyCustomHeaders
-
-	err = psiphon.InitDataStore(clientConfig)
-	if err != nil {
-		t.Fatalf("error initializing client datastore: %s", err)
-	}
-
-	controller, err := psiphon.NewController(clientConfig)
-	if err != nil {
-		t.Fatalf("error creating client controller: %s", err)
-	}
-
-	tunnelsEstablished := make(chan struct{}, 1)
-
-	psiphon.SetNoticeOutput(psiphon.NewNoticeReceiver(
-		func(notice []byte) {
-
-			fmt.Printf("%s\n", string(notice))
-
-			noticeType, payload, err := psiphon.GetNotice(notice)
-			if err != nil {
-				return
-			}
-
-			switch noticeType {
-			case "Tunnels":
-				count := int(payload["count"].(float64))
-				if count >= numTunnels {
-					select {
-					case tunnelsEstablished <- *new(struct{}):
-					default:
-					}
-				}
-			}
-		}))
-
-	controllerShutdownBroadcast := make(chan struct{})
-	controllerWaitGroup := new(sync.WaitGroup)
-	controllerWaitGroup.Add(1)
-	go func() {
-		defer controllerWaitGroup.Done()
-		controller.Run(controllerShutdownBroadcast)
-	}()
-	defer func() {
-		close(controllerShutdownBroadcast)
-		controllerWaitGroup.Wait()
-	}()
-
-	// Test: tunnels must be established within 30 seconds
-
-	establishTimeout := time.NewTimer(30 * time.Second)
-	select {
-	case <-tunnelsEstablished:
-	case <-establishTimeout.C:
-		t.Fatalf("tunnel establish timeout exceeded")
-	}
-
-	// Test: tunneled web site fetch
-
-	testUrl := "https://psiphon.ca"
-	roundTripTimeout := 30 * time.Second
-
-	proxyUrl, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", localHTTPProxyPort))
-	if err != nil {
-		t.Fatalf("error initializing proxied HTTP request: %s", err)
-	}
-
-	httpClient := &http.Client{
-		Transport: &http.Transport{
-			Proxy: http.ProxyURL(proxyUrl),
-		},
-		Timeout: roundTripTimeout,
-	}
-
-	response, err := httpClient.Get(testUrl)
-	if err != nil {
-		t.Fatalf("error sending proxied HTTP request: %s", err)
-	}
-
-	_, err = ioutil.ReadAll(response.Body)
-	if err != nil {
-		t.Fatalf("error reading proxied HTTP response: %s", err)
-	}
-	response.Body.Close()
-}
-
-const upstreamProxyURL = "http://127.0.0.1:2161"
-
-var upstreamProxyCustomHeaders = map[string][]string{"X-Test-Header-Name": []string{"test-header-value1", "test-header-value2"}}
-
-func hasExpectedCustomHeaders(h http.Header) bool {
-	for name, values := range upstreamProxyCustomHeaders {
-		if h[name] == nil {
-			return false
-		}
-		// Order may not be the same
-		for _, value := range values {
-			if !psiphon.Contains(h[name], value) {
-				return false
-			}
-		}
-	}
-	return true
-}
-
-func initUpstreamProxy() {
-	go func() {
-		proxy := goproxy.NewProxyHttpServer()
-
-		proxy.OnRequest().DoFunc(
-			func(r *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) {
-				if !hasExpectedCustomHeaders(r.Header) {
-					ctx.Logf("missing expected headers: %+v", ctx.Req.Header)
-					return nil, goproxy.NewResponse(r, goproxy.ContentTypeText, http.StatusUnauthorized, "")
-				}
-				return r, nil
-			})
-
-		proxy.OnRequest().HandleConnectFunc(
-			func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) {
-				// TODO: enable this check. Currently the headers aren't send because the
-				// following type assertion in upstreamproxy.newHTTP fails (but only in this
-				// test context, not in controller_test):
-				//   if upstreamProxyConfig, ok := forward.(*UpstreamProxyConfig); ok {
-				//       hp.customHeaders = upstreamProxyConfig.CustomHeaders
-				//   }
-				//
-				/*
-					if !hasExpectedCustomHeaders(ctx.Req.Header) {
-						ctx.Logf("missing expected headers: %+v", ctx.Req.Header)
-						return goproxy.RejectConnect, host
-					}
-				*/
-				return goproxy.OkConnect, host
-			})
-
-		err := http.ListenAndServe("127.0.0.1:2161", proxy)
-		if err != nil {
-			fmt.Printf("upstream proxy failed: %s", err)
-		}
-	}()
-
-	// TODO: wait until listener is active?
-}

+ 4 - 280
psiphon/utils.go

@@ -20,134 +20,26 @@
 package psiphon
 
 import (
-	"crypto/rand"
 	"crypto/x509"
 	"encoding/base64"
-	"encoding/hex"
 	"errors"
 	"fmt"
-	"math/big"
 	"net"
 	"net/url"
 	"os"
-	"runtime"
-	"strings"
-	"sync"
 	"syscall"
-	"time"
-)
-
-// Contains is a helper function that returns true
-// if the target string is in the list.
-func Contains(list []string, target string) bool {
-	for _, listItem := range list {
-		if listItem == target {
-			return true
-		}
-	}
-	return false
-}
-
-// FlipCoin is a helper function that randomly
-// returns true or false. If the underlying random
-// number generator fails, FlipCoin still returns
-// a result.
-func FlipCoin() bool {
-	randomInt, _ := MakeSecureRandomInt(2)
-	return randomInt == 1
-}
-
-// MakeSecureRandomInt is a helper function that wraps
-// MakeSecureRandomInt64.
-func MakeSecureRandomInt(max int) (int, error) {
-	randomInt, err := MakeSecureRandomInt64(int64(max))
-	return int(randomInt), err
-}
-
-// MakeSecureRandomInt64 is a helper function that wraps
-// crypto/rand.Int, which returns a uniform random value in [0, max).
-func MakeSecureRandomInt64(max int64) (int64, error) {
-	randomInt, err := rand.Int(rand.Reader, big.NewInt(max))
-	if err != nil {
-		return 0, ContextError(err)
-	}
-	return randomInt.Int64(), nil
-}
-
-// MakeSecureRandomBytes is a helper function that wraps
-// crypto/rand.Read.
-func MakeSecureRandomBytes(length int) ([]byte, error) {
-	randomBytes := make([]byte, length)
-	n, err := rand.Read(randomBytes)
-	if err != nil {
-		return nil, ContextError(err)
-	}
-	if n != length {
-		return nil, ContextError(errors.New("insufficient random bytes"))
-	}
-	return randomBytes, nil
-}
 
-// MakeSecureRandomPadding selects a random padding length in the indicated
-// range and returns a random byte array of the selected length.
-// In the unlikely case where an underlying MakeRandom functions fails,
-// the padding is length 0.
-func MakeSecureRandomPadding(minLength, maxLength int) []byte {
-	var padding []byte
-	paddingSize, err := MakeSecureRandomInt(maxLength - minLength)
-	if err != nil {
-		NoticeAlert("MakeSecureRandomPadding: MakeSecureRandomInt failed")
-		return make([]byte, 0)
-	}
-	paddingSize += minLength
-	padding, err = MakeSecureRandomBytes(paddingSize)
-	if err != nil {
-		NoticeAlert("MakeSecureRandomPadding: MakeSecureRandomBytes failed")
-		return make([]byte, 0)
-	}
-	return padding
-}
-
-// MakeRandomPeriod returns a random duration, within a given range.
-// In the unlikely case where an  underlying MakeRandom functions fails,
-// the period is the minimum.
-func MakeRandomPeriod(min, max time.Duration) (duration time.Duration) {
-	period, err := MakeSecureRandomInt64(max.Nanoseconds() - min.Nanoseconds())
-	if err != nil {
-		NoticeAlert("NextRandomRangePeriod: MakeSecureRandomInt64 failed")
-	}
-	duration = min + time.Duration(period)
-	return
-}
-
-// MakeRandomStringHex returns a hex encoded random string.
-// byteLength specifies the pre-encoded data length.
-func MakeRandomStringHex(byteLength int) (string, error) {
-	bytes, err := MakeSecureRandomBytes(byteLength)
-	if err != nil {
-		return "", ContextError(err)
-	}
-	return hex.EncodeToString(bytes), nil
-}
-
-// MakeRandomStringBase64 returns a base64 encoded random string.
-// byteLength specifies the pre-encoded data length.
-func MakeRandomStringBase64(byteLength int) (string, error) {
-	bytes, err := MakeSecureRandomBytes(byteLength)
-	if err != nil {
-		return "", ContextError(err)
-	}
-	return base64.RawURLEncoding.EncodeToString(bytes), nil
-}
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+)
 
 func DecodeCertificate(encodedCertificate string) (certificate *x509.Certificate, err error) {
 	derEncodedCertificate, err := base64.StdEncoding.DecodeString(encodedCertificate)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	certificate, err = x509.ParseCertificate(derEncodedCertificate)
 	if err != nil {
-		return nil, ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	return certificate, nil
 }
@@ -182,35 +74,6 @@ func TrimError(err error) error {
 	return err
 }
 
-// getFunctionName is a helper that extracts a simple function name from
-// full name returned byruntime.Func.Name(). This is used to declutter
-// log messages containing function names.
-func getFunctionName(pc uintptr) string {
-	funcName := runtime.FuncForPC(pc).Name()
-	index := strings.LastIndex(funcName, "/")
-	if index != -1 {
-		funcName = funcName[index+1:]
-	}
-	return funcName
-}
-
-// GetParentContext returns the parent function name and source file
-// line number.
-func GetParentContext() string {
-	pc, _, line, _ := runtime.Caller(2)
-	return fmt.Sprintf("%s#%d", getFunctionName(pc), line)
-}
-
-// ContextError prefixes an error message with the current function
-// name and source file line number.
-func ContextError(err error) error {
-	if err == nil {
-		return nil
-	}
-	pc, _, line, _ := runtime.Caller(1)
-	return fmt.Errorf("%s#%d: %s", getFunctionName(pc), line, err)
-}
-
 // IsAddressInUseError returns true when the err is due to EADDRINUSE/WSAEADDRINUSE.
 func IsAddressInUseError(err error) bool {
 	if err, ok := err.(*net.OpError); ok {
@@ -258,142 +121,3 @@ func (writer *SyncFileWriter) Write(p []byte) (n int, err error) {
 	}
 	return
 }
-
-// GetCurrentTimestamp returns the current time in UTC as
-// an RFC 3339 formatted string.
-func GetCurrentTimestamp() string {
-	return time.Now().UTC().Format(time.RFC3339)
-}
-
-// TruncateTimestampToHour truncates an RFC 3339 formatted string
-// to hour granularity. If the input is not a valid format, the
-// result is "".
-func TruncateTimestampToHour(timestamp string) string {
-	t, err := time.Parse(time.RFC3339, timestamp)
-	if err != nil {
-		NoticeAlert("failed to truncate timestamp: %s", err)
-		return ""
-	}
-	return t.Truncate(1 * time.Hour).Format(time.RFC3339)
-}
-
-// IsFileChanged uses os.Stat to check if the name, size, or last mod time of the
-// file has changed (which is a heuristic, but sufficiently robust for users of this
-// function). Returns nil if file has not changed; otherwise, returns a changed
-// os.FileInfo which may be used to check for subsequent changes.
-func IsFileChanged(path string, previousFileInfo os.FileInfo) (os.FileInfo, error) {
-
-	fileInfo, err := os.Stat(path)
-	if err != nil {
-		return nil, ContextError(err)
-	}
-
-	changed := previousFileInfo == nil ||
-		fileInfo.Name() != previousFileInfo.Name() ||
-		fileInfo.Size() != previousFileInfo.Size() ||
-		fileInfo.ModTime() != previousFileInfo.ModTime()
-
-	if !changed {
-		return nil, nil
-	}
-
-	return fileInfo, nil
-}
-
-// Reloader represents a read-only, in-memory reloadable data object. For example,
-// a JSON data file that is loaded into memory and accessed for read-only lookups;
-// and from time to time may be reloaded from the same file, updating the memory
-// copy.
-type Reloader interface {
-
-	// Reload reloads the data object. Reload returns a flag indicating if the
-	// reloadable target has changed and reloaded or remains unchanged. By
-	// convention, when reloading fails the Reloader should revert to its previous
-	// in-memory state.
-	Reload() (bool, error)
-
-	// WillReload indicates if the data object is capable of reloading.
-	WillReload() bool
-
-	// LogDescription returns a description to be used for logging
-	// events related to the Reloader.
-	LogDescription() string
-}
-
-// ReloadableFile is a file-backed Reloader. This type is intended to be embedded
-// in other types that add the actual reloadable data structures.
-//
-// ReloadableFile has a multi-reader mutex for synchronization. Its Reload() function
-// will obtain a write lock before reloading the data structures. Actually reloading
-// action is to be provided via the reloadAction callback (for example, read the contents
-// of the file and unmarshall the contents into data structures). All read access to
-// the data structures should be guarded by RLocks on the ReloadableFile mutex.
-//
-// reloadAction must ensure that data structures revert to their previous state when
-// a reload fails.
-//
-type ReloadableFile struct {
-	sync.RWMutex
-	fileName     string
-	fileInfo     os.FileInfo
-	reloadAction func(string) error
-}
-
-// NewReloadableFile initializes a new ReloadableFile
-func NewReloadableFile(
-	fileName string,
-	reloadAction func(string) error) ReloadableFile {
-
-	return ReloadableFile{
-		fileName:     fileName,
-		reloadAction: reloadAction,
-	}
-}
-
-// WillReload indicates whether the ReloadableFile is capable
-// of reloading.
-func (reloadable *ReloadableFile) WillReload() bool {
-	return reloadable.fileName != ""
-}
-
-// Reload checks if the underlying file has changed (using IsFileChanged semantics, which
-// are heuristics) and, when changed, invokes the reloadAction callback which should
-// reload, from the file, the in-memory data structures.
-// All data structure readers should be blocked by the ReloadableFile mutex.
-func (reloadable *ReloadableFile) Reload() (bool, error) {
-
-	if !reloadable.WillReload() {
-		return false, nil
-	}
-
-	// Check whether the file has changed _before_ blocking readers
-
-	reloadable.RLock()
-	changedFileInfo, err := IsFileChanged(reloadable.fileName, reloadable.fileInfo)
-	reloadable.RUnlock()
-	if err != nil {
-		return false, ContextError(err)
-	}
-
-	if changedFileInfo == nil {
-		return false, nil
-	}
-
-	// ...now block readers
-
-	reloadable.Lock()
-	defer reloadable.Unlock()
-
-	err = reloadable.reloadAction(reloadable.fileName)
-	if err != nil {
-		return false, ContextError(err)
-	}
-
-	reloadable.fileInfo = changedFileInfo
-
-	return true, nil
-}
-
-func (reloadable *ReloadableFile) LogDescription() string {
-	return reloadable.fileName
-}