Просмотр исходного кода

Merge pull request #545 from mirokuratczyk/master

Return base64-encoded auth ID
Rod Hynes 6 лет назад
Родитель
Сommit
950cc55778

+ 10 - 8
psiphon/common/accesscontrol/accesscontrol.go

@@ -164,23 +164,25 @@ func ValidateSigningKey(signingKey *SigningKey) error {
 // from the seed without revealing the original value. The authorization
 // from the seed without revealing the original value. The authorization
 // ID is to be used to mitigate malicious authorization reuse/sharing.
 // ID is to be used to mitigate malicious authorization reuse/sharing.
 //
 //
-// The return value is a base64-encoded, serialized JSON representation
-// of the signed authorization that can be passed to VerifyAuthorization.
+// The first return value is a base64-encoded, serialized JSON representation
+// of the signed authorization that can be passed to VerifyAuthorization. The
+// second return value is the unique ID of the signed authorization returned in
+// the first value.
 func IssueAuthorization(
 func IssueAuthorization(
 	signingKey *SigningKey,
 	signingKey *SigningKey,
 	seedAuthorizationID []byte,
 	seedAuthorizationID []byte,
-	expires time.Time) (string, error) {
+	expires time.Time) (string, []byte, error) {
 
 
 	err := ValidateSigningKey(signingKey)
 	err := ValidateSigningKey(signingKey)
 	if err != nil {
 	if err != nil {
-		return "", errors.Trace(err)
+		return "", nil, errors.Trace(err)
 	}
 	}
 
 
 	hkdf := hkdf.New(sha256.New, signingKey.AuthorizationIDKey, nil, seedAuthorizationID)
 	hkdf := hkdf.New(sha256.New, signingKey.AuthorizationIDKey, nil, seedAuthorizationID)
 	ID := make([]byte, authorizationIDLength)
 	ID := make([]byte, authorizationIDLength)
 	_, err = io.ReadFull(hkdf, ID)
 	_, err = io.ReadFull(hkdf, ID)
 	if err != nil {
 	if err != nil {
-		return "", errors.Trace(err)
+		return "", nil, errors.Trace(err)
 	}
 	}
 
 
 	auth := Authorization{
 	auth := Authorization{
@@ -191,7 +193,7 @@ func IssueAuthorization(
 
 
 	authJSON, err := json.Marshal(auth)
 	authJSON, err := json.Marshal(auth)
 	if err != nil {
 	if err != nil {
-		return "", errors.Trace(err)
+		return "", nil, errors.Trace(err)
 	}
 	}
 
 
 	signature := ed25519.Sign(signingKey.PrivateKey, authJSON)
 	signature := ed25519.Sign(signingKey.PrivateKey, authJSON)
@@ -204,12 +206,12 @@ func IssueAuthorization(
 
 
 	signedAuthJSON, err := json.Marshal(signedAuth)
 	signedAuthJSON, err := json.Marshal(signedAuth)
 	if err != nil {
 	if err != nil {
-		return "", errors.Trace(err)
+		return "", nil, errors.Trace(err)
 	}
 	}
 
 
 	encodedSignedAuth := base64.StdEncoding.EncodeToString(signedAuthJSON)
 	encodedSignedAuth := base64.StdEncoding.EncodeToString(signedAuthJSON)
 
 
-	return encodedSignedAuth, nil
+	return encodedSignedAuth, ID, nil
 }
 }
 
 
 // VerificationKeyRing is a set of verification keys to be deployed
 // VerificationKeyRing is a set of verification keys to be deployed

+ 36 - 4
psiphon/common/accesscontrol/accesscontrol_test.go

@@ -89,11 +89,43 @@ func TestAuthorization(t *testing.T) {
 
 
 	expires := time.Now().Add(10 * time.Second)
 	expires := time.Now().Add(10 * time.Second)
 
 
-	auth, err := IssueAuthorization(correctSigningKey, id, expires)
+	auth, issuedID, err := IssueAuthorization(correctSigningKey, id, expires)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("IssueAuthorization failed: %s", err)
 		t.Fatalf("IssueAuthorization failed: %s", err)
 	}
 	}
 
 
+	// Decode the signed authorization and check that the auth ID in the JSON
+	// matches the one returned by IssueAuthorization.
+
+	decodedAuthorization, err := base64.StdEncoding.DecodeString(auth)
+	if err != nil {
+		t.Fatalf("DecodeString failed: %s", err)
+	}
+
+	type partialSignedAuthorization struct {
+		Authorization json.RawMessage
+	}
+	var partialSignedAuth partialSignedAuthorization
+	err = json.Unmarshal(decodedAuthorization, &partialSignedAuth)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %s", err)
+	}
+
+	var unmarshaledAuth map[string]interface{}
+	err = json.Unmarshal(partialSignedAuth.Authorization, &unmarshaledAuth)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %s", err)
+	}
+
+	authID, ok := unmarshaledAuth["ID"].(string)
+	if !ok {
+		t.Fatalf("Failed to find auth ID in unmarshaled auth: %s", unmarshaledAuth)
+	}
+
+	if string(authID) != base64.StdEncoding.EncodeToString(issuedID) {
+		t.Fatalf("Expected auth ID in signed auth (%s) to match that returned by IssueAuthorization (%s)", string(authID), base64.StdEncoding.EncodeToString(issuedID))
+	}
+
 	fmt.Printf("encoded authorization length: %d\n", len(auth))
 	fmt.Printf("encoded authorization length: %d\n", len(auth))
 
 
 	verifiedAuth, err := VerifyAuthorization(keyRing, auth)
 	verifiedAuth, err := VerifyAuthorization(keyRing, auth)
@@ -109,7 +141,7 @@ func TestAuthorization(t *testing.T) {
 
 
 	expires = time.Now().Add(-10 * time.Second)
 	expires = time.Now().Add(-10 * time.Second)
 
 
-	auth, err = IssueAuthorization(correctSigningKey, id, expires)
+	auth, _, err = IssueAuthorization(correctSigningKey, id, expires)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("IssueAuthorization failed: %s", err)
 		t.Fatalf("IssueAuthorization failed: %s", err)
 	}
 	}
@@ -124,7 +156,7 @@ func TestAuthorization(t *testing.T) {
 
 
 	expires = time.Now().Add(10 * time.Second)
 	expires = time.Now().Add(10 * time.Second)
 
 
-	auth, err = IssueAuthorization(invalidSigningKey, id, expires)
+	auth, _, err = IssueAuthorization(invalidSigningKey, id, expires)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("IssueAuthorization failed: %s", err)
 		t.Fatalf("IssueAuthorization failed: %s", err)
 	}
 	}
@@ -139,7 +171,7 @@ func TestAuthorization(t *testing.T) {
 
 
 	expires = time.Now().Add(10 * time.Second)
 	expires = time.Now().Add(10 * time.Second)
 
 
-	auth, err = IssueAuthorization(otherSigningKey, id, expires)
+	auth, _, err = IssueAuthorization(otherSigningKey, id, expires)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("IssueAuthorization failed: %s", err)
 		t.Fatalf("IssueAuthorization failed: %s", err)
 	}
 	}

+ 1 - 1
psiphon/server/server_test.go

@@ -562,7 +562,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 
 
 	var authorizationID [32]byte
 	var authorizationID [32]byte
 
 
-	clientAuthorization, err := accesscontrol.IssueAuthorization(
+	clientAuthorization, _, err := accesscontrol.IssueAuthorization(
 		accessControlSigningKey,
 		accessControlSigningKey,
 		authorizationID[:],
 		authorizationID[:],
 		time.Now().Add(1*time.Hour))
 		time.Now().Add(1*time.Hour))