Эх сурвалжийг харах

Add StreamingReadAuthenticatedDataPackage

Rod Hynes 8 жил өмнө
parent
commit
9fe15b6af3

+ 382 - 6
psiphon/common/authPackage.go

@@ -21,6 +21,7 @@ package common
 
 import (
 	"bytes"
+	"compress/zlib"
 	"crypto"
 	"crypto/rand"
 	"crypto/rsa"
@@ -29,6 +30,11 @@ import (
 	"encoding/base64"
 	"encoding/json"
 	"errors"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"os"
+	"sync"
 )
 
 // AuthenticatedDataPackage is a JSON record containing some Psiphon data
@@ -63,9 +69,8 @@ func GenerateAuthenticatedDataPackageKeys() (string, string, error) {
 }
 
 func sha256sum(data string) []byte {
-	hash := sha256.New()
-	hash.Write([]byte(data))
-	return hash.Sum(nil)
+	digest := sha256.Sum256([]byte(data))
+	return digest[:]
 }
 
 // WriteAuthenticatedDataPackage creates an AuthenticatedDataPackage
@@ -103,17 +108,22 @@ func WriteAuthenticatedDataPackage(
 		return nil, ContextError(err)
 	}
 
-	return packageJSON, nil
+	return Compress(packageJSON), nil
 }
 
 // ReadAuthenticatedDataPackage extracts and verifies authenticated
 // data from an AuthenticatedDataPackage. The package must have been
 // signed with the given key.
 func ReadAuthenticatedDataPackage(
-	packageJSON []byte, signingPublicKey string) (string, error) {
+	compressedPackage []byte, signingPublicKey string) (string, error) {
+
+	packageJSON, err := Decompress(compressedPackage)
+	if err != nil {
+		return "", ContextError(err)
+	}
 
 	var authenticatedDataPackage *AuthenticatedDataPackage
-	err := json.Unmarshal(packageJSON, &authenticatedDataPackage)
+	err = json.Unmarshal(packageJSON, &authenticatedDataPackage)
 	if err != nil {
 		return "", ContextError(err)
 	}
@@ -149,3 +159,369 @@ func ReadAuthenticatedDataPackage(
 
 	return authenticatedDataPackage.Data, nil
 }
+
+// StreamingReadAuthenticatedDataPackage extracts and verifies authenticated
+// data from an AuthenticatedDataPackage stored in the specified file. The
+// package must have been signed with the given key.
+// StreamingReadAuthenticatedDataPackage does not load the entire package nor
+// the entire data into memory. It streams the package while verifying, and
+// returns an io.ReadCloser that the caller may use to stream the authenticated
+// data payload. The caller _must_ close the io.Closer to free resources and
+// close the underlying file.
+func StreamingReadAuthenticatedDataPackage(
+	packageFileName string, signingPublicKey string) (io.ReadCloser, error) {
+
+	file, err := os.Open(packageFileName)
+	if err != nil {
+		return nil, ContextError(err)
+	}
+
+	closeOnError := file
+	defer func() {
+		if closeOnError != nil {
+			closeOnError.Close()
+		}
+	}()
+
+	var payload io.ReadCloser
+
+	// The file is streamed in 2 passes. The first pass verifies the package
+	// signature. No payload data should be accepted/processed until the signature
+	// check is complete. The second pass repositions to the data payload and returns
+	// a reader the caller will use to stream the authenticated payload.
+	//
+	// Note: No exclusive file lock is held between passes, so it's possible to
+	// verify the data in one pass, and read different data in the second pass.
+	// For Psiphon's use cases, this will not happen in practise -- the packageFileName
+	// will not change while StreamingReadAuthenticatedDataPackage is running -- unless
+	// the client host is compromised; a compromised client host is outside of our threat
+	// model.
+
+	for pass := 0; pass < 2; pass++ {
+
+		_, err = file.Seek(0, 0)
+		if err != nil {
+			return nil, ContextError(err)
+		}
+
+		decompressor, err := zlib.NewReader(file)
+		if err != nil {
+			return nil, ContextError(err)
+		}
+		// TODO: need to Close decompressor to ensure zlib checksum is verified?
+
+		hash := sha256.New()
+
+		var jsonData io.Reader
+		var jsonSigningPublicKey []byte
+		var jsonSignature []byte
+
+		jsonReadBase64Value := func(value io.Reader) ([]byte, error) {
+			base64Value, err := ioutil.ReadAll(value)
+			if err != nil {
+				return nil, ContextError(err)
+			}
+			decodedValue, err := base64.StdEncoding.DecodeString(string(base64Value))
+			if err != nil {
+				return nil, ContextError(err)
+			}
+			return decodedValue, nil
+		}
+
+		jsonHandler := func(key string, value io.Reader) (bool, error) {
+			switch key {
+
+			case "data":
+				if pass == 0 {
+
+					_, err := io.Copy(hash, value)
+					if err != nil {
+						return false, ContextError(err)
+					}
+					return true, nil
+
+				} else { // pass == 1
+
+					jsonData = value
+
+					// The JSON stream parser must halt at this position,
+					// leaving the reader to be returned to the caller positioned
+					// at the start of the data payload.
+					return false, nil
+				}
+
+			case "signingPublicKeyDigest":
+				jsonSigningPublicKey, err = jsonReadBase64Value(value)
+				if err != nil {
+					return false, ContextError(err)
+				}
+				return true, nil
+
+			case "signature":
+				jsonSignature, err = jsonReadBase64Value(value)
+				if err != nil {
+					return false, ContextError(err)
+				}
+				return true, nil
+			}
+
+			return false, ContextError(fmt.Errorf("unexpected key '%s'", key))
+		}
+
+		jsonStreamer := &limitedJSONStreamer{
+			reader:  decompressor,
+			handler: jsonHandler,
+		}
+
+		err = jsonStreamer.Stream()
+		if err != nil {
+			return nil, ContextError(err)
+		}
+
+		if pass == 0 {
+
+			if jsonSigningPublicKey == nil || jsonSignature == nil {
+				return nil, ContextError(errors.New("missing expected field"))
+			}
+
+			derEncodedPublicKey, err := base64.StdEncoding.DecodeString(signingPublicKey)
+			if err != nil {
+				return nil, ContextError(err)
+			}
+			publicKey, err := x509.ParsePKIXPublicKey(derEncodedPublicKey)
+			if err != nil {
+				return nil, ContextError(err)
+			}
+			rsaPublicKey, ok := publicKey.(*rsa.PublicKey)
+			if !ok {
+				return nil, ContextError(errors.New("unexpected signing public key type"))
+			}
+
+			if 0 != bytes.Compare(jsonSigningPublicKey, sha256sum(signingPublicKey)) {
+				return nil, ContextError(errors.New("unexpected signing public key digest"))
+			}
+
+			err = rsa.VerifyPKCS1v15(
+				rsaPublicKey,
+				crypto.SHA256,
+				hash.Sum(nil),
+				jsonSignature)
+			if err != nil {
+				return nil, ContextError(err)
+			}
+
+		} else { // pass == 1
+
+			if jsonData == nil {
+				return nil, ContextError(errors.New("missing expected field"))
+			}
+
+			payload = struct {
+				io.Reader
+				io.Closer
+			}{
+				jsonData,
+				file,
+			}
+		}
+	}
+
+	closeOnError = nil
+
+	return payload, nil
+}
+
+// limitedJSONStreamer is a streaming JSON parser that supports just the
+// JSON required for the AuthenticatedDataPackage format and expected data payloads.
+//
+// Unlike other common streaming JSON parsers, limitedJSONStreamer streams the JSON
+// _values_, as the AuthenticatedDataPackage "data" value may be too large to fit into
+// memory.
+//
+// limitedJSONStreamer is not intended for use outside of AuthenticatedDataPackage
+// and supports only a small subset of JSON: one object with string values only,
+// no escaped characters, no nested objects, no arrays, no numbers, etc.
+//
+// limitedJSONStreamer does support any JSON spec (http://www.json.org/) format
+// for its limited subset. So, for example, any whitespace/formatting should be
+// supported and the creator of AuthenticatedDataPackage should be able to use
+// any valid JSON that results in a AuthenticatedDataPackage object.
+//
+// For each key/value pair, handler is invoked with the key name and a reader
+// to stream the value. The handler _must_ read value to EOF (or return an error).
+type limitedJSONStreamer struct {
+	reader  io.Reader
+	handler func(key string, value io.Reader) (bool, error)
+}
+
+const (
+	stateJSONSeekingObjectStart = iota
+	stateJSONSeekingKeyStart
+	stateJSONSeekingKeyEnd
+	stateJSONSeekingColon
+	stateJSONSeekingStringValueStart
+	stateJSONSeekingStringValueEnd
+	stateJSONSeekingNextPair
+	stateJSONObjectEnd
+)
+
+func (streamer *limitedJSONStreamer) Stream() error {
+
+	// TODO: validate that strings are valid Unicode?
+
+	isWhitespace := func(b byte) bool {
+		return b == ' ' || b == '\t' || b == '\r' || b == '\n'
+	}
+
+	nextByte := make([]byte, 1)
+	keyBuffer := new(bytes.Buffer)
+	state := stateJSONSeekingObjectStart
+
+	for {
+		n, readErr := streamer.reader.Read(nextByte)
+
+		if n > 0 {
+
+			b := nextByte[0]
+
+			switch state {
+
+			case stateJSONSeekingObjectStart:
+				if b == '{' {
+					state = stateJSONSeekingKeyStart
+				} else if !isWhitespace(b) {
+					return ContextError(fmt.Errorf("unexpected character %#U while seeking object start", b))
+				}
+
+			case stateJSONSeekingKeyStart:
+				if b == '"' {
+					state = stateJSONSeekingKeyEnd
+					keyBuffer.Reset()
+				} else if !isWhitespace(b) {
+					return ContextError(fmt.Errorf("unexpected character %#U while seeking key start", b))
+				}
+
+			case stateJSONSeekingKeyEnd:
+				if b == '\\' {
+					return ContextError(errors.New("unsupported escaped character"))
+				} else if b == '"' {
+					state = stateJSONSeekingColon
+				} else {
+					keyBuffer.WriteByte(b)
+				}
+
+			case stateJSONSeekingColon:
+				if b == ':' {
+					state = stateJSONSeekingStringValueStart
+				} else if !isWhitespace(b) {
+					return ContextError(fmt.Errorf("unexpected character %#U while seeking colon", b))
+				}
+
+			case stateJSONSeekingStringValueStart:
+				if b == '"' {
+					state = stateJSONSeekingStringValueEnd
+
+					key := string(keyBuffer.Bytes())
+
+					// Wrap the main reader in a reader that will read up to the end
+					// of the value and then EOF. The handler is expected to consume
+					// the full value, and then stream parsing will resume after the
+					// end of the value.
+					valueStreamer := &limitedJSONValueStreamer{
+						reader: streamer.reader,
+					}
+
+					continueStreaming, err := streamer.handler(key, valueStreamer)
+					if err != nil {
+						return ContextError(err)
+					}
+
+					// The handler may request that streaming halt at this point; no
+					// further changes are made to streamer.reader, leaving the value
+					// exactly where the hander leaves it.
+					if !continueStreaming {
+						return nil
+					}
+
+					state = stateJSONSeekingNextPair
+
+				} else if !isWhitespace(b) {
+					return ContextError(fmt.Errorf("unexpected character %#U while seeking value start", b))
+				}
+
+			case stateJSONSeekingNextPair:
+				if b == ',' {
+					state = stateJSONSeekingKeyStart
+				} else if b == '}' {
+					state = stateJSONObjectEnd
+				} else if !isWhitespace(b) {
+					return ContextError(fmt.Errorf("unexpected character %#U while seeking next name/value pair", b))
+				}
+
+			case stateJSONObjectEnd:
+				if !isWhitespace(b) {
+					return ContextError(fmt.Errorf("unexpected character %#U after object end", b))
+				}
+
+			default:
+				return ContextError(errors.New("unexpected state"))
+
+			}
+		}
+
+		if readErr != nil {
+			if readErr == io.EOF {
+				if state != stateJSONObjectEnd {
+					return ContextError(errors.New("unexpected EOF before object end"))
+				}
+				return nil
+			}
+			return ContextError(readErr)
+		}
+	}
+}
+
+// limitedJSONValueStreamer wraps the limitedJSONStreamer reader
+// with a reader that reads to the end of a string value and then
+// terminates with EOF.
+type limitedJSONValueStreamer struct {
+	mutex  sync.Mutex
+	eof    bool
+	reader io.Reader
+}
+
+// Read implements the io.Reader interface.
+func (streamer *limitedJSONValueStreamer) Read(p []byte) (int, error) {
+	streamer.mutex.Lock()
+	defer streamer.mutex.Unlock()
+
+	if streamer.eof {
+		return 0, io.EOF
+	}
+
+	var i int
+	var err error
+
+	for i = 0; i < len(p); i++ {
+
+		var n int
+		n, err = streamer.reader.Read(p[i : i+1])
+
+		if n == 1 {
+			if p[i] == '"' {
+				n = 0
+				streamer.eof = true
+				err = io.EOF
+			} else if p[i] == '\\' {
+				n = 0
+				err = ContextError(errors.New("unsupported escaped character"))
+			}
+		}
+
+		if err != nil {
+			break
+		}
+	}
+
+	return i, err
+}

+ 146 - 36
psiphon/common/authPackage_test.go

@@ -20,35 +20,67 @@
 package common
 
 import (
+	"encoding/base64"
 	"encoding/json"
+	"io"
+	"io/ioutil"
+	"math/rand"
+	"os"
 	"testing"
 )
 
 func TestAuthenticatedPackage(t *testing.T) {
 
-	var signingPublicKey, signingPrivateKey string
-
-	t.Run("generate package keys", func(t *testing.T) {
-		var err error
-		signingPublicKey, signingPrivateKey, err = GenerateAuthenticatedDataPackageKeys()
-		if err != nil {
-			t.Fatalf("GenerateAuthenticatedDataPackageKeys failed: %s", err)
-		}
-	})
+	signingPublicKey, signingPrivateKey, err := GenerateAuthenticatedDataPackageKeys()
+	if err != nil {
+		t.Fatalf("GenerateAuthenticatedDataPackageKeys failed: %s", err)
+	}
 
 	expectedContent := "TestAuthenticatedPackage"
-	var packagePayload []byte
-
-	t.Run("write package", func(t *testing.T) {
-		var err error
-		packagePayload, err = WriteAuthenticatedDataPackage(
-			expectedContent,
-			signingPublicKey,
-			signingPrivateKey)
-		if err != nil {
-			t.Fatalf("WriteAuthenticatedDataPackage failed: %s", err)
-		}
-	})
+
+	packagePayload, err := WriteAuthenticatedDataPackage(
+		expectedContent,
+		signingPublicKey,
+		signingPrivateKey)
+	if err != nil {
+		t.Fatalf("WriteAuthenticatedDataPackage failed: %s", err)
+	}
+
+	tempFileName, err := makeTempFile(packagePayload)
+	if err != nil {
+		t.Fatalf("makeTempFile failed: %s", err)
+	}
+	defer os.Remove(tempFileName)
+
+	wrongSigningPublicKey, _, err := GenerateAuthenticatedDataPackageKeys()
+	if err != nil {
+		t.Fatalf("GenerateAuthenticatedDataPackageKeys failed: %s", err)
+	}
+
+	packageJSON, err := Decompress(packagePayload)
+	if err != nil {
+		t.Fatalf("Uncompress failed: %s", err)
+	}
+
+	var authDataPackage AuthenticatedDataPackage
+	err = json.Unmarshal(packageJSON, &authDataPackage)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %s", err)
+	}
+	authDataPackage.Data = "TamperedData"
+
+	tamperedPackageJSON, err := json.Marshal(&authDataPackage)
+	if err != nil {
+		t.Fatalf("Marshal failed: %s", err)
+	}
+
+	tamperedPackagePayload := Compress(tamperedPackageJSON)
+
+	tamperedTempFileName, err := makeTempFile(tamperedPackagePayload)
+	if err != nil {
+		t.Fatalf("makeTempFile failed: %s", err)
+	}
+	defer os.Remove(tempFileName)
 
 	t.Run("read package: success", func(t *testing.T) {
 		content, err := ReadAuthenticatedDataPackage(
@@ -63,11 +95,24 @@ func TestAuthenticatedPackage(t *testing.T) {
 		}
 	})
 
-	t.Run("read package: wrong signing key", func(t *testing.T) {
-		wrongSigningPublicKey, _, err := GenerateAuthenticatedDataPackageKeys()
+	t.Run("streaming read package: success", func(t *testing.T) {
+		contentReader, err := StreamingReadAuthenticatedDataPackage(
+			tempFileName, signingPublicKey)
 		if err != nil {
-			t.Fatalf("GenerateAuthenticatedDataPackageKeys failed: %s", err)
+			t.Fatalf("StreamingReadAuthenticatedDataPackage failed: %s", err)
 		}
+		content, err := ioutil.ReadAll(contentReader)
+		if err != nil {
+			t.Fatalf("ReadAll failed: %s", err)
+		}
+		if string(content) != expectedContent {
+			t.Fatalf(
+				"unexpected package content: expected %s got %s",
+				expectedContent, content)
+		}
+	})
+
+	t.Run("read package: wrong signing key", func(t *testing.T) {
 		_, err = ReadAuthenticatedDataPackage(
 			packagePayload, wrongSigningPublicKey)
 		if err == nil {
@@ -75,24 +120,89 @@ func TestAuthenticatedPackage(t *testing.T) {
 		}
 	})
 
-	t.Run("read package: tampered data", func(t *testing.T) {
-
-		var authDataPackage AuthenticatedDataPackage
-		err := json.Unmarshal(packagePayload, &authDataPackage)
-		if err != nil {
-			t.Fatalf("Unmarshal failed: %s", err)
-		}
-		authDataPackage.Data = "TamperedData"
-
-		tamperedPackagePayload, err := json.Marshal(&authDataPackage)
-		if err != nil {
-			t.Fatalf("Marshal failed: %s", err)
+	t.Run("streaming read package: wrong signing key", func(t *testing.T) {
+		_, err = StreamingReadAuthenticatedDataPackage(
+			tempFileName, wrongSigningPublicKey)
+		if err == nil {
+			t.Fatalf("StreamingReadAuthenticatedDataPackage unexpectedly succeeded")
 		}
+	})
 
+	t.Run("read package: tampered data", func(t *testing.T) {
 		_, err = ReadAuthenticatedDataPackage(
 			tamperedPackagePayload, signingPublicKey)
 		if err == nil {
 			t.Fatalf("ReadAuthenticatedDataPackage unexpectedly succeeded")
 		}
 	})
+
+	t.Run("streaming read package: tampered data", func(t *testing.T) {
+		_, err = StreamingReadAuthenticatedDataPackage(
+			tamperedTempFileName, signingPublicKey)
+		if err == nil {
+			t.Fatalf("StreamingReadAuthenticatedDataPackage unexpectedly succeeded")
+		}
+	})
+}
+
+func BenchmarkAuthenticatedPackage(b *testing.B) {
+
+	signingPublicKey, signingPrivateKey, err := GenerateAuthenticatedDataPackageKeys()
+	if err != nil {
+		b.Fatalf("GenerateAuthenticatedDataPackageKeys failed: %s", err)
+	}
+
+	data := make([]byte, 104857600)
+	rand.Read(data)
+
+	packagePayload, err := WriteAuthenticatedDataPackage(
+		base64.StdEncoding.EncodeToString(data),
+		signingPublicKey,
+		signingPrivateKey)
+	if err != nil {
+		b.Fatalf("WriteAuthenticatedDataPackage failed: %s", err)
+	}
+
+	tempFileName, err := makeTempFile(packagePayload)
+	if err != nil {
+		b.Fatalf("makeTempFile failed: %s", err)
+	}
+	defer os.Remove(tempFileName)
+
+	b.Run("read package", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			_, err := ReadAuthenticatedDataPackage(
+				packagePayload, signingPublicKey)
+			if err != nil {
+				b.Fatalf("ReadAuthenticatedDataPackage failed: %s", err)
+			}
+		}
+	})
+
+	b.Run("streaming read package", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			contentReader, err := StreamingReadAuthenticatedDataPackage(
+				tempFileName, signingPublicKey)
+			if err != nil {
+				b.Fatalf("StreamingReadAuthenticatedDataPackage failed: %s", err)
+			}
+			_, err = io.Copy(ioutil.Discard, contentReader)
+			if err != nil {
+				b.Fatalf("Read failed: %s", err)
+			}
+		}
+	})
+}
+
+func makeTempFile(data []byte) (string, error) {
+	file, err := ioutil.TempFile("", "authPackage_test")
+	if err != nil {
+		return "", ContextError(err)
+	}
+	defer file.Close()
+	_, err = file.Write(data)
+	if err != nil {
+		return "", ContextError(err)
+	}
+	return file.Name(), nil
 }