Przeglądaj źródła

Merge pull request #352 from rod-hynes/master

Misc. changes (see description)
Rod Hynes 9 lat temu
rodzic
commit
338cd27372

+ 16 - 0
psiphon/common/utils.go

@@ -28,6 +28,7 @@ import (
 	"errors"
 	"fmt"
 	"io/ioutil"
+	"math"
 	"math/big"
 	"runtime"
 	"strings"
@@ -134,6 +135,21 @@ func MakeRandomStringBase64(byteLength int) (string, error) {
 	return base64.RawURLEncoding.EncodeToString(bytes), nil
 }
 
+// Jitter returns n +/- the given factor.
+// For example, for n = 100 and factor = 0.1, the
+// return value will be in the range [90, 110].
+func Jitter(n int64, factor float64) int64 {
+	a := int64(math.Ceil(float64(n) * factor))
+	r, _ := MakeSecureRandomInt64(2*a + 1)
+	return n + r - a
+}
+
+// JitterDuration is a helper function that wraps Jitter.
+func JitterDuration(
+	d time.Duration, factor float64) time.Duration {
+	return time.Duration(Jitter(int64(d), factor))
+}
+
 // GetCurrentTimestamp returns the current time in UTC as
 // an RFC 3339 formatted string.
 func GetCurrentTimestamp() string {

+ 42 - 0
psiphon/common/utils_test.go

@@ -21,6 +21,8 @@ package common
 
 import (
 	"bytes"
+	"fmt"
+	"math"
 	"testing"
 	"time"
 )
@@ -54,6 +56,46 @@ func TestMakeRandomPeriod(t *testing.T) {
 	}
 }
 
+func TestJitter(t *testing.T) {
+
+	testCases := []struct {
+		n           int64
+		factor      float64
+		expectedMin int64
+		expectedMax int64
+	}{
+		{100, 0.1, 90, 110},
+		{1000, 0.3, 700, 1300},
+	}
+
+	for _, testCase := range testCases {
+		t.Run(fmt.Sprintf("jitter case: %+v", testCase), func(t *testing.T) {
+
+			min := int64(math.MaxInt64)
+			max := int64(0)
+
+			for i := 0; i < 100000; i++ {
+
+				x := Jitter(testCase.n, testCase.factor)
+				if x < min {
+					min = x
+				}
+				if x > max {
+					max = x
+				}
+			}
+
+			if min != testCase.expectedMin {
+				t.Errorf("unexpected minimum jittered value: %d", min)
+			}
+
+			if max != testCase.expectedMax {
+				t.Errorf("unexpected maximum jittered value: %d", max)
+			}
+		})
+	}
+}
+
 func TestCompress(t *testing.T) {
 
 	originalData := []byte("test data")

+ 1 - 1
psiphon/controller_test.go

@@ -497,7 +497,7 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 	// tests and ensure tests complete within fixed deadlines.
 	fetchRemoteServerListRetryPeriodSeconds := 0
 	config.FetchRemoteServerListRetryPeriodSeconds = &fetchRemoteServerListRetryPeriodSeconds
-	downloadUpgradeRetryPeriodSeconds := 0
+	downloadUpgradeRetryPeriodSeconds := 1
 	config.DownloadUpgradeRetryPeriodSeconds = &downloadUpgradeRetryPeriodSeconds
 	establishTunnelPausePeriodSeconds := 1
 	config.EstablishTunnelPausePeriodSeconds = &establishTunnelPausePeriodSeconds

+ 39 - 5
psiphon/meekConn.go

@@ -54,8 +54,11 @@ const (
 	FULL_RECEIVE_BUFFER_LENGTH     = 4194304
 	READ_PAYLOAD_CHUNK_LENGTH      = 65536
 	MIN_POLL_INTERVAL              = 100 * time.Millisecond
+	MIN_POLL_INTERVAL_JITTER       = 0.3
 	MAX_POLL_INTERVAL              = 5 * time.Second
-	POLL_INTERNAL_MULTIPLIER       = 1.5
+	MAX_POLL_INTERVAL_JITTER       = 0.1
+	POLL_INTERVAL_MULTIPLIER       = 1.5
+	POLL_INTERVAL_JITTER           = 0.1
 	MEEK_ROUND_TRIP_RETRY_DEADLINE = 1 * time.Second
 	MEEK_ROUND_TRIP_RETRY_DELAY    = 50 * time.Millisecond
 	MEEK_ROUND_TRIP_TIMEOUT        = 20 * time.Second
@@ -474,9 +477,15 @@ func (meek *MeekConn) relay() {
 	// Note: meek.Close() calls here in relay() are made asynchronously
 	// (using goroutines) since Close() will wait on this WaitGroup.
 	defer meek.relayWaitGroup.Done()
-	interval := MIN_POLL_INTERVAL
+
+	interval := common.JitterDuration(
+		MIN_POLL_INTERVAL,
+		MIN_POLL_INTERVAL_JITTER)
+
 	timeout := time.NewTimer(interval)
+
 	sendPayload := make([]byte, MAX_SEND_PAYLOAD_LENGTH)
+
 	for {
 		timeout.Reset(interval)
 		// Block until there is payload to send or it is time to poll
@@ -517,14 +526,39 @@ func (meek *MeekConn) relay() {
 			go meek.Close()
 			return
 		}
+
+		// Calculate polling interval. When data is received,
+		// immediately request more. Otherwise, schedule next
+		// poll with exponential back off. Jitter and coin
+		// flips are used to avoid trivial, static traffic
+		// timing patterns.
+
 		if receivedPayloadSize > 0 || sendPayloadSize > 0 {
+
 			interval = 0
+
 		} else if interval == 0 {
-			interval = MIN_POLL_INTERVAL
+
+			interval = common.JitterDuration(
+				MIN_POLL_INTERVAL,
+				MIN_POLL_INTERVAL_JITTER)
+
 		} else {
-			interval = time.Duration(float64(interval) * POLL_INTERNAL_MULTIPLIER)
+
+			if common.FlipCoin() {
+				interval = common.JitterDuration(
+					interval,
+					POLL_INTERVAL_JITTER)
+			} else {
+				interval = common.JitterDuration(
+					time.Duration(float64(interval)*POLL_INTERVAL_MULTIPLIER),
+					POLL_INTERVAL_JITTER)
+			}
+
 			if interval >= MAX_POLL_INTERVAL {
-				interval = MAX_POLL_INTERVAL
+				interval = common.JitterDuration(
+					MAX_POLL_INTERVAL,
+					MAX_POLL_INTERVAL_JITTER)
 			}
 		}
 	}

+ 8 - 4
psiphon/common/pluginProtocol.go → psiphon/pluginProtocol.go

@@ -17,12 +17,14 @@
  *
  */
 
-package common
+package psiphon
 
 import (
 	"io"
 	"net"
 	"sync/atomic"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 var registeredPluginProtocolDialer atomic.Value
@@ -41,8 +43,9 @@ type PluginProtocolNetDialer func(network, addr string) (net.Conn, error)
 // PluginProtocolDialer must add its connection to pendingConns
 // before the initial dial to allow for interruption.
 type PluginProtocolDialer func(
+	config *Config,
 	loggerOutput io.Writer,
-	pendingConns *Conns,
+	pendingConns *common.Conns,
 	netDialer PluginProtocolNetDialer,
 	addr string) (
 	bool, net.Conn, error)
@@ -56,8 +59,9 @@ func RegisterPluginProtocol(protcolDialer PluginProtocolDialer) {
 // DialPluginProtocol uses the current plugin protocol dialer,
 // if set, to connect to addr over the plugin protocol.
 func DialPluginProtocol(
+	config *Config,
 	loggerOutput io.Writer,
-	pendingConns *Conns,
+	pendingConns *common.Conns,
 	netDialer PluginProtocolNetDialer,
 	addr string) (
 	bool, net.Conn, error) {
@@ -65,7 +69,7 @@ func DialPluginProtocol(
 	dialer := registeredPluginProtocolDialer.Load()
 	if dialer != nil {
 		return dialer.(PluginProtocolDialer)(
-			loggerOutput, pendingConns, netDialer, addr)
+			config, loggerOutput, pendingConns, netDialer, addr)
 	}
 	return false, nil, nil
 }

+ 3 - 0
psiphon/remoteServerList_test.go

@@ -76,6 +76,9 @@ func TestObfuscatedRemoteServerLists(t *testing.T) {
 			TunnelProtocolPorts:  map[string]int{"OSSH": 4001},
 			LogFilename:          filepath.Join(testDataDirName, "psiphond.log"),
 			LogLevel:             "debug",
+
+			// "defer os.RemoveAll" will cause a log write error
+			SkipPanickingLogWriter: true,
 		})
 	if err != nil {
 		t.Fatalf("error generating server config: %s", err)

+ 6 - 2
psiphon/server/api.go

@@ -98,8 +98,12 @@ func dispatchAPIRequestHandler(
 	// terminating in the case of a bug.
 	defer func() {
 		if e := recover(); e != nil {
-			log.LogPanicRecover(e, debug.Stack())
-			reterr = common.ContextError(errors.New("request handler panic"))
+			if intentionalPanic, ok := e.(IntentionalPanicError); ok {
+				panic(intentionalPanic)
+			} else {
+				log.LogPanicRecover(e, debug.Stack())
+				reterr = common.ContextError(errors.New("request handler panic"))
+			}
 		}
 	}()
 

+ 12 - 13
psiphon/server/config.go

@@ -63,10 +63,9 @@ type Config struct {
 	// to. When blank, logs are written to stderr.
 	LogFilename string
 
-	// PanicLogFilename specifies the path of the file to
-	// log unrecovered panics to. When blank, logs are
-	// written to stderr
-	PanicLogFilename string
+	// SkipPanickingLogWriter disables panicking when
+	// unable to write any logs.
+	SkipPanickingLogWriter bool
 
 	// DiscoveryValueHMACKey is the network-wide secret value
 	// used to determine a unique discovery strategy.
@@ -343,14 +342,14 @@ func validateNetworkAddress(address string, requireIPaddress bool) error {
 // GenerateConfigParams specifies customizations to be applied to
 // a generated server config.
 type GenerateConfigParams struct {
-	LogFilename          string
-	PanicLogFilename     string
-	LogLevel             string
-	ServerIPAddress      string
-	WebServerPort        int
-	EnableSSHAPIRequests bool
-	TunnelProtocolPorts  map[string]int
-	TrafficRulesFilename string
+	LogFilename            string
+	SkipPanickingLogWriter bool
+	LogLevel               string
+	ServerIPAddress        string
+	WebServerPort          int
+	EnableSSHAPIRequests   bool
+	TunnelProtocolPorts    map[string]int
+	TrafficRulesFilename   string
 }
 
 // GenerateConfig creates a new Psiphon server config. It returns JSON
@@ -501,7 +500,7 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 	config := &Config{
 		LogLevel:                       logLevel,
 		LogFilename:                    params.LogFilename,
-		PanicLogFilename:               params.PanicLogFilename,
+		SkipPanickingLogWriter:         params.SkipPanickingLogWriter,
 		GeoIPDatabaseFilenames:         nil,
 		HostID:                         "example-host-id",
 		ServerIPAddress:                params.ServerIPAddress,

+ 27 - 0
psiphon/server/log.go

@@ -203,6 +203,33 @@ func InitLogging(config *Config) (retErr error) {
 				retErr = common.ContextError(err)
 				return
 			}
+
+			if !config.SkipPanickingLogWriter {
+
+				// Use PanickingLogWriter, which will intentionally
+				// panic when a Write fails. Set SkipPanickingLogWriter
+				// if this behavior is not desired.
+				//
+				// Note that NewRotatableFileWriter will first attempt
+				// a retry when a Write fails.
+				//
+				// It is assumed that continuing operation while unable
+				// to log is unacceptable; and that the psiphond service
+				// is managed and will restart when it terminates.
+				//
+				// It is further assumed that panicking will result in
+				// an error that is externally logged and reported to a
+				// monitoring system.
+				//
+				// TODO: An orderly shutdown may be preferred, as some
+				// data will be lost in a panic (e.g., server_tunnel logs).
+				// It may be possible to perform an orderly shutdown first
+				// and then panic, or perform an orderly shutdown and
+				// simulate a panic message that will be reported.
+
+				logWriter = NewPanickingLogWriter(config.LogFilename, logWriter)
+			}
+
 		} else {
 			logWriter = os.Stderr
 		}

+ 3 - 3
psiphon/server/server_test.go

@@ -913,7 +913,7 @@ func paveOSLConfigFile(t *testing.T, oslConfigFilename string) string {
             }
           ],
           "SeedSpecThreshold" : 2,
-          "SeedPeriodNanoseconds" : 3600000000000,
+          "SeedPeriodNanoseconds" : 2592000000000000,
           "SeedPeriodKeySplits": [
             {
               "Total": 2,
@@ -939,7 +939,7 @@ func paveOSLConfigFile(t *testing.T, oslConfigFilename string) string {
             }
           ],
           "SeedSpecThreshold" : 1,
-          "SeedPeriodNanoseconds" : 3600000000000,
+          "SeedPeriodNanoseconds" : 2592000000000000,
           "SeedPeriodKeySplits": [
             {
               "Total": 1,
@@ -954,7 +954,7 @@ func paveOSLConfigFile(t *testing.T, oslConfigFilename string) string {
 	propagationChannelID, _ := common.MakeRandomStringHex(8)
 
 	now := time.Now().UTC()
-	epoch := now.Truncate(1 * time.Hour)
+	epoch := now.Truncate(720 * time.Hour)
 	epochStr := epoch.Format(time.RFC3339Nano)
 
 	oslConfigJSON := fmt.Sprintf(

+ 58 - 30
psiphon/server/tunnelServer.go

@@ -191,7 +191,7 @@ func (server *TunnelServer) Run() error {
 // broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
 // include current connected client count, total number of current port
 // forwards.
-func (server *TunnelServer) GetLoadStats() map[string]map[string]int64 {
+func (server *TunnelServer) GetLoadStats() map[string]interface{} {
 	return server.sshServer.getLoadStats()
 }
 
@@ -239,7 +239,7 @@ type sshServer struct {
 	sshHostKey           ssh.Signer
 	clientsMutex         sync.Mutex
 	stoppingClients      bool
-	acceptedClientCounts map[string]int64
+	acceptedClientCounts map[string]map[string]int64
 	clients              map[string]*sshClient
 }
 
@@ -263,7 +263,7 @@ func newSSHServer(
 		establishTunnels:     1,
 		shutdownBroadcast:    shutdownBroadcast,
 		sshHostKey:           signer,
-		acceptedClientCounts: make(map[string]int64),
+		acceptedClientCounts: make(map[string]map[string]int64),
 		clients:              make(map[string]*sshClient),
 	}, nil
 }
@@ -375,20 +375,24 @@ func (sshServer *sshServer) runListener(
 
 // An accepted client has completed a direct TCP or meek connection and has a net.Conn. Registration
 // is for tracking the number of connections.
-func (sshServer *sshServer) registerAcceptedClient(tunnelProtocol string) {
+func (sshServer *sshServer) registerAcceptedClient(tunnelProtocol, region string) {
 
 	sshServer.clientsMutex.Lock()
 	defer sshServer.clientsMutex.Unlock()
 
-	sshServer.acceptedClientCounts[tunnelProtocol] += 1
+	if sshServer.acceptedClientCounts[tunnelProtocol] == nil {
+		sshServer.acceptedClientCounts[tunnelProtocol] = make(map[string]int64)
+	}
+
+	sshServer.acceptedClientCounts[tunnelProtocol][region] += 1
 }
 
-func (sshServer *sshServer) unregisterAcceptedClient(tunnelProtocol string) {
+func (sshServer *sshServer) unregisterAcceptedClient(tunnelProtocol, region string) {
 
 	sshServer.clientsMutex.Lock()
 	defer sshServer.clientsMutex.Unlock()
 
-	sshServer.acceptedClientCounts[tunnelProtocol] -= 1
+	sshServer.acceptedClientCounts[tunnelProtocol][region] -= 1
 }
 
 // An established client has completed its SSH handshake and has a ssh.Conn. Registration is
@@ -437,43 +441,61 @@ func (sshServer *sshServer) unregisterEstablishedClient(sessionID string) {
 	}
 }
 
-func (sshServer *sshServer) getLoadStats() map[string]map[string]int64 {
+func (sshServer *sshServer) getLoadStats() map[string]interface{} {
 
 	sshServer.clientsMutex.Lock()
 	defer sshServer.clientsMutex.Unlock()
 
-	loadStats := make(map[string]map[string]int64)
+	protocolStats := make(map[string]map[string]map[string]int64)
 
 	// Explicitly populate with zeros to get 0 counts in log messages derived from getLoadStats()
 
 	for tunnelProtocol, _ := range sshServer.support.Config.TunnelProtocolPorts {
-		loadStats[tunnelProtocol] = make(map[string]int64)
-		loadStats[tunnelProtocol]["accepted_clients"] = 0
-		loadStats[tunnelProtocol]["established_clients"] = 0
-		loadStats[tunnelProtocol]["tcp_port_forwards"] = 0
-		loadStats[tunnelProtocol]["total_tcp_port_forwards"] = 0
-		loadStats[tunnelProtocol]["udp_port_forwards"] = 0
-		loadStats[tunnelProtocol]["total_udp_port_forwards"] = 0
+		protocolStats[tunnelProtocol] = make(map[string]map[string]int64)
+		protocolStats[tunnelProtocol]["ALL"] = make(map[string]int64)
+		protocolStats[tunnelProtocol]["ALL"]["accepted_clients"] = 0
+		protocolStats[tunnelProtocol]["ALL"]["established_clients"] = 0
+		protocolStats[tunnelProtocol]["ALL"]["tcp_port_forwards"] = 0
+		protocolStats[tunnelProtocol]["ALL"]["total_tcp_port_forwards"] = 0
+		protocolStats[tunnelProtocol]["ALL"]["udp_port_forwards"] = 0
+		protocolStats[tunnelProtocol]["ALL"]["total_udp_port_forwards"] = 0
 	}
 
 	// Note: as currently tracked/counted, each established client is also an accepted client
 
-	for tunnelProtocol, acceptedClientCount := range sshServer.acceptedClientCounts {
-		loadStats[tunnelProtocol]["accepted_clients"] = acceptedClientCount
+	for tunnelProtocol, regionAcceptedClientCounts := range sshServer.acceptedClientCounts {
+		total := int64(0)
+		for region, acceptedClientCount := range regionAcceptedClientCounts {
+			if protocolStats[tunnelProtocol][region] == nil {
+				protocolStats[tunnelProtocol][region] = make(map[string]int64)
+			}
+			protocolStats[tunnelProtocol][region]["accepted_clients"] = acceptedClientCount
+			total += acceptedClientCount
+		}
+		protocolStats[tunnelProtocol]["ALL"]["accepted_clients"] = total
 	}
 
 	var aggregatedQualityMetrics qualityMetrics
 
 	for _, client := range sshServer.clients {
-		// Note: can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
-		loadStats[client.tunnelProtocol]["established_clients"] += 1
 
 		client.Lock()
 
-		loadStats[client.tunnelProtocol]["tcp_port_forwards"] += client.tcpTrafficState.concurrentPortForwardCount
-		loadStats[client.tunnelProtocol]["total_tcp_port_forwards"] += client.tcpTrafficState.totalPortForwardCount
-		loadStats[client.tunnelProtocol]["udp_port_forwards"] += client.udpTrafficState.concurrentPortForwardCount
-		loadStats[client.tunnelProtocol]["total_udp_port_forwards"] += client.udpTrafficState.totalPortForwardCount
+		for _, region := range []string{"ALL", client.geoIPData.Country} {
+
+			if protocolStats[tunnelProtocol][region] == nil {
+				protocolStats[tunnelProtocol][region] = make(map[string]int64)
+			}
+
+			// Note: can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
+			protocolStats[client.tunnelProtocol][region]["established_clients"] += 1
+
+			protocolStats[client.tunnelProtocol][region]["tcp_port_forwards"] += client.tcpTrafficState.concurrentPortForwardCount
+			protocolStats[client.tunnelProtocol][region]["total_tcp_port_forwards"] += client.tcpTrafficState.totalPortForwardCount
+			protocolStats[client.tunnelProtocol][region]["udp_port_forwards"] += client.udpTrafficState.concurrentPortForwardCount
+			protocolStats[client.tunnelProtocol][region]["total_udp_port_forwards"] += client.udpTrafficState.totalPortForwardCount
+
+		}
 
 		aggregatedQualityMetrics.tcpPortForwardDialedCount += client.qualityMetrics.tcpPortForwardDialedCount
 		aggregatedQualityMetrics.tcpPortForwardDialedDuration +=
@@ -504,13 +526,19 @@ func (sshServer *sshServer) getLoadStats() map[string]map[string]int64 {
 	allProtocolsStats["tcp_port_forward_failed_count"] = aggregatedQualityMetrics.tcpPortForwardFailedCount
 	allProtocolsStats["tcp_port_forward_failed_duration"] = int64(aggregatedQualityMetrics.tcpPortForwardFailedDuration)
 
-	for _, stats := range loadStats {
-		for name, value := range stats {
-			allProtocolsStats[name] += value
+	for _, regionStats := range protocolStats {
+		for _, stats := range regionStats {
+			for name, value := range stats {
+				allProtocolsStats[name] += value
+			}
 		}
 	}
 
+	loadStats := make(map[string]interface{})
 	loadStats["ALL"] = allProtocolsStats
+	for tunnelProtocol, stats := range protocolStats {
+		loadStats[tunnelProtocol] = stats
+	}
 
 	return loadStats
 }
@@ -577,12 +605,12 @@ func (sshServer *sshServer) stopClients() {
 
 func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.Conn) {
 
-	sshServer.registerAcceptedClient(tunnelProtocol)
-	defer sshServer.unregisterAcceptedClient(tunnelProtocol)
-
 	geoIPData := sshServer.support.GeoIPService.Lookup(
 		common.IPAddressFromAddr(clientConn.RemoteAddr()))
 
+	sshServer.registerAcceptedClient(tunnelProtocol, geoIPData.Country)
+	defer sshServer.unregisterAcceptedClient(tunnelProtocol, geoIPData.Country)
+
 	sshClient := newSshClient(sshServer, tunnelProtocol, geoIPData)
 
 	sshClient.run(clientConn)

+ 48 - 0
psiphon/server/utils.go

@@ -26,6 +26,8 @@ import (
 	"crypto/x509"
 	"crypto/x509/pkix"
 	"encoding/pem"
+	"fmt"
+	"io"
 	"math/big"
 	"time"
 
@@ -126,3 +128,49 @@ func GenerateWebServerCertificate(commonName string) (string, string, error) {
 
 	return string(webServerCertificate), string(webServerPrivateKey), nil
 }
+
+// IntentionalPanicError is an error type that is used
+// when calling panic() in a situation where recovers
+// should propagate the panic.
+type IntentionalPanicError struct {
+	message string
+}
+
+// NewIntentionalPanicError creates a new IntentionalPanicError.
+func NewIntentionalPanicError(errorMessage string) error {
+	return IntentionalPanicError{
+		message: fmt.Sprintf("intentional panic error: %s", errorMessage)}
+}
+
+// Error implements the error interface.
+func (err IntentionalPanicError) Error() string {
+	return err.message
+}
+
+// PanickingLogWriter wraps an io.Writer and intentionally
+// panics when a Write() fails.
+type PanickingLogWriter struct {
+	name   string
+	writer io.Writer
+}
+
+// NewPanickingLogWriter creates a new PanickingLogWriter.
+func NewPanickingLogWriter(
+	name string, writer io.Writer) *PanickingLogWriter {
+
+	return &PanickingLogWriter{
+		name:   name,
+		writer: writer,
+	}
+}
+
+// Write implements the io.Writer interface.
+func (w *PanickingLogWriter) Write(p []byte) (n int, err error) {
+	n, err = w.writer.Write(p)
+	if err != nil {
+		panic(
+			NewIntentionalPanicError(
+				fmt.Sprintf("fatal write to %s failed: %s", w.name, err)))
+	}
+	return
+}

+ 3 - 2
psiphon/tunnel.go

@@ -616,7 +616,7 @@ func dialSsh(
 		}
 	}
 
-	dialCustomHeaders, selectedUserAgent = common.UserAgentIfUnset(config.CustomHeaders)
+	dialCustomHeaders, selectedUserAgent = UserAgentIfUnset(config.CustomHeaders)
 
 	// Use an asynchronous callback to record the resolved IP address when
 	// dialing a domain name. Note that DialMeek doesn't immediately
@@ -705,7 +705,8 @@ func dialSsh(
 		// For some direct connect servers, DialPluginProtocol
 		// will layer on another obfuscation protocol.
 		var dialedPlugin bool
-		dialedPlugin, dialConn, err = common.DialPluginProtocol(
+		dialedPlugin, dialConn, err = DialPluginProtocol(
+			config,
 			NewNoticeWriter("DialPluginProtocol"),
 			pendingConns,
 			func(_, addr string) (net.Conn, error) {

+ 4 - 2
psiphon/common/userAgentPicker.go → psiphon/userAgentPicker.go

@@ -17,11 +17,13 @@
  *
  */
 
-package common
+package psiphon
 
 import (
 	"net/http"
 	"sync/atomic"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
 var registeredUserAgentPicker atomic.Value
@@ -54,7 +56,7 @@ func UserAgentIfUnset(h http.Header) (http.Header, bool) {
 			}
 		}
 
-		if FlipCoin() {
+		if common.FlipCoin() {
 			dialHeaders.Set("User-Agent", pickUserAgent())
 		} else {
 			dialHeaders.Set("User-Agent", "")