|
|
@@ -57,11 +57,14 @@ import (
|
|
|
"crypto/sha256"
|
|
|
"crypto/x509"
|
|
|
"encoding/base64"
|
|
|
+ "encoding/binary"
|
|
|
"encoding/hex"
|
|
|
std_errors "errors"
|
|
|
+ "io"
|
|
|
"io/ioutil"
|
|
|
"math"
|
|
|
"net"
|
|
|
+ "sync/atomic"
|
|
|
|
|
|
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
|
|
|
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
|
|
|
@@ -178,6 +181,9 @@ type CustomTLSConfig struct {
|
|
|
// obfuscator.MakeTLSPassthroughMessage.
|
|
|
PassthroughMessage []byte
|
|
|
|
|
|
+ // FragmentClientHello specifies whether to fragment the ClientHello.
|
|
|
+ FragmentClientHello bool
|
|
|
+
|
|
|
clientSessionCache utls.ClientSessionCache
|
|
|
}
|
|
|
|
|
|
@@ -236,6 +242,10 @@ func CustomTLSDial(
|
|
|
return nil, errors.Trace(err)
|
|
|
}
|
|
|
|
|
|
+ if config.FragmentClientHello {
|
|
|
+ rawConn = NewTLSFragmentorConn(rawConn)
|
|
|
+ }
|
|
|
+
|
|
|
hostname, _, err := net.SplitHostPort(dialAddr)
|
|
|
if err != nil {
|
|
|
rawConn.Close()
|
|
|
@@ -1010,3 +1020,226 @@ func init() {
|
|
|
// downloads, don't depend on this TLS for its security properties.
|
|
|
utls.EnableWeakCiphers()
|
|
|
}
|
|
|
+
|
|
|
+type TLSFragmentorConn struct {
|
|
|
+ net.Conn
|
|
|
+ clientHelloSent int32
|
|
|
+}
|
|
|
+
|
|
|
+func NewTLSFragmentorConn(
|
|
|
+ conn net.Conn,
|
|
|
+) net.Conn {
|
|
|
+ return &TLSFragmentorConn{
|
|
|
+ Conn: conn,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (c *TLSFragmentorConn) Close() error {
|
|
|
+ return c.Conn.Close()
|
|
|
+}
|
|
|
+
|
|
|
+func (c *TLSFragmentorConn) Read(b []byte) (n int, err error) {
|
|
|
+ return c.Conn.Read(b)
|
|
|
+}
|
|
|
+
|
|
|
+// Write transparently splits the first TLS record containing ClientHello into
|
|
|
+// two fragments and writes them separately to the underlying conn.
|
|
|
+// The second fragment contains the data portion of the SNI extension (i.e. the server name).
|
|
|
+// Write assumes a non-fragmented and complete ClientHello on the first call.
|
|
|
+func (c *TLSFragmentorConn) Write(b []byte) (n int, err error) {
|
|
|
+
|
|
|
+ if atomic.LoadInt32(&c.clientHelloSent) == 0 {
|
|
|
+
|
|
|
+ buf := bytes.NewReader(b)
|
|
|
+
|
|
|
+ var contentType uint8
|
|
|
+ err := binary.Read(buf, binary.BigEndian, &contentType)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+ if contentType != 0x16 {
|
|
|
+ return 0, errors.TraceNew("expected Handshake content type")
|
|
|
+ }
|
|
|
+
|
|
|
+ var version uint16
|
|
|
+ err = binary.Read(buf, binary.BigEndian, &version)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+ if version != 0x0303 && version != 0x0302 && version != 0x0301 {
|
|
|
+ return 0, errors.TraceNew("expected TLS version 0x0303 or 0x0302 or 0x0301")
|
|
|
+ }
|
|
|
+
|
|
|
+ var msgLen uint16
|
|
|
+ err = binary.Read(buf, binary.BigEndian, &msgLen)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+ if len(b) != int(msgLen)+5 {
|
|
|
+ return 0, errors.TraceNew("unexpected TLS message length")
|
|
|
+ }
|
|
|
+
|
|
|
+ var handshakeType uint8
|
|
|
+ err = binary.Read(buf, binary.BigEndian, &handshakeType)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+ if handshakeType != 0x01 {
|
|
|
+ return 0, errors.TraceNew("expected ClientHello(1) handshake type")
|
|
|
+ }
|
|
|
+
|
|
|
+ var handshakeLen uint32
|
|
|
+ err = binary.Read(buf, binary.BigEndian, &handshakeLen)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+ handshakeLen >>= 8 // 24-bit value
|
|
|
+ buf.UnreadByte() // Unread the last byte
|
|
|
+
|
|
|
+ var legacyVersion uint16
|
|
|
+ err = binary.Read(buf, binary.BigEndian, &legacyVersion)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+ if legacyVersion != 0x0303 {
|
|
|
+ return 0, errors.TraceNew("expected TLS version 0x0303")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Skip random
|
|
|
+ _, err = buf.Seek(32, io.SeekCurrent)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ var sessionIdLen uint8
|
|
|
+ err = binary.Read(buf, binary.BigEndian, &sessionIdLen)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+ if sessionIdLen > 32 {
|
|
|
+ return 0, errors.TraceNew("unexpected session ID length")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Skip session ID
|
|
|
+ _, err = buf.Seek(int64(sessionIdLen), io.SeekCurrent)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ var cipherSuitesLen uint16
|
|
|
+ err = binary.Read(buf, binary.BigEndian, &cipherSuitesLen)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+ if cipherSuitesLen < 2 || cipherSuitesLen > 65535 {
|
|
|
+ return 0, errors.TraceNew("unexpected cipher suites length")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Skip cipher suites
|
|
|
+ _, err = buf.Seek(int64(cipherSuitesLen), io.SeekCurrent)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ var compressionMethodsLen int8
|
|
|
+ err = binary.Read(buf, binary.BigEndian, &compressionMethodsLen)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+ if compressionMethodsLen < 1 || compressionMethodsLen > 32 {
|
|
|
+ return 0, errors.TraceNew("unexpected compression methods length")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Skip compression methods
|
|
|
+ _, err = buf.Seek(int64(compressionMethodsLen), io.SeekCurrent)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ var extensionsLen uint16
|
|
|
+ err = binary.Read(buf, binary.BigEndian, &extensionsLen)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+ if extensionsLen < 2 || extensionsLen > 65535 {
|
|
|
+ return 0, errors.TraceNew("unexpected extensions length")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Finds SNI extension.
|
|
|
+ for {
|
|
|
+ if buf.Len() == 0 {
|
|
|
+ return 0, errors.TraceNew("missing SNI extension")
|
|
|
+ }
|
|
|
+
|
|
|
+ var extensionType uint16
|
|
|
+ err = binary.Read(buf, binary.BigEndian, &extensionType)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ var extensionLen uint16
|
|
|
+ err = binary.Read(buf, binary.BigEndian, &extensionLen)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // server_name(0) extension type
|
|
|
+ if extensionType == 0x0000 {
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ // Skip extension data
|
|
|
+ _, err = buf.Seek(int64(extensionLen), io.SeekCurrent)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ sniStartIndex := len(b) - buf.Len()
|
|
|
+
|
|
|
+ // Splits the ClientHello message into two fragments at sniStartIndex,
|
|
|
+ // and writes them separately to the underlying conn.
|
|
|
+ tlsMessage := b[5:]
|
|
|
+ frag1, frag2, err := splitTLSMessage(contentType, version, tlsMessage, sniStartIndex)
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+ n, err = c.Conn.Write(frag1)
|
|
|
+ if err != nil {
|
|
|
+ return n, errors.Trace(err)
|
|
|
+ }
|
|
|
+ n2, err := c.Conn.Write(frag2)
|
|
|
+ if err != nil {
|
|
|
+ return n + n2, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ atomic.CompareAndSwapInt32(&c.clientHelloSent, 0, 1)
|
|
|
+
|
|
|
+ return len(b), nil
|
|
|
+ }
|
|
|
+
|
|
|
+ return c.Conn.Write(b)
|
|
|
+}
|
|
|
+
|
|
|
+// splitTLSMessage splits a TLS message into two fragments.
|
|
|
+// The two fragments are wrapped in TLS records.
|
|
|
+func splitTLSMessage(contentType uint8, version uint16, msg []byte, splitIndex int) ([]byte, []byte, error) {
|
|
|
+ if splitIndex > len(msg)-1 {
|
|
|
+ return nil, nil, errors.TraceNew("split index out of range")
|
|
|
+ }
|
|
|
+
|
|
|
+ frag1 := make([]byte, splitIndex+5)
|
|
|
+ frag2 := make([]byte, len(msg)-splitIndex+5)
|
|
|
+
|
|
|
+ frag1[0] = byte(contentType)
|
|
|
+ binary.BigEndian.PutUint16(frag1[1:3], version)
|
|
|
+ binary.BigEndian.PutUint16(frag1[3:5], uint16(splitIndex))
|
|
|
+ copy(frag1[5:], msg[:splitIndex])
|
|
|
+
|
|
|
+ frag2[0] = byte(contentType)
|
|
|
+ binary.BigEndian.PutUint16(frag2[1:3], version)
|
|
|
+ binary.BigEndian.PutUint16(frag2[3:5], uint16(len(msg)-splitIndex))
|
|
|
+ copy(frag2[5:], msg[splitIndex:])
|
|
|
+
|
|
|
+ return frag1, frag2, nil
|
|
|
+}
|