Kaynağa Gözat

Merge pull request #663 from adotkhan/split-ch

Added TLS ClientHello fragmentation
Rod Hynes 2 yıl önce
ebeveyn
işleme
c6287f9a62

+ 7 - 0
psiphon/common/parameters/parameters.go

@@ -243,6 +243,7 @@ const (
 	ReplayHTTPTransformerParameters                  = "ReplayHTTPTransformerParameters"
 	ReplayHTTPTransformerParameters                  = "ReplayHTTPTransformerParameters"
 	ReplayOSSHSeedTransformerParameters              = "ReplayOSSHSeedTransformerParameters"
 	ReplayOSSHSeedTransformerParameters              = "ReplayOSSHSeedTransformerParameters"
 	ReplayOSSHPrefix                                 = "ReplayOSSHPrefix"
 	ReplayOSSHPrefix                                 = "ReplayOSSHPrefix"
+	ReplayTLSFragmentClientHello                     = "ReplayTLSFragmentClientHello"
 	APIRequestUpstreamPaddingMinBytes                = "APIRequestUpstreamPaddingMinBytes"
 	APIRequestUpstreamPaddingMinBytes                = "APIRequestUpstreamPaddingMinBytes"
 	APIRequestUpstreamPaddingMaxBytes                = "APIRequestUpstreamPaddingMaxBytes"
 	APIRequestUpstreamPaddingMaxBytes                = "APIRequestUpstreamPaddingMaxBytes"
 	APIRequestDownstreamPaddingMinBytes              = "APIRequestDownstreamPaddingMinBytes"
 	APIRequestDownstreamPaddingMinBytes              = "APIRequestDownstreamPaddingMinBytes"
@@ -347,6 +348,8 @@ const (
 	TLSTunnelTrafficShapingProbability               = "TLSTunnelTrafficShapingProbability"
 	TLSTunnelTrafficShapingProbability               = "TLSTunnelTrafficShapingProbability"
 	TLSTunnelMinTLSPadding                           = "TLSTunnelMinTLSPadding"
 	TLSTunnelMinTLSPadding                           = "TLSTunnelMinTLSPadding"
 	TLSTunnelMaxTLSPadding                           = "TLSTunnelMaxTLSPadding"
 	TLSTunnelMaxTLSPadding                           = "TLSTunnelMaxTLSPadding"
+	TLSFragmentClientHelloProbability                = "TLSFragmentClientHelloProbability"
+	TLSFragmentClientHelloLimitProtocols             = "TLSFragmentClientHelloLimitProtocols"
 
 
 	// Retired parameters
 	// Retired parameters
 
 
@@ -609,6 +612,7 @@ var defaultParameters = map[string]struct {
 	ReplayHTTPTransformerParameters:        {value: true},
 	ReplayHTTPTransformerParameters:        {value: true},
 	ReplayOSSHSeedTransformerParameters:    {value: true},
 	ReplayOSSHSeedTransformerParameters:    {value: true},
 	ReplayOSSHPrefix:                       {value: true},
 	ReplayOSSHPrefix:                       {value: true},
+	ReplayTLSFragmentClientHello:           {value: true},
 
 
 	APIRequestUpstreamPaddingMinBytes:   {value: 0, minimum: 0},
 	APIRequestUpstreamPaddingMinBytes:   {value: 0, minimum: 0},
 	APIRequestUpstreamPaddingMaxBytes:   {value: 1024, minimum: 0},
 	APIRequestUpstreamPaddingMaxBytes:   {value: 1024, minimum: 0},
@@ -741,6 +745,9 @@ var defaultParameters = map[string]struct {
 	TLSTunnelTrafficShapingProbability: {value: 1.0, minimum: 0.0},
 	TLSTunnelTrafficShapingProbability: {value: 1.0, minimum: 0.0},
 	TLSTunnelMinTLSPadding:             {value: 0, minimum: 0},
 	TLSTunnelMinTLSPadding:             {value: 0, minimum: 0},
 	TLSTunnelMaxTLSPadding:             {value: 0, minimum: 0},
 	TLSTunnelMaxTLSPadding:             {value: 0, minimum: 0},
+
+	TLSFragmentClientHelloProbability:    {value: 0.0, minimum: 0.0},
+	TLSFragmentClientHelloLimitProtocols: {value: protocol.TunnelProtocols{}},
 }
 }
 
 
 // IsServerSideOnly indicates if the parameter specified by name is used
 // IsServerSideOnly indicates if the parameter specified by name is used

+ 24 - 0
psiphon/config.go

@@ -872,6 +872,10 @@ type Config struct {
 	TLSTunnelMinTLSPadding             *int
 	TLSTunnelMinTLSPadding             *int
 	TLSTunnelMaxTLSPadding             *int
 	TLSTunnelMaxTLSPadding             *int
 
 
+	// TLSFragmentClientHello fields are for testing purposes only.
+	TLSFragmentClientHelloProbability    *float64
+	TLSFragmentClientHelloLimitProtocols []string
+
 	// AdditionalParameters is used for testing.
 	// AdditionalParameters is used for testing.
 	AdditionalParameters string
 	AdditionalParameters string
 
 
@@ -2057,6 +2061,14 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.TLSTunnelMaxTLSPadding] = *config.TLSTunnelMaxTLSPadding
 		applyParameters[parameters.TLSTunnelMaxTLSPadding] = *config.TLSTunnelMaxTLSPadding
 	}
 	}
 
 
+	if config.TLSFragmentClientHelloProbability != nil {
+		applyParameters[parameters.TLSFragmentClientHelloProbability] = *config.TLSFragmentClientHelloProbability
+	}
+
+	if len(config.TLSFragmentClientHelloLimitProtocols) > 0 {
+		applyParameters[parameters.TLSFragmentClientHelloLimitProtocols] = protocol.TunnelProtocols(config.TLSFragmentClientHelloLimitProtocols)
+	}
+
 	// When adding new config dial parameters that may override tactics, also
 	// When adding new config dial parameters that may override tactics, also
 	// update setDialParametersHash.
 	// update setDialParametersHash.
 
 
@@ -2588,6 +2600,18 @@ func (config *Config) setDialParametersHash() {
 		binary.Write(hash, binary.LittleEndian, int64(*config.TLSTunnelMaxTLSPadding))
 		binary.Write(hash, binary.LittleEndian, int64(*config.TLSTunnelMaxTLSPadding))
 	}
 	}
 
 
+	if config.TLSFragmentClientHelloProbability != nil {
+		hash.Write([]byte("TLSFragmentClientHelloProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.TLSFragmentClientHelloProbability)
+	}
+
+	if len(config.TLSFragmentClientHelloLimitProtocols) > 0 {
+		hash.Write([]byte("TLSFragmentClientHelloLimitProtocols"))
+		for _, protocol := range config.TLSFragmentClientHelloLimitProtocols {
+			hash.Write([]byte(protocol))
+		}
+	}
+
 	config.dialParametersHash = hash.Sum(nil)
 	config.dialParametersHash = hash.Sum(nil)
 }
 }
 
 

+ 30 - 0
psiphon/dialParameters.go

@@ -122,6 +122,7 @@ type DialParameters struct {
 	NoDefaultTLSSessionID    bool
 	NoDefaultTLSSessionID    bool
 	TLSVersion               string
 	TLSVersion               string
 	RandomizedTLSProfileSeed *prng.Seed
 	RandomizedTLSProfileSeed *prng.Seed
+	TLSFragmentClientHello   bool
 
 
 	QUICVersion                              string
 	QUICVersion                              string
 	QUICDialSNIAddress                       string
 	QUICDialSNIAddress                       string
@@ -198,6 +199,7 @@ func MakeDialParameters(
 	replayObfuscatorPadding := p.Bool(parameters.ReplayObfuscatorPadding)
 	replayObfuscatorPadding := p.Bool(parameters.ReplayObfuscatorPadding)
 	replayFragmentor := p.Bool(parameters.ReplayFragmentor)
 	replayFragmentor := p.Bool(parameters.ReplayFragmentor)
 	replayTLSProfile := p.Bool(parameters.ReplayTLSProfile)
 	replayTLSProfile := p.Bool(parameters.ReplayTLSProfile)
+	replayTLSFragmentClientHello := p.Bool(parameters.ReplayTLSFragmentClientHello)
 	replayFronting := p.Bool(parameters.ReplayFronting)
 	replayFronting := p.Bool(parameters.ReplayFronting)
 	replayHostname := p.Bool(parameters.ReplayHostname)
 	replayHostname := p.Bool(parameters.ReplayHostname)
 	replayQUICVersion := p.Bool(parameters.ReplayQUICVersion)
 	replayQUICVersion := p.Bool(parameters.ReplayQUICVersion)
@@ -1015,6 +1017,32 @@ func MakeDialParameters(
 		}
 		}
 	}
 	}
 
 
+	// TLS ClientHello fragmentation is applied only after the state
+	// of SNI is determined above.
+	if (!isReplay || !replayTLSFragmentClientHello) && usingTLS {
+
+		limitProtocols := p.TunnelProtocols(parameters.TLSFragmentClientHelloLimitProtocols)
+		if len(limitProtocols) == 0 || common.Contains(limitProtocols, dialParams.TunnelProtocol) {
+
+			// Note: The TLS stack automatically drops the SNI extension when
+			// the host is an IP address.
+
+			usingSNI := false
+			if dialParams.TLSOSSHSNIServerName != "" {
+				usingSNI = net.ParseIP(dialParams.TLSOSSHSNIServerName) == nil
+
+			} else if dialParams.MeekSNIServerName != "" {
+				usingSNI = net.ParseIP(dialParams.MeekSNIServerName) == nil
+			}
+
+			// TLS ClientHello fragmentor expects SNI to be present.
+			if usingSNI {
+				dialParams.TLSFragmentClientHello = p.WeightedCoinFlip(
+					parameters.TLSFragmentClientHelloProbability)
+			}
+		}
+	}
+
 	// Initialize/replay User-Agent header for HTTP upstream proxy and meek protocols.
 	// Initialize/replay User-Agent header for HTTP upstream proxy and meek protocols.
 
 
 	if config.UseUpstreamProxy() {
 	if config.UseUpstreamProxy() {
@@ -1128,6 +1156,7 @@ func MakeDialParameters(
 			QUICDisablePathMTUDiscovery:   dialParams.QUICDisablePathMTUDiscovery,
 			QUICDisablePathMTUDiscovery:   dialParams.QUICDisablePathMTUDiscovery,
 			UseHTTPS:                      usingTLS,
 			UseHTTPS:                      usingTLS,
 			TLSProfile:                    dialParams.TLSProfile,
 			TLSProfile:                    dialParams.TLSProfile,
+			TLSFragmentClientHello:        dialParams.TLSFragmentClientHello,
 			LegacyPassthrough:             serverEntry.ProtocolUsesLegacyPassthrough(dialParams.TunnelProtocol),
 			LegacyPassthrough:             serverEntry.ProtocolUsesLegacyPassthrough(dialParams.TunnelProtocol),
 			NoDefaultTLSSessionID:         dialParams.NoDefaultTLSSessionID,
 			NoDefaultTLSSessionID:         dialParams.NoDefaultTLSSessionID,
 			RandomizedTLSProfileSeed:      dialParams.RandomizedTLSProfileSeed,
 			RandomizedTLSProfileSeed:      dialParams.RandomizedTLSProfileSeed,
@@ -1185,6 +1214,7 @@ func (dialParams *DialParameters) GetTLSOSSHConfig(config *Config) *TLSTunnelCon
 			TLSProfile:               dialParams.TLSProfile,
 			TLSProfile:               dialParams.TLSProfile,
 			NoDefaultTLSSessionID:    &dialParams.NoDefaultTLSSessionID,
 			NoDefaultTLSSessionID:    &dialParams.NoDefaultTLSSessionID,
 			RandomizedTLSProfileSeed: dialParams.RandomizedTLSProfileSeed,
 			RandomizedTLSProfileSeed: dialParams.RandomizedTLSProfileSeed,
+			FragmentClientHello:      dialParams.TLSFragmentClientHello,
 		},
 		},
 		// Obfuscated session tickets are not used because TLS-OSSH uses TLS 1.3.
 		// Obfuscated session tickets are not used because TLS-OSSH uses TLS 1.3.
 		UseObfuscatedSessionTickets: false,
 		UseObfuscatedSessionTickets: false,

+ 4 - 0
psiphon/meekConn.go

@@ -131,6 +131,9 @@ type MeekConfig struct {
 	// underlying TLS connections created by this meek connection.
 	// underlying TLS connections created by this meek connection.
 	TLSProfile string
 	TLSProfile string
 
 
+	// TLSFragmentClientHello specifies whether to fragment the TLS Client Hello.
+	TLSFragmentClientHello bool
+
 	// LegacyPassthrough indicates that the server expects a legacy passthrough
 	// LegacyPassthrough indicates that the server expects a legacy passthrough
 	// message.
 	// message.
 	LegacyPassthrough bool
 	LegacyPassthrough bool
@@ -458,6 +461,7 @@ func DialMeek(
 			RandomizedTLSProfileSeed:      meekConfig.RandomizedTLSProfileSeed,
 			RandomizedTLSProfileSeed:      meekConfig.RandomizedTLSProfileSeed,
 			TLSPadding:                    meek.tlsPadding,
 			TLSPadding:                    meek.tlsPadding,
 			TrustedCACertificatesFilename: dialConfig.TrustedCACertificatesFilename,
 			TrustedCACertificatesFilename: dialConfig.TrustedCACertificatesFilename,
+			FragmentClientHello:           meekConfig.TLSFragmentClientHello,
 		}
 		}
 		tlsConfig.EnableClientSessionCache()
 		tlsConfig.EnableClientSessionCache()
 
 

+ 25 - 20
psiphon/meekConn_test.go

@@ -72,30 +72,35 @@ func TestMeekModePlaintextRoundTrip(t *testing.T) {
 		CustomDialer:                  dialer,
 		CustomDialer:                  dialer,
 	}
 	}
 
 
-	ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second)
-	defer cancelFunc()
+	for _, tlsFragmentClientHello := range []bool{false, true} {
 
 
-	meekConn, err := DialMeek(ctx, meekConfig, dialConfig)
-	if err != nil {
-		t.Fatalf("DialMeek failed: %v", err)
-	}
+		ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second)
+		defer cancelFunc()
 
 
-	client := &http.Client{
-		Transport: meekConn,
-	}
+		meekConfig.TLSFragmentClientHello = tlsFragmentClientHello
 
 
-	response, err := client.Get("https://" + serverAddr + "/")
-	if err != nil {
-		t.Fatalf("http.Client.Get failed: %v", err)
-	}
-	response.Body.Close()
+		meekConn, err := DialMeek(ctx, meekConfig, dialConfig)
+		if err != nil {
+			t.Fatalf("DialMeek failed: %v", err)
+		}
 
 
-	if response.StatusCode != http.StatusOK {
-		t.Fatalf("unexpected response code: %v", response.StatusCode)
-	}
+		client := &http.Client{
+			Transport: meekConn,
+		}
 
 
-	err = meekConn.Close()
-	if err != nil {
-		t.Fatalf("MeekConn.Close failed: %v", err)
+		response, err := client.Get("https://" + serverAddr + "/")
+		if err != nil {
+			t.Fatalf("http.Client.Get failed: %v", err)
+		}
+		response.Body.Close()
+
+		if response.StatusCode != http.StatusOK {
+			t.Fatalf("unexpected response code: %v", response.StatusCode)
+		}
+
+		err = meekConn.Close()
+		if err != nil {
+			t.Fatalf("MeekConn.Close failed: %v", err)
+		}
 	}
 	}
 }
 }

+ 15 - 0
psiphon/net.go

@@ -474,6 +474,16 @@ func makeFrontedHTTPClient(
 		networkLatencyMultiplierMax,
 		networkLatencyMultiplierMax,
 		p.Float(parameters.NetworkLatencyMultiplierLambda))
 		p.Float(parameters.NetworkLatencyMultiplierLambda))
 
 
+	tlsFragmentClientHello := false
+	if meekSNIServerName != "" {
+		tlsFragmentorLimitProtocols := p.TunnelProtocols(parameters.TLSFragmentClientHelloLimitProtocols)
+		if len(tlsFragmentorLimitProtocols) == 0 || common.Contains(tlsFragmentorLimitProtocols, effectiveTunnelProtocol) {
+			if net.ParseIP(meekSNIServerName) == nil {
+				tlsFragmentClientHello = p.WeightedCoinFlip(parameters.TLSFragmentClientHelloProbability)
+			}
+		}
+	}
+
 	meekConfig := &MeekConfig{
 	meekConfig := &MeekConfig{
 		DiagnosticID:             frontingProviderID,
 		DiagnosticID:             frontingProviderID,
 		Parameters:               config.GetParameters(),
 		Parameters:               config.GetParameters(),
@@ -481,6 +491,7 @@ func makeFrontedHTTPClient(
 		DialAddress:              meekDialAddress,
 		DialAddress:              meekDialAddress,
 		UseHTTPS:                 true,
 		UseHTTPS:                 true,
 		TLSProfile:               tlsProfile,
 		TLSProfile:               tlsProfile,
+		TLSFragmentClientHello:   tlsFragmentClientHello,
 		NoDefaultTLSSessionID:    noDefaultTLSSessionID,
 		NoDefaultTLSSessionID:    noDefaultTLSSessionID,
 		RandomizedTLSProfileSeed: randomizedTLSProfileSeed,
 		RandomizedTLSProfileSeed: randomizedTLSProfileSeed,
 		SNIServerName:            meekSNIServerName,
 		SNIServerName:            meekSNIServerName,
@@ -603,6 +614,10 @@ func makeFrontedHTTPClient(
 			params["tls_version"] = getTLSVersionForMetrics(tlsVersion, meekConfig.NoDefaultTLSSessionID)
 			params["tls_version"] = getTLSVersionForMetrics(tlsVersion, meekConfig.NoDefaultTLSSessionID)
 		}
 		}
 
 
+		if meekConfig.TLSFragmentClientHello {
+			params["tls_fragmented"] = "1"
+		}
+
 		return params
 		return params
 	}
 	}
 
 

+ 4 - 0
psiphon/notice.go

@@ -513,6 +513,10 @@ func noticeWithDialParameters(noticeType string, dialParams *DialParameters, pos
 			args = append(args, "tlsOSSHTransformedSNIServerName", dialParams.TLSOSSHTransformedSNIServerName)
 			args = append(args, "tlsOSSHTransformedSNIServerName", dialParams.TLSOSSHTransformedSNIServerName)
 		}
 		}
 
 
+		if dialParams.TLSFragmentClientHello {
+			args = append(args, "tlsFragmentClientHello", dialParams.TLSFragmentClientHello)
+		}
+
 		if dialParams.SelectedUserAgent {
 		if dialParams.SelectedUserAgent {
 			args = append(args, "userAgent", dialParams.UserAgent)
 			args = append(args, "userAgent", dialParams.UserAgent)
 		}
 		}

+ 4 - 1
psiphon/server/api.go

@@ -541,7 +541,9 @@ var remoteServerListStatParams = append(
 		{"meek_transformed_host_name", isBooleanFlag, requestParamOptional | requestParamLogFlagAsBool},
 		{"meek_transformed_host_name", isBooleanFlag, requestParamOptional | requestParamLogFlagAsBool},
 		{"user_agent", isAnyString, requestParamOptional},
 		{"user_agent", isAnyString, requestParamOptional},
 		{"tls_profile", isAnyString, requestParamOptional},
 		{"tls_profile", isAnyString, requestParamOptional},
-		{"tls_version", isAnyString, requestParamOptional}},
+		{"tls_version", isAnyString, requestParamOptional},
+		{"tls_fragmented", isBooleanFlag, requestParamOptional | requestParamLogFlagAsBool},
+	},
 
 
 	baseSessionParams...)
 	baseSessionParams...)
 
 
@@ -950,6 +952,7 @@ var baseDialParams = []requestParamSpec{
 	{"http_transform", isAnyString, requestParamOptional},
 	{"http_transform", isAnyString, requestParamOptional},
 	{"seed_transform", isAnyString, requestParamOptional},
 	{"seed_transform", isAnyString, requestParamOptional},
 	{"ossh_prefix", isAnyString, requestParamOptional},
 	{"ossh_prefix", isAnyString, requestParamOptional},
+	{"tls_fragmented", isBooleanFlag, requestParamOptional | requestParamLogFlagAsBool},
 }
 }
 
 
 // baseSessionAndDialParams adds baseDialParams to baseSessionParams.
 // baseSessionAndDialParams adds baseDialParams to baseSessionParams.

+ 4 - 0
psiphon/serverApi.go

@@ -1007,6 +1007,10 @@ func getBaseAPIParameters(
 			params["tls_ossh_transformed_host_name"] = "1"
 			params["tls_ossh_transformed_host_name"] = "1"
 		}
 		}
 
 
+		if dialParams.TLSFragmentClientHello {
+			params["tls_fragmented"] = "1"
+		}
+
 		if dialParams.SelectedUserAgent {
 		if dialParams.SelectedUserAgent {
 			params["user_agent"] = dialParams.UserAgent
 			params["user_agent"] = dialParams.UserAgent
 		}
 		}

+ 233 - 0
psiphon/tlsDialer.go

@@ -57,11 +57,14 @@ import (
 	"crypto/sha256"
 	"crypto/sha256"
 	"crypto/x509"
 	"crypto/x509"
 	"encoding/base64"
 	"encoding/base64"
+	"encoding/binary"
 	"encoding/hex"
 	"encoding/hex"
 	std_errors "errors"
 	std_errors "errors"
+	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"math"
 	"math"
 	"net"
 	"net"
+	"sync/atomic"
 
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
@@ -178,6 +181,9 @@ type CustomTLSConfig struct {
 	// obfuscator.MakeTLSPassthroughMessage.
 	// obfuscator.MakeTLSPassthroughMessage.
 	PassthroughMessage []byte
 	PassthroughMessage []byte
 
 
+	// FragmentClientHello specifies whether to fragment the ClientHello.
+	FragmentClientHello bool
+
 	clientSessionCache utls.ClientSessionCache
 	clientSessionCache utls.ClientSessionCache
 }
 }
 
 
@@ -236,6 +242,10 @@ func CustomTLSDial(
 		return nil, errors.Trace(err)
 		return nil, errors.Trace(err)
 	}
 	}
 
 
+	if config.FragmentClientHello {
+		rawConn = NewTLSFragmentorConn(rawConn)
+	}
+
 	hostname, _, err := net.SplitHostPort(dialAddr)
 	hostname, _, err := net.SplitHostPort(dialAddr)
 	if err != nil {
 	if err != nil {
 		rawConn.Close()
 		rawConn.Close()
@@ -1010,3 +1020,226 @@ func init() {
 	// downloads, don't depend on this TLS for its security properties.
 	// downloads, don't depend on this TLS for its security properties.
 	utls.EnableWeakCiphers()
 	utls.EnableWeakCiphers()
 }
 }
+
+type TLSFragmentorConn struct {
+	net.Conn
+	clientHelloSent int32
+}
+
+func NewTLSFragmentorConn(
+	conn net.Conn,
+) net.Conn {
+	return &TLSFragmentorConn{
+		Conn: conn,
+	}
+}
+
+func (c *TLSFragmentorConn) Close() error {
+	return c.Conn.Close()
+}
+
+func (c *TLSFragmentorConn) Read(b []byte) (n int, err error) {
+	return c.Conn.Read(b)
+}
+
+// Write transparently splits the first TLS record containing ClientHello into
+// two fragments and writes them separately to the underlying conn.
+// The second fragment contains the data portion of the SNI extension (i.e. the server name).
+// Write assumes a non-fragmented and complete ClientHello on the first call.
+func (c *TLSFragmentorConn) Write(b []byte) (n int, err error) {
+
+	if atomic.LoadInt32(&c.clientHelloSent) == 0 {
+
+		buf := bytes.NewReader(b)
+
+		var contentType uint8
+		err := binary.Read(buf, binary.BigEndian, &contentType)
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+		if contentType != 0x16 {
+			return 0, errors.TraceNew("expected Handshake content type")
+		}
+
+		var version uint16
+		err = binary.Read(buf, binary.BigEndian, &version)
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+		if version != 0x0303 && version != 0x0302 && version != 0x0301 {
+			return 0, errors.TraceNew("expected TLS version 0x0303 or 0x0302 or 0x0301")
+		}
+
+		var msgLen uint16
+		err = binary.Read(buf, binary.BigEndian, &msgLen)
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+		if len(b) != int(msgLen)+5 {
+			return 0, errors.TraceNew("unexpected TLS message length")
+		}
+
+		var handshakeType uint8
+		err = binary.Read(buf, binary.BigEndian, &handshakeType)
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+		if handshakeType != 0x01 {
+			return 0, errors.TraceNew("expected ClientHello(1) handshake type")
+		}
+
+		var handshakeLen uint32
+		err = binary.Read(buf, binary.BigEndian, &handshakeLen)
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+		handshakeLen >>= 8 // 24-bit value
+		buf.UnreadByte()   // Unread the last byte
+
+		var legacyVersion uint16
+		err = binary.Read(buf, binary.BigEndian, &legacyVersion)
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+		if legacyVersion != 0x0303 {
+			return 0, errors.TraceNew("expected TLS version 0x0303")
+		}
+
+		// Skip random
+		_, err = buf.Seek(32, io.SeekCurrent)
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+
+		var sessionIdLen uint8
+		err = binary.Read(buf, binary.BigEndian, &sessionIdLen)
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+		if sessionIdLen > 32 {
+			return 0, errors.TraceNew("unexpected session ID length")
+		}
+
+		// Skip session ID
+		_, err = buf.Seek(int64(sessionIdLen), io.SeekCurrent)
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+
+		var cipherSuitesLen uint16
+		err = binary.Read(buf, binary.BigEndian, &cipherSuitesLen)
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+		if cipherSuitesLen < 2 || cipherSuitesLen > 65535 {
+			return 0, errors.TraceNew("unexpected cipher suites length")
+		}
+
+		// Skip cipher suites
+		_, err = buf.Seek(int64(cipherSuitesLen), io.SeekCurrent)
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+
+		var compressionMethodsLen int8
+		err = binary.Read(buf, binary.BigEndian, &compressionMethodsLen)
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+		if compressionMethodsLen < 1 || compressionMethodsLen > 32 {
+			return 0, errors.TraceNew("unexpected compression methods length")
+		}
+
+		// Skip compression methods
+		_, err = buf.Seek(int64(compressionMethodsLen), io.SeekCurrent)
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+
+		var extensionsLen uint16
+		err = binary.Read(buf, binary.BigEndian, &extensionsLen)
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+		if extensionsLen < 2 || extensionsLen > 65535 {
+			return 0, errors.TraceNew("unexpected extensions length")
+		}
+
+		// Finds SNI extension.
+		for {
+			if buf.Len() == 0 {
+				return 0, errors.TraceNew("missing SNI extension")
+			}
+
+			var extensionType uint16
+			err = binary.Read(buf, binary.BigEndian, &extensionType)
+			if err != nil {
+				return 0, errors.Trace(err)
+			}
+
+			var extensionLen uint16
+			err = binary.Read(buf, binary.BigEndian, &extensionLen)
+			if err != nil {
+				return 0, errors.Trace(err)
+			}
+
+			// server_name(0) extension type
+			if extensionType == 0x0000 {
+				break
+			}
+
+			// Skip extension data
+			_, err = buf.Seek(int64(extensionLen), io.SeekCurrent)
+			if err != nil {
+				return 0, errors.Trace(err)
+			}
+		}
+
+		sniStartIndex := len(b) - buf.Len()
+
+		// Splits the ClientHello message into two fragments at sniStartIndex,
+		// and writes them separately to the underlying conn.
+		tlsMessage := b[5:]
+		frag1, frag2, err := splitTLSMessage(contentType, version, tlsMessage, sniStartIndex)
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+		n, err = c.Conn.Write(frag1)
+		if err != nil {
+			return n, errors.Trace(err)
+		}
+		n2, err := c.Conn.Write(frag2)
+		if err != nil {
+			return n + n2, errors.Trace(err)
+		}
+
+		atomic.CompareAndSwapInt32(&c.clientHelloSent, 0, 1)
+
+		return len(b), nil
+	}
+
+	return c.Conn.Write(b)
+}
+
+// splitTLSMessage splits a TLS message into two fragments.
+// The two fragments are wrapped in TLS records.
+func splitTLSMessage(contentType uint8, version uint16, msg []byte, splitIndex int) ([]byte, []byte, error) {
+	if splitIndex > len(msg)-1 {
+		return nil, nil, errors.TraceNew("split index out of range")
+	}
+
+	frag1 := make([]byte, splitIndex+5)
+	frag2 := make([]byte, len(msg)-splitIndex+5)
+
+	frag1[0] = byte(contentType)
+	binary.BigEndian.PutUint16(frag1[1:3], version)
+	binary.BigEndian.PutUint16(frag1[3:5], uint16(splitIndex))
+	copy(frag1[5:], msg[:splitIndex])
+
+	frag2[0] = byte(contentType)
+	binary.BigEndian.PutUint16(frag2[1:3], version)
+	binary.BigEndian.PutUint16(frag2[3:5], uint16(len(msg)-splitIndex))
+	copy(frag2[5:], msg[splitIndex:])
+
+	return frag1, frag2, nil
+}

+ 78 - 11
psiphon/tlsDialer_test.go

@@ -418,7 +418,8 @@ func initTestCertificatesAndWebServer(
 
 
 func TestTLSDialerCompatibility(t *testing.T) {
 func TestTLSDialerCompatibility(t *testing.T) {
 
 
-	// This test checks that each TLS profile can successfully complete a TLS
+	// This test checks that each TLS profile in combination with TLS ClientHello
+	// fragmentation can successfully complete a TLS
 	// handshake with various servers. By default, only the "psiphon" case is
 	// handshake with various servers. By default, only the "psiphon" case is
 	// run, which runs the same TLS listener used by a Psiphon server.
 	// run, which runs the same TLS listener used by a Psiphon server.
 	//
 	//
@@ -432,22 +433,25 @@ func TestTLSDialerCompatibility(t *testing.T) {
 		configAddresses = strings.Split(string(config), "\n")
 		configAddresses = strings.Split(string(config), "\n")
 	}
 	}
 
 
-	runner := func(address string) func(t *testing.T) {
+	runner := func(address string, fragmentClientHello bool) func(t *testing.T) {
 		return func(t *testing.T) {
 		return func(t *testing.T) {
-			testTLSDialerCompatibility(t, address)
+			testTLSDialerCompatibility(t, address, fragmentClientHello)
 		}
 		}
 	}
 	}
 
 
 	for _, address := range configAddresses {
 	for _, address := range configAddresses {
-		if len(address) > 0 {
-			t.Run(address, runner(address))
+		for _, fragmentClientHello := range []bool{false, true} {
+			if len(address) > 0 {
+				t.Run(fmt.Sprintf("%s (fragmentClientHello: %v)", address, fragmentClientHello),
+					runner(address, fragmentClientHello))
+			}
 		}
 		}
 	}
 	}
 
 
-	t.Run("psiphon", runner(""))
+	t.Run("psiphon", runner("", false))
 }
 }
 
 
-func testTLSDialerCompatibility(t *testing.T, address string) {
+func testTLSDialerCompatibility(t *testing.T, address string, fragmentClientHello bool) {
 
 
 	if address == "" {
 	if address == "" {
 
 
@@ -519,10 +523,11 @@ func testTLSDialerCompatibility(t *testing.T, address string) {
 			transformHostname := i%2 == 0
 			transformHostname := i%2 == 0
 
 
 			tlsConfig := &CustomTLSConfig{
 			tlsConfig := &CustomTLSConfig{
-				Parameters: params,
-				Dial:       dialer,
-				SkipVerify: true,
-				TLSProfile: tlsProfile,
+				Parameters:          params,
+				Dial:                dialer,
+				SkipVerify:          true,
+				TLSProfile:          tlsProfile,
+				FragmentClientHello: fragmentClientHello,
 			}
 			}
 
 
 			if transformHostname {
 			if transformHostname {
@@ -760,6 +765,68 @@ func TestSelectTLSProfile(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestTLSFragmentorWithoutSNI(t *testing.T) {
+	testDataDirName, err := ioutil.TempDir("", "psiphon-tls-certificate-verification-test")
+	if err != nil {
+		t.Fatalf("TempDir failed: %v", err)
+	}
+	defer os.RemoveAll(testDataDirName)
+
+	serverName := "example.org"
+
+	rootCAsFileName,
+		_,
+		serverCertificatePin,
+		shutdown,
+		serverAddr,
+		dialer := initTestCertificatesAndWebServer(
+		t, testDataDirName, serverName)
+	defer shutdown()
+
+	params, err := parameters.NewParameters(nil)
+	if err != nil {
+		t.Fatalf("parameters.NewParameters failed: %v", err)
+	}
+
+	// Test: missing SNI, the TLS dial fails
+
+	conn, err := CustomTLSDial(
+		context.Background(), "tcp", serverAddr,
+		&CustomTLSConfig{
+			Parameters:                    params,
+			Dial:                          dialer,
+			SNIServerName:                 "",
+			VerifyServerName:              serverName,
+			VerifyPins:                    []string{serverCertificatePin},
+			TrustedCACertificatesFilename: rootCAsFileName,
+		})
+
+	if err == nil {
+		t.Errorf("unexpected success without SNI")
+		conn.Close()
+	}
+
+	// Test: with SNI, the TLS dial succeeds
+
+	conn, err = CustomTLSDial(
+		context.Background(), "tcp", serverAddr,
+		&CustomTLSConfig{
+			Parameters:                    params,
+			Dial:                          dialer,
+			SNIServerName:                 serverName,
+			VerifyServerName:              serverName,
+			VerifyPins:                    []string{serverCertificatePin},
+			TrustedCACertificatesFilename: rootCAsFileName,
+		})
+
+	if err != nil {
+		t.Errorf("CustomTLSDial failed: %v", err)
+	} else {
+		conn.Close()
+	}
+
+}
+
 func BenchmarkRandomizedGetClientHelloVersion(b *testing.B) {
 func BenchmarkRandomizedGetClientHelloVersion(b *testing.B) {
 	for n := 0; n < b.N; n++ {
 	for n := 0; n < b.N; n++ {
 		utlsClientHelloID := utls.HelloRandomized
 		utlsClientHelloID := utls.HelloRandomized

+ 1 - 0
psiphon/tlsTunnelConn.go

@@ -86,6 +86,7 @@ func DialTLSTunnel(
 		RandomizedTLSProfileSeed:      tlsTunnelConfig.CustomTLSConfig.RandomizedTLSProfileSeed,
 		RandomizedTLSProfileSeed:      tlsTunnelConfig.CustomTLSConfig.RandomizedTLSProfileSeed,
 		TLSPadding:                    tlsPadding,
 		TLSPadding:                    tlsPadding,
 		TrustedCACertificatesFilename: dialConfig.TrustedCACertificatesFilename,
 		TrustedCACertificatesFilename: dialConfig.TrustedCACertificatesFilename,
+		FragmentClientHello:           tlsTunnelConfig.CustomTLSConfig.FragmentClientHello,
 	}
 	}
 	tlsConfig.EnableClientSessionCache()
 	tlsConfig.EnableClientSessionCache()