Просмотр исходного кода

OSL download optimizations

- Option to omit MD5Sum fields for schemes where
  OSL context changes daily. In this case, the MD5Sum
  is counterproductive since it causes the OSL registry
  file to change frequently and doesn't save on OSL file
  downloads.

- Fully deterministic OSL registry generation to ensure
  repeated paves with the same config and parameters will
  not cause registry to change and be downloaded again.

- Option to omit empty OSL files from the pave/registry.
Rod Hynes 8 лет назад
Родитель
Сommit
bd04746e32

+ 146 - 53
psiphon/common/osl/osl.go

@@ -30,6 +30,8 @@
 package osl
 
 import (
+	"crypto/aes"
+	"crypto/cipher"
 	"crypto/hmac"
 	"crypto/md5"
 	"crypto/sha256"
@@ -49,9 +51,9 @@ import (
 	"sync/atomic"
 	"time"
 
-	"github.com/Psiphon-Inc/sss"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/nacl/secretbox"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/sss"
 )
 
 const (
@@ -780,7 +782,8 @@ type PaveLogInfo struct {
 // epoch to endTime, and a pave file for each OSL. paveServerEntries is
 // a map from hex-encoded OSL IDs to server entries to pave into that OSL.
 // When entries are found, OSL will contain those entries, newline
-// separated. Otherwise the OSL will still be issued, but be empty.
+// separated. Otherwise the OSL will still be issued, but be empty (unless
+// the scheme is in skipEmptyOSLsSchemes).
 //
 // As OSLs outside the epoch-endTime range will no longer appear in
 // the registry, Pave is intended to be used to create the full set
@@ -794,6 +797,8 @@ func (config *Config) Pave(
 	signingPublicKey string,
 	signingPrivateKey string,
 	paveServerEntries map[string][]string,
+	omitMD5SumsSchemes []int,
+	omitEmptyOSLsSchemes []int,
 	logCallback func(*PaveLogInfo)) ([]*PaveFile, error) {
 
 	config.ReloadableFile.RLock()
@@ -806,6 +811,10 @@ func (config *Config) Pave(
 	for schemeIndex, scheme := range config.Schemes {
 		if common.Contains(scheme.PropagationChannelIDs, propagationChannelID) {
 
+			omitMD5Sums := common.ContainsInt(omitMD5SumsSchemes, schemeIndex)
+
+			omitEmptyOSLs := common.ContainsInt(omitEmptyOSLsSchemes, schemeIndex)
+
 			oslDuration := scheme.GetOSLDuration()
 
 			oslTime := scheme.epoch
@@ -821,47 +830,52 @@ func (config *Config) Pave(
 
 				hexEncodedOSLID := hex.EncodeToString(fileSpec.ID)
 
-				registry.FileSpecs = append(registry.FileSpecs, fileSpec)
-
 				serverEntryCount := len(paveServerEntries[hexEncodedOSLID])
 
-				// serverEntries will be "" when nothing is found in paveServerEntries
-				serverEntries := strings.Join(paveServerEntries[hexEncodedOSLID], "\n")
+				if serverEntryCount > 0 || !omitEmptyOSLs {
 
-				serverEntriesPackage, err := common.WriteAuthenticatedDataPackage(
-					serverEntries,
-					signingPublicKey,
-					signingPrivateKey)
-				if err != nil {
-					return nil, common.ContextError(err)
-				}
+					registry.FileSpecs = append(registry.FileSpecs, fileSpec)
 
-				boxedServerEntries, err := box(fileKey, serverEntriesPackage)
-				if err != nil {
-					return nil, common.ContextError(err)
-				}
+					// serverEntries will be "" when nothing is found in paveServerEntries
+					serverEntries := strings.Join(paveServerEntries[hexEncodedOSLID], "\n")
+
+					serverEntriesPackage, err := common.WriteAuthenticatedDataPackage(
+						serverEntries,
+						signingPublicKey,
+						signingPrivateKey)
+					if err != nil {
+						return nil, common.ContextError(err)
+					}
+
+					boxedServerEntries, err := box(fileKey, serverEntriesPackage)
+					if err != nil {
+						return nil, common.ContextError(err)
+					}
 
-				md5sum := md5.Sum(boxedServerEntries)
-				fileSpec.MD5Sum = md5sum[:]
-
-				fileName := fmt.Sprintf(
-					OSL_FILENAME_FORMAT, hexEncodedOSLID)
-
-				paveFiles = append(paveFiles, &PaveFile{
-					Name:     fileName,
-					Contents: boxedServerEntries,
-				})
-
-				if logCallback != nil {
-					logCallback(&PaveLogInfo{
-						FileName:             fileName,
-						SchemeIndex:          schemeIndex,
-						PropagationChannelID: propagationChannelID,
-						OSLID:                hexEncodedOSLID,
-						OSLTime:              oslTime,
-						OSLDuration:          oslDuration,
-						ServerEntryCount:     serverEntryCount,
+					if !omitMD5Sums {
+						md5sum := md5.Sum(boxedServerEntries)
+						fileSpec.MD5Sum = md5sum[:]
+					}
+
+					fileName := fmt.Sprintf(
+						OSL_FILENAME_FORMAT, hexEncodedOSLID)
+
+					paveFiles = append(paveFiles, &PaveFile{
+						Name:     fileName,
+						Contents: boxedServerEntries,
 					})
+
+					if logCallback != nil {
+						logCallback(&PaveLogInfo{
+							FileName:             fileName,
+							SchemeIndex:          schemeIndex,
+							PropagationChannelID: propagationChannelID,
+							OSLID:                hexEncodedOSLID,
+							OSLTime:              oslTime,
+							OSLDuration:          oslDuration,
+							ServerEntryCount:     serverEntryCount,
+						})
+					}
 				}
 
 				oslTime = oslTime.Add(oslDuration)
@@ -936,19 +950,53 @@ func makeOSLFileSpec(
 	firstSLOK := scheme.deriveSLOK(ref)
 	oslID := firstSLOK.ID
 
-	// Note: previously, this was a random key. Now, the file key
+	// Note: previously, fileKey was a random key. Now, the key
 	// is derived from the master key and OSL ID. This deterministic
 	// derivation ensures that repeated paves of the same OSL
 	// with the same ID and same content yields the same MD5Sum
 	// to avoid wasteful downloads.
+	//
+	// Similarly, the shareKeys generated in divideKey and the Shamir
+	// key splitting random polynomials are now both determinisitcally
+	// generated from a seeded CSPRNG. This ensures that the OSL
+	// registry remains identical for repeated paves of the same config
+	// and parameters.
+	//
+	// The split structure is added to the deterministic key
+	// derivation so that changes to the split configuration will not
+	// expose the same key material to different SLOK combinations.
+
+	splitStructure := make([]byte, 16*(1+len(scheme.SeedPeriodKeySplits)))
+	i := 0
+	binary.LittleEndian.PutUint64(splitStructure[i:], uint64(len(scheme.SeedSpecs)))
+	binary.LittleEndian.PutUint64(splitStructure[i+8:], uint64(scheme.SeedSpecThreshold))
+	i += 16
+	for _, keySplit := range scheme.SeedPeriodKeySplits {
+		binary.LittleEndian.PutUint64(splitStructure[i:], uint64(keySplit.Total))
+		binary.LittleEndian.PutUint64(splitStructure[i+8:], uint64(keySplit.Threshold))
+		i += 16
+	}
 
 	fileKey := deriveKeyHKDF(
 		scheme.MasterKey,
+		splitStructure,
 		[]byte("osl-file-key"),
 		oslID)
 
+	splitKeyMaterialSeed := deriveKeyHKDF(
+		scheme.MasterKey,
+		splitStructure,
+		[]byte("osl-file-split-key-material-seed"),
+		oslID)
+
+	keyMaterialReader, err := newSeededKeyMaterialReader(splitKeyMaterialSeed)
+	if err != nil {
+		return nil, nil, common.ContextError(err)
+	}
+
 	keyShares, err := divideKey(
 		scheme,
+		keyMaterialReader,
 		fileKey,
 		scheme.SeedPeriodKeySplits,
 		propagationChannelID,
@@ -968,6 +1016,7 @@ func makeOSLFileSpec(
 // divideKey recursively constructs a KeyShares tree.
 func divideKey(
 	scheme *Scheme,
+	keyMaterialReader io.Reader,
 	key []byte,
 	keySplits []KeySplit,
 	propagationChannelID string,
@@ -976,7 +1025,11 @@ func divideKey(
 	keySplitIndex := len(keySplits) - 1
 	keySplit := keySplits[keySplitIndex]
 
-	shares, err := shamirSplit(key, keySplit.Total, keySplit.Threshold)
+	shares, err := shamirSplit(
+		key,
+		keySplit.Total,
+		keySplit.Threshold,
+		keyMaterialReader)
 	if err != nil {
 		return nil, common.ContextError(err)
 	}
@@ -986,15 +1039,12 @@ func divideKey(
 
 	for _, share := range shares {
 
-		// Note: for a fully deterministic pave, where the OSL registry
-		// is unchanged when no OSLs change, the share key would need
-		// to be derived (e.g., from the master key, OSL ID, key split
-		// index, and share index). However, since the OSL registry file
-		// content is nondeterministic in any case due to aspects of the
-		// Shamir secret splitting algorithm, there's no reason not to
-		// use a random key here.
+		var shareKey [KEY_LENGTH_BYTES]byte
 
-		shareKey, err := common.MakeSecureRandomBytes(KEY_LENGTH_BYTES)
+		n, err := keyMaterialReader.Read(shareKey[:])
+		if err == nil && n != len(shareKey) {
+			err = errors.New("unexpected length")
+		}
 		if err != nil {
 			return nil, common.ContextError(err)
 		}
@@ -1002,7 +1052,8 @@ func divideKey(
 		if keySplitIndex > 0 {
 			keyShare, err := divideKey(
 				scheme,
-				shareKey,
+				keyMaterialReader,
+				shareKey[:],
 				keySplits[0:keySplitIndex],
 				propagationChannelID,
 				nextSLOKTime)
@@ -1013,7 +1064,8 @@ func divideKey(
 		} else {
 			keyShare, err := divideKeyWithSeedSpecSLOKs(
 				scheme,
-				shareKey,
+				keyMaterialReader,
+				shareKey[:],
 				propagationChannelID,
 				nextSLOKTime)
 			if err != nil {
@@ -1023,7 +1075,7 @@ func divideKey(
 
 			*nextSLOKTime = nextSLOKTime.Add(time.Duration(scheme.SeedPeriodNanoseconds))
 		}
-		boxedShare, err := box(shareKey, share)
+		boxedShare, err := box(shareKey[:], share)
 		if err != nil {
 			return nil, common.ContextError(err)
 		}
@@ -1040,6 +1092,7 @@ func divideKey(
 
 func divideKeyWithSeedSpecSLOKs(
 	scheme *Scheme,
+	keyMaterialReader io.Reader,
 	key []byte,
 	propagationChannelID string,
 	nextSLOKTime *time.Time) (*KeyShares, error) {
@@ -1048,7 +1101,10 @@ func divideKeyWithSeedSpecSLOKs(
 	var slokIDs [][]byte
 
 	shares, err := shamirSplit(
-		key, len(scheme.SeedSpecs), scheme.SeedSpecThreshold)
+		key,
+		len(scheme.SeedSpecs),
+		scheme.SeedSpecThreshold,
+		keyMaterialReader)
 	if err != nil {
 		return nil, common.ContextError(err)
 	}
@@ -1352,6 +1408,38 @@ func NewOSLReader(
 		signingPublicKey)
 }
 
+// zeroReader reads an unlimited stream of zeroes.
+type zeroReader struct {
+}
+
+func (z *zeroReader) Read(p []byte) (int, error) {
+	for i := 0; i < len(p); i++ {
+		p[i] = 0
+	}
+	return len(p), nil
+}
+
+// newSeededKeyMaterialReader constructs a CSPRNG using AES-CTR.
+// The seed is the AES key and the IV is fixed and constant.
+// Using same seed will always produce the same output stream.
+// The data stream is intended to be used to determinisically
+// generate key material and is not intended as a general
+// purpose CSPRNG.
+func newSeededKeyMaterialReader(seed []byte) (io.Reader, error) {
+
+	aesCipher, err := aes.NewCipher(seed)
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+
+	var iv [aes.BlockSize]byte
+
+	return &cipher.StreamReader{
+		S: cipher.NewCTR(aesCipher, iv[:]),
+		R: new(zeroReader),
+	}, nil
+}
+
 // deriveKeyHKDF implements HKDF-Expand as defined in https://tools.ietf.org/html/rfc5869
 // where masterKey = PRK, context = info, and L = 32; SHA-256 is used so HashLen = 32
 func deriveKeyHKDF(masterKey []byte, context ...[]byte) []byte {
@@ -1372,7 +1460,11 @@ func isValidShamirSplit(total, threshold int) bool {
 }
 
 // shamirSplit is a helper wrapper for sss.Split
-func shamirSplit(secret []byte, total, threshold int) ([][]byte, error) {
+func shamirSplit(
+	secret []byte,
+	total, threshold int,
+	randReader io.Reader) ([][]byte, error) {
+
 	if !isValidShamirSplit(total, threshold) {
 		return nil, common.ContextError(errors.New("invalid parameters"))
 	}
@@ -1386,7 +1478,8 @@ func shamirSplit(secret []byte, total, threshold int) ([][]byte, error) {
 		return shares, nil
 	}
 
-	shareMap, err := sss.Split(byte(total), byte(threshold), secret)
+	shareMap, err := sss.SplitUsingReader(
+		byte(total), byte(threshold), secret, randReader)
 	if err != nil {
 		return nil, common.ContextError(err)
 	}

+ 34 - 0
psiphon/common/osl/osl_test.go

@@ -358,22 +358,56 @@ func TestOSL(t *testing.T) {
 				}
 			}
 
+			// Note: these options are exercised in remoteServerList_test.go
+			omitMD5SumsSchemes := []int{}
+			omitEmptyOSLsSchemes := []int{}
+
+			firstPaveFiles, err := config.Pave(
+				endTime,
+				propagationChannelID,
+				signingPublicKey,
+				signingPrivateKey,
+				paveServerEntries,
+				omitMD5SumsSchemes,
+				omitEmptyOSLsSchemes,
+				nil)
+			if err != nil {
+				t.Fatalf("Pave failed: %s", err)
+			}
+
 			paveFiles, err := config.Pave(
 				endTime,
 				propagationChannelID,
 				signingPublicKey,
 				signingPrivateKey,
 				paveServerEntries,
+				omitMD5SumsSchemes,
+				omitEmptyOSLsSchemes,
 				nil)
 			if err != nil {
 				t.Fatalf("Pave failed: %s", err)
 			}
 
 			// Check that the paved file name matches the name the client will look for.
+
 			if len(paveFiles) < 1 || paveFiles[len(paveFiles)-1].Name != GetOSLRegistryURL("") {
 				t.Fatalf("invalid registry pave file")
 			}
 
+			// Check that the content of two paves is the same: all the crypto should be
+			// deterministc.
+
+			for index, paveFile := range paveFiles {
+				if paveFile.Name != firstPaveFiles[index].Name {
+					t.Fatalf("Pave name mismatch")
+				}
+				if bytes.Compare(paveFile.Contents, firstPaveFiles[index].Contents) != 0 {
+					t.Fatalf("Pave content mismatch")
+				}
+			}
+
+			// Use the paved content in the following tests.
+
 			pavedRegistries[propagationChannelID] = paveFiles[len(paveFiles)-1].Contents
 
 			pavedOSLFileContents[propagationChannelID] = make(map[string][]byte)

+ 24 - 0
psiphon/common/osl/paver/main.go

@@ -29,6 +29,7 @@ import (
 	"io/ioutil"
 	"os"
 	"path/filepath"
+	"strconv"
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
@@ -63,6 +64,12 @@ func main() {
 	var listScheme int
 	flag.IntVar(&listScheme, "list-scheme", -1, "list current period OSL IDs for specified scheme; no files are written")
 
+	var omitMD5SumsSchemes ints
+	flag.Var(&omitMD5SumsSchemes, "omit-md5sums", "omit MD5Sum fields for specified scheme(s)")
+
+	var omitEmptyOSLsSchemes ints
+	flag.Var(&omitEmptyOSLsSchemes, "omit-empty", "omit empty OSLs for specified scheme(s)")
+
 	flag.Parse()
 
 	// load config
@@ -211,6 +218,8 @@ func main() {
 			signingPublicKey,
 			signingPrivateKey,
 			paveServerEntries,
+			omitMD5SumsSchemes,
+			omitEmptyOSLsSchemes,
 			func(logInfo *osl.PaveLogInfo) {
 				pavedPayloadOSLID[logInfo.OSLID] = true
 				fmt.Printf(
@@ -266,3 +275,18 @@ func main() {
 		os.Exit(1)
 	}
 }
+
+type ints []int
+
+func (i *ints) String() string {
+	return fmt.Sprint(*i)
+}
+
+func (i *ints) Set(strValue string) error {
+	value, err := strconv.Atoi(strValue)
+	if err != nil {
+		return err
+	}
+	*i = append(*i, value)
+	return nil
+}

+ 11 - 0
psiphon/common/utils.go

@@ -48,6 +48,17 @@ func Contains(list []string, target string) bool {
 	return false
 }
 
+// ContainsInt returns true if the target int is
+// in the list.
+func ContainsInt(list []int, target int) bool {
+	for _, listItem := range list {
+		if listItem == target {
+			return true
+		}
+	}
+	return false
+}
+
 // FlipCoin is a helper function that randomly
 // returns true or false. If the underlying random
 // number generator fails, FlipCoin still returns

+ 48 - 17
psiphon/remoteServerList_test.go

@@ -35,6 +35,7 @@ import (
 	"path"
 	"path/filepath"
 	"sync"
+	"syscall"
 	"testing"
 	"time"
 
@@ -47,6 +48,14 @@ import (
 // TODO: TestCommonRemoteServerList (this is currently covered by controller_test.go)
 
 func TestObfuscatedRemoteServerLists(t *testing.T) {
+	testObfuscatedRemoteServerLists(t, false)
+}
+
+func TestObfuscatedRemoteServerListsOmitMD5Sums(t *testing.T) {
+	testObfuscatedRemoteServerLists(t, true)
+}
+
+func testObfuscatedRemoteServerLists(t *testing.T, omitMD5Sums bool) {
 
 	testDataDirName, err := ioutil.TempDir("", "psiphon-remote-server-list-test")
 	if err != nil {
@@ -150,6 +159,12 @@ func TestObfuscatedRemoteServerLists(t *testing.T) {
 		t.Fatalf("error generating package keys: %s", err)
 	}
 
+	var omitMD5SumsSchemes []int
+	if omitMD5Sums {
+		omitMD5SumsSchemes = []int{0}
+	}
+	omitEmptyOSLsSchemes := []int{0}
+
 	// First Pave() call is to get the OSL ID to pave into
 
 	oslID := ""
@@ -160,6 +175,8 @@ func TestObfuscatedRemoteServerLists(t *testing.T) {
 		signingPublicKey,
 		signingPrivateKey,
 		map[string][]string{},
+		omitMD5SumsSchemes,
+		omitEmptyOSLsSchemes,
 		func(logInfo *osl.PaveLogInfo) {
 			oslID = logInfo.OSLID
 		})
@@ -175,6 +192,8 @@ func TestObfuscatedRemoteServerLists(t *testing.T) {
 		map[string][]string{
 			oslID: {string(encodedServerEntry)},
 		},
+		omitMD5SumsSchemes,
+		omitEmptyOSLsSchemes,
 		nil)
 	if err != nil {
 		t.Fatalf("error paving OSL files: %s", err)
@@ -211,9 +230,17 @@ func TestObfuscatedRemoteServerLists(t *testing.T) {
 	//
 
 	// Exercise using multiple download URLs
-	remoteServerListHostAddresses := []string{
-		net.JoinHostPort(serverIPAddress, "8081"),
-		net.JoinHostPort(serverIPAddress, "8082"),
+
+	var remoteServerListListeners [2]net.Listener
+	var remoteServerListHostAddresses [2]string
+
+	for i := 0; i < len(remoteServerListListeners); i++ {
+		remoteServerListListeners[i], err = net.Listen("tcp", net.JoinHostPort(serverIPAddress, "0"))
+		if err != nil {
+			t.Fatalf("net.Listen error: %s", err)
+		}
+		defer remoteServerListListeners[i].Close()
+		remoteServerListHostAddresses[i] = remoteServerListListeners[i].Addr().String()
 	}
 
 	// The common remote server list fetches will 404
@@ -234,7 +261,7 @@ func TestObfuscatedRemoteServerLists(t *testing.T) {
 			obfuscatedServerListRootURLsJSONConfig += ","
 		}
 
-		go func(remoteServerListHostAddress string) {
+		go func(listener net.Listener, remoteServerListHostAddress string) {
 			startTime := time.Now()
 			serveMux := http.NewServeMux()
 			for _, paveFile := range paveFiles {
@@ -250,12 +277,8 @@ func TestObfuscatedRemoteServerLists(t *testing.T) {
 				Addr:    remoteServerListHostAddress,
 				Handler: serveMux,
 			}
-			err := httpServer.ListenAndServe()
-			if err != nil {
-				// TODO: wrong goroutine for t.FatalNow()
-				t.Fatalf("error running remote server list host: %s", err)
-			}
-		}(remoteServerListHostAddresses[i])
+			httpServer.Serve(listener)
+		}(remoteServerListListeners[i], remoteServerListHostAddresses[i])
 	}
 
 	obfuscatedServerListDownloadDirectory := testDataDirName
@@ -272,19 +295,27 @@ func TestObfuscatedRemoteServerLists(t *testing.T) {
 		}
 	}()
 
+	process, err := os.FindProcess(os.Getpid())
+	if err != nil {
+		t.Fatalf("os.FindProcess error: %s", err)
+	}
+	defer process.Signal(syscall.SIGTERM)
+
 	//
 	// disrupt remote server list downloads
 	//
 
-	disruptorProxyAddress := "127.0.0.1:2162"
+	disruptorListener, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("net.Listen error: %s", err)
+	}
+	defer disruptorListener.Close()
+
+	disruptorProxyAddress := disruptorListener.Addr().String()
 	disruptorProxyURL := "socks4a://" + disruptorProxyAddress
 
 	go func() {
-		listener, err := socks.ListenSocks("tcp", disruptorProxyAddress)
-		if err != nil {
-			fmt.Printf("disruptor proxy listen error: %s\n", err)
-			return
-		}
+		listener := socks.NewSocksListener(disruptorListener)
 		for {
 			localConn, err := listener.AcceptSocks()
 			if err != nil {
@@ -309,7 +340,7 @@ func TestObfuscatedRemoteServerLists(t *testing.T) {
 					defer waitGroup.Done()
 					io.Copy(remoteConn, localConn)
 				}()
-				if common.Contains(remoteServerListHostAddresses, localConn.Req.Target) {
+				if common.Contains(remoteServerListHostAddresses[:], localConn.Req.Target) {
 					io.CopyN(localConn, remoteConn, 500)
 				} else {
 					io.Copy(localConn, remoteConn)