Browse Source

Add time factor to passthrough messages

Rod Hynes 4 years ago
parent
commit
2dcc641138

+ 21 - 6
psiphon/common/obfuscator/history.go

@@ -106,18 +106,32 @@ func NewSeedHistory(config *SeedHistoryConfig) *SeedHistory {
 	}
 }
 
-// AddNew adds a new seed value to the history. If the seed value is already
-// in the history, and an expected case such as a meek retry is ruled out (or
-// strictMode is on), AddNew returns false.
+// AddNew calls AddNewWithTTL using the SeedTTL that was specified in the
+// SeedHistoryConfig.
+func (h *SeedHistory) AddNew(
+	strictMode bool,
+	clientIP string,
+	seedType string,
+	seed []byte) (bool, *common.LogFields) {
+
+	return h.AddNewWithTTL(
+		strictMode, clientIP, seedType, seed, lrucache.DefaultExpiration)
+}
+
+// AddNewWithTTL adds a new seed value to the history, set to expire with the
+// specified TTL. If the seed value is already in the history, and an expected
+// case such as a meek retry is ruled out (or strictMode is on), AddNew
+// returns false.
 //
 // When a duplicate seed is found, a common.LogFields instance is returned,
 // populated with event data. Log fields may be returned in either the false
 // or true case.
-func (h *SeedHistory) AddNew(
+func (h *SeedHistory) AddNewWithTTL(
 	strictMode bool,
 	clientIP string,
 	seedType string,
-	seed []byte) (bool, *common.LogFields) {
+	seed []byte,
+	TTL time.Duration) (bool, *common.LogFields) {
 
 	key := string(seed)
 
@@ -126,8 +140,9 @@ func (h *SeedHistory) AddNew(
 	// an unlikely possibility that this Add and the following Get don't see the
 	// same existing key/value state.
 
-	if h.seedToTime.Add(key, time.Now(), lrucache.DefaultExpiration) == nil {
+	if h.seedToTime.Add(key, time.Now(), TTL) == nil {
 		// Seed was not already in cache
+		// TODO: if TTL < SeedHistory.ClientIPTTL, use the shorter TTL here
 		h.seedToClientIP.Set(key, clientIP, lrucache.DefaultExpiration)
 		return true, nil
 	}

+ 69 - 28
psiphon/common/obfuscator/passthrough.go

@@ -24,7 +24,9 @@ import (
 	"crypto/rand"
 	"crypto/sha256"
 	"crypto/subtle"
+	"encoding/binary"
 	"io"
+	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"golang.org/x/crypto/hkdf"
@@ -33,42 +35,27 @@ import (
 const (
 	TLS_PASSTHROUGH_NONCE_SIZE   = 16
 	TLS_PASSTHROUGH_KEY_SIZE     = 32
+	TLS_PASSTHROUGH_TIME_PERIOD  = 15 * time.Minute
 	TLS_PASSTHROUGH_MESSAGE_SIZE = 32
 )
 
-// DeriveTLSPassthroughKey derives a TLS passthrough key from a master
-// obfuscated key. The resulting key can be cached and passed to
-// VerifyTLSPassthroughMessage.
-func DeriveTLSPassthroughKey(obfuscatedKey string) ([]byte, error) {
-
-	secret := []byte(obfuscatedKey)
-
-	salt := []byte("passthrough-obfuscation-key")
-
-	key := make([]byte, TLS_PASSTHROUGH_KEY_SIZE)
-
-	_, err := io.ReadFull(hkdf.New(sha256.New, secret, salt, nil), key)
-	if err != nil {
-		return nil, errors.Trace(err)
-	}
-
-	return key, nil
-}
-
 // MakeTLSPassthroughMessage generates a unique TLS passthrough message
 // using the passthrough key derived from a master obfuscated key.
 //
 // The passthrough message demonstrates knowledge of the obfuscated key.
-func MakeTLSPassthroughMessage(obfuscatedKey string) ([]byte, error) {
+// When useTimeFactor is set, the message will also reflect the current
+// time period, limiting how long it remains valid.
+//
+// The configurable useTimeFactor enables support for legacy clients and
+// servers which don't use the time factor.
+func MakeTLSPassthroughMessage(
+	useTimeFactor bool, obfuscatedKey string) ([]byte, error) {
 
-	passthroughKey, err := DeriveTLSPassthroughKey(obfuscatedKey)
-	if err != nil {
-		return nil, errors.Trace(err)
-	}
+	passthroughKey := derivePassthroughKey(useTimeFactor, obfuscatedKey)
 
 	message := make([]byte, TLS_PASSTHROUGH_MESSAGE_SIZE)
 
-	_, err = rand.Read(message[0:TLS_PASSTHROUGH_NONCE_SIZE])
+	_, err := rand.Read(message[0:TLS_PASSTHROUGH_NONCE_SIZE])
 	if err != nil {
 		return nil, errors.Trace(err)
 	}
@@ -82,12 +69,22 @@ func MakeTLSPassthroughMessage(obfuscatedKey string) ([]byte, error) {
 
 // VerifyTLSPassthroughMessage checks that the specified passthrough message
 // was generated using the passthrough key.
-func VerifyTLSPassthroughMessage(passthroughKey, message []byte) bool {
-
+//
+// useTimeFactor must be set to the same value used in
+// MakeTLSPassthroughMessage.
+func VerifyTLSPassthroughMessage(
+	useTimeFactor bool, obfuscatedKey string, message []byte) bool {
+
+	// If the message is the wrong length, continue processing with a stub
+	// message of the correct length. This is to avoid leaking the existence of
+	// passthrough via timing differences.
 	if len(message) != TLS_PASSTHROUGH_MESSAGE_SIZE {
-		return false
+		var stub [TLS_PASSTHROUGH_MESSAGE_SIZE]byte
+		message = stub[:]
 	}
 
+	passthroughKey := derivePassthroughKey(useTimeFactor, obfuscatedKey)
+
 	h := hmac.New(sha256.New, passthroughKey)
 	h.Write(message[0:TLS_PASSTHROUGH_NONCE_SIZE])
 
@@ -95,3 +92,47 @@ func VerifyTLSPassthroughMessage(passthroughKey, message []byte) bool {
 		message[TLS_PASSTHROUGH_NONCE_SIZE:],
 		h.Sum(nil)[0:TLS_PASSTHROUGH_MESSAGE_SIZE-TLS_PASSTHROUGH_NONCE_SIZE])
 }
+
+// timePeriodSeconds is variable, to enable overriding the value in
+// TestTLSPassthrough. This value should not be overridden outside of test
+// cases.
+var timePeriodSeconds = int64(TLS_PASSTHROUGH_TIME_PERIOD / time.Second)
+
+func derivePassthroughKey(
+	useTimeFactor bool, obfuscatedKey string) []byte {
+
+	secret := []byte(obfuscatedKey)
+
+	salt := []byte("passthrough-obfuscation-key")
+
+	if useTimeFactor {
+
+		// Include a time factor, so messages created with this key remain valid
+		// only for a limited time period. The current time is rounded, allowing the
+		// client clock to be slightly ahead of or behind of the server clock.
+		//
+		// This time factor mechanism is used in concert with SeedHistory to detect
+		// passthrough message replay. SeedHistory, a history of recent passthrough
+		// messages, is used to detect duplicate passthrough messages. The time
+		// factor bounds the necessary history length: passthrough messages older
+		// than the time period no longer need to be retained in history.
+		//
+		// We _always_ derive the passthrough key for each
+		// MakeTLSPassthroughMessage, even for multiple calls in the same time
+		// factor period, to avoid leaking the presense of passthough via timing
+		// differences at time boundaries. We assume that the server always or never
+		// sets useTimeFactor.
+
+		roundedTimePeriod := (time.Now().Unix() + (timePeriodSeconds / 2)) / timePeriodSeconds
+
+		var timeFactor [8]byte
+		binary.LittleEndian.PutUint64(timeFactor[:], uint64(roundedTimePeriod))
+		salt = append(salt, timeFactor[:]...)
+	}
+
+	key := make([]byte, TLS_PASSTHROUGH_KEY_SIZE)
+
+	_, _ = io.ReadFull(hkdf.New(sha256.New, secret, salt, nil), key)
+
+	return key
+}

+ 92 - 24
psiphon/common/obfuscator/passthrough_test.go

@@ -21,43 +21,111 @@ package obfuscator
 
 import (
 	"bytes"
+	"fmt"
 	"testing"
+	"time"
 )
 
 func TestTLSPassthrough(t *testing.T) {
 
+	// Use artificially low time factor period for test
+	timePeriodSeconds = 2
+
 	correctMasterKey := "correct-master-key"
 	incorrectMasterKey := "incorrect-master-key"
 
-	passthroughKey, err := DeriveTLSPassthroughKey(correctMasterKey)
-	if err != nil {
-		t.Fatalf("DeriveTLSPassthroughKey failed: %s", err)
-	}
+	for _, useTimeFactor := range []bool{false, true} {
 
-	validMessage, err := MakeTLSPassthroughMessage(correctMasterKey)
-	if err != nil {
-		t.Fatalf("MakeTLSPassthroughMessage failed: %s", err)
-	}
+		t.Run(fmt.Sprintf("useTimeFactor: %v", useTimeFactor), func(t *testing.T) {
 
-	if !VerifyTLSPassthroughMessage(passthroughKey, validMessage) {
-		t.Fatalf("unexpected invalid passthrough messages")
-	}
+			// test: valid passthrough message
 
-	anotherValidMessage, err := MakeTLSPassthroughMessage(correctMasterKey)
-	if err != nil {
-		t.Fatalf("MakeTLSPassthroughMessage failed: %s", err)
-	}
+			validMessage, err := MakeTLSPassthroughMessage(useTimeFactor, correctMasterKey)
+			if err != nil {
+				t.Fatalf("MakeTLSPassthroughMessage failed: %s", err)
+			}
 
-	if bytes.Equal(validMessage, anotherValidMessage) {
-		t.Fatalf("unexpected identical passthrough messages")
-	}
+			startTime := time.Now()
 
-	invalidMessage, err := MakeTLSPassthroughMessage(incorrectMasterKey)
-	if err != nil {
-		t.Fatalf("MakeTLSPassthroughMessage failed: %s", err)
-	}
+			if !VerifyTLSPassthroughMessage(useTimeFactor, correctMasterKey, validMessage) {
+				t.Fatalf("unexpected invalid passthrough message")
+			}
+
+			correctElapsedTime := time.Now().Sub(startTime)
+
+			// test: passthrough messages are not identical
+
+			anotherValidMessage, err := MakeTLSPassthroughMessage(useTimeFactor, correctMasterKey)
+			if err != nil {
+				t.Fatalf("MakeTLSPassthroughMessage failed: %s", err)
+			}
+
+			if bytes.Equal(validMessage, anotherValidMessage) {
+				t.Fatalf("unexpected identical passthrough messages")
+			}
+
+			// test: valid passthrough message still valid within time factor period
+
+			time.Sleep(1 * time.Millisecond)
+
+			if !VerifyTLSPassthroughMessage(useTimeFactor, correctMasterKey, validMessage) {
+				t.Fatalf("unexpected invalid delayed passthrough message")
+			}
+
+			// test: valid passthrough message now invalid after time factor period
+
+			time.Sleep(time.Duration(timePeriodSeconds)*time.Second + time.Millisecond)
+
+			verified := VerifyTLSPassthroughMessage(useTimeFactor, correctMasterKey, validMessage)
+
+			if verified && useTimeFactor {
+				t.Fatalf("unexpected replayed passthrough message")
+			}
+
+			// test: invalid passthrough message with incorrect key
+
+			invalidMessage, err := MakeTLSPassthroughMessage(useTimeFactor, incorrectMasterKey)
+			if err != nil {
+				t.Fatalf("MakeTLSPassthroughMessage failed: %s", err)
+			}
+
+			startTime = time.Now()
+
+			if VerifyTLSPassthroughMessage(useTimeFactor, correctMasterKey, invalidMessage) {
+				t.Fatalf("unexpected valid passthrough message")
+			}
+
+			incorrectElapsedTime := time.Now().Sub(startTime)
+
+			// test: valid/invalid elapsed times are nearly identical
+
+			timeDiff := correctElapsedTime - incorrectElapsedTime
+			if timeDiff < 0 {
+				timeDiff = -timeDiff
+			}
+
+			if timeDiff.Microseconds() > 100 {
+				t.Fatalf("unexpected elapsed time difference")
+			}
+
+			// test: invalid message length and elapsed time
+
+			startTime = time.Now()
+
+			if VerifyTLSPassthroughMessage(useTimeFactor, correctMasterKey, invalidMessage[:16]) {
+				t.Fatalf("unexpected valid passthrough message with invalid length")
+			}
+
+			incorrectElapsedTime = time.Now().Sub(startTime)
+
+			timeDiff = correctElapsedTime - incorrectElapsedTime
+			if timeDiff < 0 {
+				timeDiff = -timeDiff
+			}
 
-	if VerifyTLSPassthroughMessage(passthroughKey, invalidMessage) {
-		t.Fatalf("unexpected valid passthrough messages")
+			if timeDiff.Microseconds() > 100 {
+				t.Fatalf("unexpected elapsed time difference")
+			}
+		})
 	}
 }

+ 18 - 5
psiphon/common/protocol/serverEntry.go

@@ -447,13 +447,14 @@ func GetTacticsCapability(protocol string) string {
 
 // hasCapability indicates if the server entry has the specified capability.
 //
-// Any internal "PASSTHROUGH" componant in the server entry's capabilities is
-// ignored. The PASSTHROUGH component is used to mask protocols which are
-// running the passthrough mechanism from older clients which do not implement
-// the passthrough message. Older clients will treat these capabilities as
-// unknown protocols and skip them.
+// Any internal "PASSTHROUGH-v2 or "PASSTHROUGH" componant in the server
+// entry's capabilities is ignored. These PASSTHROUGH components are used to
+// mask protocols which are running the passthrough mechanisms from older
+// clients which do not implement the passthrough messages. Older clients will
+// treat these capabilities as unknown protocols and skip them.
 func (serverEntry *ServerEntry) hasCapability(requiredCapability string) bool {
 	for _, capability := range serverEntry.Capabilities {
+		capability = strings.ReplaceAll(capability, "-PASSTHROUGH-v2", "")
 		capability = strings.ReplaceAll(capability, "-PASSTHROUGH", "")
 		if capability == requiredCapability {
 			return true
@@ -469,6 +470,18 @@ func (serverEntry *ServerEntry) SupportsProtocol(protocol string) bool {
 	return serverEntry.hasCapability(requiredCapability)
 }
 
+// ProtocolUsesLegacyPassthrough indicates whether the ServerEntry supports
+// the specified protocol using legacy passthrough messages.
+func (serverEntry *ServerEntry) ProtocolUsesLegacyPassthrough(protocol string) bool {
+	legacyCapability := GetCapability(protocol) + "-PASSTHROUGH"
+	for _, capability := range serverEntry.Capabilities {
+		if capability == legacyCapability {
+			return true
+		}
+	}
+	return false
+}
+
 // ConditionallyEnabledComponents defines an interface which can be queried to
 // determine which conditionally compiled protocol components are present.
 type ConditionallyEnabledComponents interface {

+ 1 - 0
psiphon/dialParameters.go

@@ -812,6 +812,7 @@ func MakeDialParameters(
 			QUICClientHelloSeed:           dialParams.QUICClientHelloSeed,
 			UseHTTPS:                      usingTLS,
 			TLSProfile:                    dialParams.TLSProfile,
+			LegacyPassthrough:             serverEntry.ProtocolUsesLegacyPassthrough(dialParams.TunnelProtocol),
 			NoDefaultTLSSessionID:         dialParams.NoDefaultTLSSessionID,
 			RandomizedTLSProfileSeed:      dialParams.RandomizedTLSProfileSeed,
 			UseObfuscatedSessionTickets:   dialParams.TunnelProtocol == protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET,

+ 5 - 1
psiphon/meekConn.go

@@ -125,6 +125,10 @@ type MeekConfig struct {
 	// underlying TLS connections created by this meek connection.
 	TLSProfile string
 
+	// LegacyPassthrough indicates that the server expects a legacy passthrough
+	// message.
+	LegacyPassthrough bool
+
 	// NoDefaultTLSSessionID specifies the value for
 	// CustomTLSConfig.NoDefaultTLSSessionID for all underlying TLS connections
 	// created by this meek connection.
@@ -137,7 +141,6 @@ type MeekConfig struct {
 
 	// UseObfuscatedSessionTickets indicates whether to use obfuscated session
 	// tickets. Assumes UseHTTPS is true. Ignored for MeekModePlaintextRoundTrip.
-	//
 	UseObfuscatedSessionTickets bool
 
 	// SNIServerName is the value to place in the TLS/QUIC SNI server_name field
@@ -435,6 +438,7 @@ func DialMeek(
 			// clients don't know which servers are configured to use it).
 
 			passthroughMessage, err := obfuscator.MakeTLSPassthroughMessage(
+				!meekConfig.LegacyPassthrough,
 				meekConfig.MeekObfuscatedKey)
 			if err != nil {
 				return nil, errors.Trace(err)

+ 18 - 1
psiphon/server/config.go

@@ -170,6 +170,11 @@ type Config struct {
 	// "UNFRONTED-MEEK-HTTPS-OSSH", "UNFRONTED-MEEK-SESSION-TICKET-OSSH".
 	TunnelProtocolPassthroughAddresses map[string]string
 
+	// LegacyPassthrough indicates whether to expect legacy passthrough messages
+	// from clients attempting to connect. This should be set for existing/legacy
+	// passthrough servers only.
+	LegacyPassthrough bool
+
 	// SSHPrivateKey is the SSH host key. The same key is used for
 	// all protocols, run by this server instance, which use SSH.
 	SSHPrivateKey string
@@ -662,6 +667,8 @@ type GenerateConfigParams struct {
 	TacticsConfigFilename       string
 	TacticsRequestPublicKey     string
 	TacticsRequestObfuscatedKey string
+	Passthrough                 bool
+	LegacyPassthrough           bool
 }
 
 // GenerateConfig creates a new Psiphon server config. It returns JSON encoded
@@ -859,6 +866,7 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, []byt
 		OSLConfigFilename:              params.OSLConfigFilename,
 		TacticsConfigFilename:          params.TacticsConfigFilename,
 		MarionetteFormat:               params.MarionetteFormat,
+		LegacyPassthrough:              params.LegacyPassthrough,
 	}
 
 	encodedConfig, err := json.MarshalIndent(config, "\n", "    ")
@@ -952,7 +960,16 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, []byt
 	}
 
 	for tunnelProtocol := range params.TunnelProtocolPorts {
-		capabilities = append(capabilities, protocol.GetCapability(tunnelProtocol))
+
+		capability := protocol.GetCapability(tunnelProtocol)
+		if params.Passthrough && protocol.TunnelProtocolSupportsPassthrough(tunnelProtocol) {
+			if !params.LegacyPassthrough {
+				capability += "-PASSTHROUGH-v2"
+			} else {
+				capability += "-PASSTHROUGH"
+			}
+		}
+		capabilities = append(capabilities, capability)
 
 		if params.TacticsRequestPublicKey != "" && params.TacticsRequestObfuscatedKey != "" &&
 			protocol.TunnelProtocolUsesMeek(tunnelProtocol) {

+ 17 - 8
psiphon/server/meek.go

@@ -1105,13 +1105,14 @@ func (server *MeekServer) makeMeekTLSConfig(
 
 		config.PassthroughAddress = server.passthroughAddress
 
-		passthroughKey, err := obfuscator.DeriveTLSPassthroughKey(
-			server.support.Config.MeekObfuscatedKey)
-		if err != nil {
-			return nil, errors.Trace(err)
-		}
+		config.PassthroughVerifyMessage = func(
+			message []byte) bool {
 
-		config.PassthroughKey = passthroughKey
+			return obfuscator.VerifyTLSPassthroughMessage(
+				!server.support.Config.LegacyPassthrough,
+				server.support.Config.MeekObfuscatedKey,
+				message)
+		}
 
 		config.PassthroughLogInvalidMessage = func(
 			clientIP string) {
@@ -1129,14 +1130,22 @@ func (server *MeekServer) makeMeekTLSConfig(
 			clientIP string,
 			clientRandom []byte) bool {
 
+			// Use a custom, shorter TTL based on the validity period of the
+			// passthrough message.
+			TTL := obfuscator.TLS_PASSTHROUGH_TIME_PERIOD
+			if server.support.Config.LegacyPassthrough {
+				TTL = obfuscator.HISTORY_SEED_TTL
+			}
+
 			// strictMode is true as, unlike with meek cookies, legitimate meek clients
 			// never retry TLS connections using a previous random value.
 
-			ok, logFields := server.obfuscatorSeedHistory.AddNew(
+			ok, logFields := server.obfuscatorSeedHistory.AddNewWithTTL(
 				true,
 				clientIP,
 				"client-random",
-				clientRandom)
+				clientRandom,
+				TTL)
 
 			if logFields != nil {
 				logIrregularTunnel(

+ 10 - 0
psiphon/server/passthrough_test.go

@@ -44,6 +44,14 @@ import (
 )
 
 func TestPassthrough(t *testing.T) {
+	testPassthrough(t, false)
+}
+
+func TestLegacyPassthrough(t *testing.T) {
+	testPassthrough(t, true)
+}
+
+func testPassthrough(t *testing.T, legacy bool) {
 
 	psiphon.SetEmitDiagnosticNotices(true, true)
 
@@ -92,6 +100,8 @@ func TestPassthrough(t *testing.T) {
 		EnableSSHAPIRequests: true,
 		WebServerPort:        8000,
 		TunnelProtocolPorts:  map[string]int{tunnelProtocol: 4000},
+		Passthrough:          true,
+		LegacyPassthrough:    legacy,
 	}
 
 	serverConfigJSON, _, _, _, encodedServerEntry, err := GenerateConfig(generateConfigParams)

+ 4 - 3
vendor/github.com/Psiphon-Labs/tls-tris/common.go

@@ -678,9 +678,10 @@ type Config struct {
 	PassthroughAddress string
 
 	// [Psiphon]
-	// PassthroughKey must be set, to a value generated by
-	// obfuscator.DerivePassthroughKey, when passthrough mode is enabled.
-	PassthroughKey []byte
+	// PassthroughVerifyMessage must be set when passthrough mode is enabled. The
+	// function must return true for valid passthrough messages and false
+	// otherwise.
+	PassthroughVerifyMessage func([]byte) bool
 
 	// [Psiphon]
 	// PassthroughHistoryAddNew must be set when passthough mode is enabled. The

+ 1 - 4
vendor/github.com/Psiphon-Labs/tls-tris/handshake_server.go

@@ -17,8 +17,6 @@ import (
 	"net"
 	"sync/atomic"
 	"time"
-
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/obfuscator"
 )
 
 // serverHandshakeState contains details of a server handshake in progress.
@@ -105,8 +103,7 @@ func (c *Conn) serverHandshake() error {
 		clientIP, _, _ := net.SplitHostPort(clientAddr)
 
 		if !doPassthrough {
-			if !obfuscator.VerifyTLSPassthroughMessage(
-				c.config.PassthroughKey, hs.clientHello.random) {
+			if !c.config.PassthroughVerifyMessage(hs.clientHello.random) {
 
 				c.config.PassthroughLogInvalidMessage(clientIP)