Просмотр исходного кода

Re-key ObfuscatedPacketConn

- Single server-side packet conn instance can exceed 2^38 key stream limit.

- Revert 297430f.
Rod Hynes 7 лет назад
Родитель
Сommit
44ca901ad7
3 измененных файлов с 116 добавлено и 92 удалено
  1. 66 38
      psiphon/common/quic/obfuscator.go
  2. 34 6
      psiphon/common/quic/obfuscator_test.go
  3. 16 48
      psiphon/common/quic/quic.go

+ 66 - 38
psiphon/common/quic/obfuscator.go

@@ -42,6 +42,7 @@ const (
 	MAX_OBFUSCATED_QUIC_IPV6_PACKET_SIZE = 1352
 	MAX_OBFUSCATED_QUIC_IPV6_PACKET_SIZE = 1352
 	MAX_PADDING                          = 64
 	MAX_PADDING                          = 64
 	NONCE_SIZE                           = 12
 	NONCE_SIZE                           = 12
+	RANDOM_STREAM_LIMIT                  = 1 << 38
 )
 )
 
 
 // ObfuscatedPacketConn wraps a QUIC net.PacketConn with an obfuscation layer
 // ObfuscatedPacketConn wraps a QUIC net.PacketConn with an obfuscation layer
@@ -70,6 +71,7 @@ type ObfuscatedPacketConn struct {
 
 
 	randomStreamMutex sync.Mutex
 	randomStreamMutex sync.Mutex
 	randomStream      *chacha20.Cipher
 	randomStream      *chacha20.Cipher
+	randomStreamCount int64
 }
 }
 
 
 type peerMode struct {
 type peerMode struct {
@@ -82,15 +84,16 @@ func (p *peerMode) isStale() bool {
 }
 }
 
 
 // NewObfuscatedPacketConnPacketConn creates a new ObfuscatedPacketConn.
 // NewObfuscatedPacketConnPacketConn creates a new ObfuscatedPacketConn.
-func NewObfuscatedPacketConnPacketConn(
+func NewObfuscatedPacketConn(
 	conn net.PacketConn,
 	conn net.PacketConn,
 	isServer bool,
 	isServer bool,
 	obfuscationKey string) (*ObfuscatedPacketConn, error) {
 	obfuscationKey string) (*ObfuscatedPacketConn, error) {
 
 
 	packetConn := &ObfuscatedPacketConn{
 	packetConn := &ObfuscatedPacketConn{
-		PacketConn: conn,
-		isServer:   isServer,
-		peerModes:  make(map[string]*peerMode),
+		PacketConn:        conn,
+		isServer:          isServer,
+		peerModes:         make(map[string]*peerMode),
+		randomStreamCount: RANDOM_STREAM_LIMIT,
 	}
 	}
 
 
 	secret := []byte(obfuscationKey)
 	secret := []byte(obfuscationKey)
@@ -101,26 +104,6 @@ func NewObfuscatedPacketConnPacketConn(
 		return nil, common.ContextError(err)
 		return nil, common.ContextError(err)
 	}
 	}
 
 
-	// Use a stream cipher to generate randomness for padding. This mitigates
-	// issues using a high volume (multiple per packet) of crypto/rand.Read
-	// calls, which use getrandom via a syscall; under high load with many
-	// clients, we observed very long syscall durations, perhaps due to lock
-	// contention. Using a userspace random stream avoids frequent syscall
-	// context switches as well as spinlock overhead.
-
-	var randomStreamKey [32]byte
-	_, err = rand.Read(randomStreamKey[:])
-	if err != nil {
-		return nil, common.ContextError(err)
-	}
-	var randomKeyNonce [NONCE_SIZE]byte
-	packetConn.randomStream, err = chacha20.NewCipher(
-		randomStreamKey[:],
-		randomKeyNonce[:])
-	if err != nil {
-		return nil, common.ContextError(err)
-	}
-
 	if isServer {
 	if isServer {
 
 
 		packetConn.runWaitGroup = new(sync.WaitGroup)
 		packetConn.runWaitGroup = new(sync.WaitGroup)
@@ -318,9 +301,10 @@ func (conn *ObfuscatedPacketConn) WriteTo(p []byte, addr net.Addr) (int, error)
 
 
 		nonce := buffer[0:NONCE_SIZE]
 		nonce := buffer[0:NONCE_SIZE]
 		for {
 		for {
-			conn.randomStreamMutex.Lock()
-			conn.randomStream.KeyStream(nonce)
-			conn.randomStreamMutex.Unlock()
+			err := conn.getRandomBytes(nonce)
+			if err != nil {
+				return 0, common.ContextError(err)
+			}
 
 
 			// Don't use a random nonce that looks like QUIC, or the
 			// Don't use a random nonce that looks like QUIC, or the
 			// peer will not treat this packet as obfuscated.
 			// peer will not treat this packet as obfuscated.
@@ -340,14 +324,18 @@ func (conn *ObfuscatedPacketConn) WriteTo(p []byte, addr net.Addr) (int, error)
 			maxPaddingLen = MAX_PADDING
 			maxPaddingLen = MAX_PADDING
 		}
 		}
 
 
-		paddingLen := conn.getPaddingLen(maxPaddingLen)
+		paddingLen, err := conn.getRandomPaddingLen(maxPaddingLen)
+		if err != nil {
+			return 0, common.ContextError(err)
+		}
 
 
 		buffer[NONCE_SIZE] = uint8(paddingLen)
 		buffer[NONCE_SIZE] = uint8(paddingLen)
 
 
 		padding := buffer[(NONCE_SIZE + 1) : (NONCE_SIZE+1)+paddingLen]
 		padding := buffer[(NONCE_SIZE + 1) : (NONCE_SIZE+1)+paddingLen]
-		conn.randomStreamMutex.Lock()
-		conn.randomStream.KeyStream(padding)
-		conn.randomStreamMutex.Unlock()
+		err = conn.getRandomBytes(padding)
+		if err != nil {
+			return 0, common.ContextError(err)
+		}
 
 
 		copy(buffer[(NONCE_SIZE+1)+paddingLen:], p)
 		copy(buffer[(NONCE_SIZE+1)+paddingLen:], p)
 		dataLen := (NONCE_SIZE + 1) + paddingLen + n
 		dataLen := (NONCE_SIZE + 1) + paddingLen + n
@@ -367,7 +355,46 @@ func (conn *ObfuscatedPacketConn) WriteTo(p []byte, addr net.Addr) (int, error)
 	return n, err
 	return n, err
 }
 }
 
 
-func (conn *ObfuscatedPacketConn) getPaddingLen(maxPadding int) int {
+func (conn *ObfuscatedPacketConn) getRandomBytes(b []byte) error {
+	conn.randomStreamMutex.Lock()
+	defer conn.randomStreamMutex.Unlock()
+
+	// Use a stream cipher to generate randomness for padding. This mitigates
+	// issues using a high volume (multiple per packet) of crypto/rand.Read
+	// calls, which use getrandom via a syscall; under high load with many
+	// clients, we observed very long syscall durations, perhaps due to lock
+	// contention. Using a userspace random stream avoids frequent syscall
+	// context switches as well as spinlock overhead.
+
+	if conn.randomStreamCount+int64(len(b)) >= RANDOM_STREAM_LIMIT {
+
+		// Re-key before reaching the 2^38 chacha20 key stream limit.
+
+		var randomStreamKey [32]byte
+		_, err := rand.Read(randomStreamKey[:])
+		if err != nil {
+			return common.ContextError(err)
+		}
+		var randomKeyNonce [NONCE_SIZE]byte
+
+		conn.randomStream, err = chacha20.NewCipher(
+			randomStreamKey[:],
+			randomKeyNonce[:])
+		if err != nil {
+			return common.ContextError(err)
+		}
+
+		conn.randomStreamCount = 0
+	}
+
+	conn.randomStream.KeyStream(b)
+
+	conn.randomStreamCount += int64(len(b))
+
+	return nil
+}
+
+func (conn *ObfuscatedPacketConn) getRandomPaddingLen(maxPadding int) (int, error) {
 
 
 	// Selects uniformly from [0, maxPadding], using the ObfuscatedPacketConn's
 	// Selects uniformly from [0, maxPadding], using the ObfuscatedPacketConn's
 	// random stream.
 	// random stream.
@@ -375,11 +402,11 @@ func (conn *ObfuscatedPacketConn) getPaddingLen(maxPadding int) int {
 	maxRand := 255
 	maxRand := 255
 
 
 	if maxPadding < 0 || maxPadding > maxRand {
 	if maxPadding < 0 || maxPadding > maxRand {
-		panic(fmt.Sprintf("unexpected max padding: %d", maxPadding))
+		return 0, common.ContextError(fmt.Errorf("unexpected max padding: %d", maxPadding))
 	}
 	}
 
 
 	if maxPadding == 0 {
 	if maxPadding == 0 {
-		return 0
+		return 0, nil
 	}
 	}
 
 
 	upperBound := maxPadding
 	upperBound := maxPadding
@@ -389,13 +416,14 @@ func (conn *ObfuscatedPacketConn) getPaddingLen(maxPadding int) int {
 
 
 	for {
 	for {
 		var value [1]byte
 		var value [1]byte
-		conn.randomStreamMutex.Lock()
-		conn.randomStream.KeyStream(value[:])
-		conn.randomStreamMutex.Unlock()
+		err := conn.getRandomBytes(value[:])
+		if err != nil {
+			return 0, common.ContextError(err)
+		}
 
 
 		padding := int(value[0])
 		padding := int(value[0])
 		if padding <= upperBound {
 		if padding <= upperBound {
-			return padding % (maxPadding + 1)
+			return padding % (maxPadding + 1), nil
 		}
 		}
 	}
 	}
 }
 }

+ 34 - 6
psiphon/common/quic/obfuscator_test.go

@@ -27,9 +27,9 @@ import (
 
 
 func TestPaddingLen(t *testing.T) {
 func TestPaddingLen(t *testing.T) {
 
 
-	c, err := NewObfuscatedPacketConnPacketConn(nil, false, "key")
+	c, err := NewObfuscatedPacketConn(nil, false, "key")
 	if err != nil {
 	if err != nil {
-		t.Fatalf("NewObfuscatedPacketConnPacketConn failed: %s", err)
+		t.Fatalf("NewObfuscatedPacketConn failed: %s", err)
 	}
 	}
 
 
 	for max := 0; max <= 255; max++ {
 	for max := 0; max <= 255; max++ {
@@ -38,7 +38,10 @@ func TestPaddingLen(t *testing.T) {
 		repeats := 200000
 		repeats := 200000
 
 
 		for r := 0; r < repeats; r++ {
 		for r := 0; r < repeats; r++ {
-			padding := c.getPaddingLen(max)
+			padding, err := c.getRandomPaddingLen(max)
+			if err != nil {
+				t.Fatalf("getRandomPaddingLen failed: %s", err)
+			}
 			if padding < 0 || padding > max {
 			if padding < 0 || padding > max {
 				t.Fatalf("unexpected padding: max = %d, padding = %d", max, padding)
 				t.Fatalf("unexpected padding: max = %d, padding = %d", max, padding)
 			}
 			}
@@ -56,16 +59,41 @@ func TestPaddingLen(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func Disabled_TestPaddingLenLimit(t *testing.T) {
+
+	// This test takes up to ~2 minute to complete, so it's disabled by default.
+
+	c, err := NewObfuscatedPacketConn(nil, false, "key")
+	if err != nil {
+		t.Fatalf("NewObfuscatedPacketConn failed: %s", err)
+	}
+
+	var b [2 * 1024 * 1024 * 1024]byte
+	n := int64(0)
+
+	for {
+		err := c.getRandomBytes(b[:])
+		if err != nil {
+			t.Fatalf("getRandomBytes failed: %s", err)
+		}
+		n += int64(len(b))
+		if n > (1<<38)+1 {
+			// We're past the chacha20 key stream limit.
+			break
+		}
+	}
+}
+
 func BenchmarkPaddingLen(b *testing.B) {
 func BenchmarkPaddingLen(b *testing.B) {
 
 
-	c, err := NewObfuscatedPacketConnPacketConn(nil, false, "key")
+	c, err := NewObfuscatedPacketConn(nil, false, "key")
 	if err != nil {
 	if err != nil {
-		b.Fatalf("NewObfuscatedPacketConnPacketConn failed: %s", err)
+		b.Fatalf("NewObfuscatedPacketConn failed: %s", err)
 	}
 	}
 
 
 	b.Run("getPaddingLen", func(b *testing.B) {
 	b.Run("getPaddingLen", func(b *testing.B) {
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			_ = c.getPaddingLen(n % MAX_PADDING)
+			_, _ = c.getRandomPaddingLen(n % MAX_PADDING)
 		}
 		}
 	})
 	})
 
 

+ 16 - 48
psiphon/common/quic/quic.go

@@ -43,7 +43,6 @@ package quic
 import (
 import (
 	"context"
 	"context"
 	"crypto/tls"
 	"crypto/tls"
-	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net"
 	"net"
@@ -113,15 +112,15 @@ func Listen(
 	}
 	}
 
 
 	var packetConn net.PacketConn
 	var packetConn net.PacketConn
-	packetConn, err = NewObfuscatedPacketConnPacketConn(
+	packetConn, err = NewObfuscatedPacketConn(
 		udpConn, true, obfuscationKey)
 		udpConn, true, obfuscationKey)
 	if err != nil {
 	if err != nil {
 		return nil, common.ContextError(err)
 		return nil, common.ContextError(err)
 	}
 	}
 
 
 	// This wrapping must be outermost to ensure that all
 	// This wrapping must be outermost to ensure that all
-	// ReadFrom/WriteTo calls are intercepted.
-	packetConn = newWorkaroundPacketConn(logger, packetConn)
+	// ReadFrom errors are intercepted and logged.
+	packetConn = newLoggingPacketConn(logger, packetConn)
 
 
 	quicListener, err := quic_go.Listen(
 	quicListener, err := quic_go.Listen(
 		packetConn, tlsConfig, quicConfig)
 		packetConn, tlsConfig, quicConfig)
@@ -199,7 +198,7 @@ func Dial(
 
 
 	if negotiateQUICVersion == protocol.QUIC_VERSION_OBFUSCATED {
 	if negotiateQUICVersion == protocol.QUIC_VERSION_OBFUSCATED {
 		var err error
 		var err error
-		packetConn, err = NewObfuscatedPacketConnPacketConn(
+		packetConn, err = NewObfuscatedPacketConn(
 			packetConn, false, obfuscationKey)
 			packetConn, false, obfuscationKey)
 		if err != nil {
 		if err != nil {
 			return nil, common.ContextError(err)
 			return nil, common.ContextError(err)
@@ -421,7 +420,7 @@ func isErrorIndicatingClosed(err error) bool {
 	return false
 	return false
 }
 }
 
 
-// workaroundPacketConn is a workaround for issues in the quic-go server (as of
+// loggingPacketConn is a workaround for issues in the quic-go server (as of
 // revision ffdfa1).
 // revision ffdfa1).
 //
 //
 // 1. quic-go will shutdown the QUIC server on any error returned from
 // 1. quic-go will shutdown the QUIC server on any error returned from
@@ -441,45 +440,33 @@ func isErrorIndicatingClosed(err error) bool {
 //    packetHandlerMap and its mutex are used by all client sessions, this
 //    packetHandlerMap and its mutex are used by all client sessions, this
 //    effectively hangs the entire server.
 //    effectively hangs the entire server.
 //
 //
-// 3. In certain cases, quic-go appears to get into a state where it
-//    calls WriteTo in an unconstrained loop, far exceeding the expected
-//    rate for normal outbound traffic. This state pegs the psiphond
-//    CPU and, in the case of obfuscated QUIC, exhausts the 2^38 byte
-//    random padding key stream. To mitigate this, we rate limit
-//    workaroundPacketConn when an excessive WriteTo call rate is
-//    detected.
-//
-// workaroundPacketConn checks PacketConn ReadFrom errors and returns any usable
-// values or loops and calls ReadFrom again. In practise, due to the nature of
-// UDP sockets, ReadFrom errors are exceptional as they will most likely not
+// loggingPacketConn PacketConn ReadFrom errors and returns any usable values
+// or loops and calls ReadFrom again. In practise, due to the nature of UDP
+// sockets, ReadFrom errors are exceptional as they will mosyt likely not
 // occur due to network transmission failures. ObfuscatedPacketConn returns
 // occur due to network transmission failures. ObfuscatedPacketConn returns
 // errors that could be due to network transmission failures that corrupt
 // errors that could be due to network transmission failures that corrupt
-// packets; these are marked as net.Error.Temporary() and workaroundPacketConn
+// packets; these are marked as net.Error.Temporary() and loggingPacketConn
 // logs these at debug level.
 // logs these at debug level.
 //
 //
-// workaroundPacketConn assumes specific quic-go behavior and will break other
-// use cases, such as setting deadlines and expecting net.Error.Timeout()
+// loggingPacketConn assumes quic-go revision ffdfa1 behavior and will break
+// other behavior, such as setting deadlines and expecting net.Error.Timeout()
 // errors from ReadFrom.
 // errors from ReadFrom.
-type workaroundPacketConn struct {
+type loggingPacketConn struct {
 	net.PacketConn
 	net.PacketConn
 	logger common.Logger
 	logger common.Logger
-
-	mutex      sync.Mutex
-	currentMS  time.Time
-	callsPerMS int
 }
 }
 
 
-func newWorkaroundPacketConn(
+func newLoggingPacketConn(
 	logger common.Logger,
 	logger common.Logger,
-	packetConn net.PacketConn) *workaroundPacketConn {
+	packetConn net.PacketConn) *loggingPacketConn {
 
 
-	return &workaroundPacketConn{
+	return &loggingPacketConn{
 		PacketConn: packetConn,
 		PacketConn: packetConn,
 		logger:     logger,
 		logger:     logger,
 	}
 	}
 }
 }
 
 
-func (conn *workaroundPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
+func (conn *loggingPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
 
 
 	for {
 	for {
 		n, addr, err := conn.PacketConn.ReadFrom(p)
 		n, addr, err := conn.PacketConn.ReadFrom(p)
@@ -501,22 +488,3 @@ func (conn *workaroundPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
 		}
 		}
 	}
 	}
 }
 }
-
-func (conn *workaroundPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
-
-	conn.mutex.Lock()
-	currentMS := time.Now().Round(time.Millisecond)
-	if currentMS != conn.currentMS {
-		conn.currentMS = currentMS
-		conn.callsPerMS = 0
-	} else {
-		if conn.callsPerMS >= 1000 {
-			conn.mutex.Unlock()
-			return 0, common.ContextError(errors.New("rate limit exceeded"))
-		}
-		conn.callsPerMS += 1
-	}
-	conn.mutex.Unlock()
-
-	return conn.PacketConn.WriteTo(p, addr)
-}