瀏覽代碼

Add authorized access control component

Rod Hynes 8 年之前
父節點
當前提交
aeca9dd9f9

+ 289 - 0
psiphon/common/accesscontrol/accesscontrol.go

@@ -0,0 +1,289 @@
+/*
+ * Copyright (c) 2018, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+// Package accesscontrol implements an access control authorization scheme
+// based on digital signatures.
+//
+// Authorizations for specified access types are issued by an entity that
+// digitally signs each authorization. The digital signature is verified
+// by service providers before granting the specified access type. Each
+// authorization includes an expiry date and a unique ID that may be used
+// to mitigate malicious reuse/sharing of authorizations.
+//
+// In a typical deployment, the signing keys will be present on issuing
+// entities which are distinct from service providers. Only verification
+// keys will be deployed to service providers.
+//
+// An authorization is encoded in JSON:
+//
+// {
+//   "Authorization" : {
+// 	 "ID" : <derived unique ID>,
+// 	 "AccessType" : <access type name; e.g., "my-access">,
+// 	 "Expires" : <RFC3339-encoded UTC time value>
+//   },
+//   "SigningKeyID" : <unique key ID>,
+//   "Signature" : <Ed25519 digital signature>
+// }
+//
+package accesscontrol
+
+import (
+	"crypto/rand"
+	"crypto/sha256"
+	"crypto/subtle"
+	"encoding/json"
+	"errors"
+	"io"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ed25519"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/hkdf"
+)
+
+const (
+	keyIDLength              = 32
+	authorizationIDKeyLength = 32
+	authorizationIDLength    = 32
+)
+
+// SigningKey is the private key used to sign newly issued
+// authorizations for the specified access type. The key ID
+// is included in authorizations and identifies the
+// corresponding verification keys.
+//
+// AuthorizationIDKey is used to produce a unique
+// authentication ID that cannot be mapped back to its seed
+// value.
+type SigningKey struct {
+	ID                 []byte
+	AccessType         string
+	AuthorizationIDKey []byte
+	PrivateKey         []byte
+}
+
+// VerificationKey is the public key used to verify signed
+// authentications issued for the specified access type. The
+// authorization references the expected public key by ID.
+type VerificationKey struct {
+	ID         []byte
+	AccessType string
+	PublicKey  []byte
+}
+
+// NewKeyPair generates a new authorization signing key pair.
+func NewKeyPair(
+	accessType string) (*SigningKey, *VerificationKey, error) {
+
+	ID, err := common.MakeSecureRandomBytes(keyIDLength)
+	if err != nil {
+		return nil, nil, common.ContextError(err)
+	}
+
+	authorizationIDKey, err := common.MakeSecureRandomBytes(authorizationIDKeyLength)
+	if err != nil {
+		return nil, nil, common.ContextError(err)
+	}
+
+	publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
+	if err != nil {
+		return nil, nil, common.ContextError(err)
+	}
+
+	signingKey := &SigningKey{
+		ID:                 ID,
+		AccessType:         accessType,
+		AuthorizationIDKey: authorizationIDKey,
+		PrivateKey:         privateKey,
+	}
+
+	verificationKey := &VerificationKey{
+		ID:         ID,
+		AccessType: accessType,
+		PublicKey:  publicKey,
+	}
+
+	return signingKey, verificationKey, nil
+}
+
+// Authorization describes an authorization, with a unique ID,
+// granting access to a specified access type, and expiring at
+// the specified time.
+//
+// An Authorization is embedded within a digitally signed
+// object. This wrapping object adds a signature and a signing
+// key ID.
+type Authorization struct {
+	ID         []byte
+	AccessType string
+	Expires    time.Time
+}
+
+type signedAuthorization struct {
+	Authorization json.RawMessage
+	SigningKeyID  []byte
+	Signature     []byte
+}
+
+// IssueAuthorization issues an authorization signed with the
+// specified signing key.
+//
+// seedAuthorizationID should be a value that uniquely identifies
+// the purchase, subscription, or transaction that backs the
+// authorization; a distinct unique authorization ID will be derived
+// from the seed without revealing the original value. The authorization
+// ID is to be used to mitigate malicious authorization reuse/sharing.
+//
+// The return value is a serialized JSON representation of the
+// signed authorization that can be passed to VerifyAuthorization.
+func IssueAuthorization(
+	signingKey *SigningKey,
+	seedAuthorizationID []byte,
+	expires time.Time) ([]byte, error) {
+
+	if len(signingKey.ID) != keyIDLength ||
+		len(signingKey.AccessType) < 1 ||
+		len(signingKey.AuthorizationIDKey) != authorizationIDKeyLength ||
+		len(signingKey.PrivateKey) != ed25519.PrivateKeySize {
+		return nil, common.ContextError(errors.New("invalid signing key"))
+	}
+
+	hkdf := hkdf.New(sha256.New, signingKey.AuthorizationIDKey, nil, seedAuthorizationID)
+	ID := make([]byte, authorizationIDLength)
+	_, err := io.ReadFull(hkdf, ID)
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+
+	auth := Authorization{
+		ID:         ID,
+		AccessType: signingKey.AccessType,
+		Expires:    expires.UTC(),
+	}
+
+	authJSON, err := json.Marshal(auth)
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+
+	signature := ed25519.Sign(signingKey.PrivateKey, authJSON)
+
+	signedAuth := signedAuthorization{
+		Authorization: authJSON,
+		SigningKeyID:  signingKey.ID,
+		Signature:     signature,
+	}
+
+	signedAuthJSON, err := json.Marshal(signedAuth)
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+
+	return signedAuthJSON, nil
+}
+
+// VerificationKeyRing is a set of verification keys to be deployed
+// to a service provider for verifying access authorizations.
+type VerificationKeyRing struct {
+	Keys []*VerificationKey
+}
+
+// ValidateKeyRing checks that a verification key ring is correctly
+// configured.
+func ValidateKeyRing(keyRing *VerificationKeyRing) error {
+	for _, key := range keyRing.Keys {
+		if len(key.ID) != keyIDLength ||
+			len(key.AccessType) < 1 ||
+			len(key.PublicKey) != ed25519.PublicKeySize {
+			return common.ContextError(errors.New("invalid verification key"))
+		}
+	}
+	return nil
+}
+
+// VerifyAuthorization verifies the signed authorization and, when
+// verified, returns the embedded Authorization struct with the
+// access control information.
+//
+// The key ID in the signed authorization is used to select the
+// appropriate verification key from the key ring.
+//
+// Assumes that ValidateKeyRing has been called.
+func VerifyAuthorization(
+	keyRing *VerificationKeyRing,
+	signedAuthorizationJSON []byte) (*Authorization, error) {
+
+	var signedAuth signedAuthorization
+
+	err := json.Unmarshal(signedAuthorizationJSON, &signedAuth)
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+
+	if len(signedAuth.SigningKeyID) != keyIDLength {
+		return nil, common.ContextError(errors.New("invalid key ID length"))
+	}
+
+	if len(signedAuth.Signature) != ed25519.SignatureSize {
+		return nil, common.ContextError(errors.New("invalid signature length"))
+	}
+
+	var verificationKey *VerificationKey
+
+	for _, key := range keyRing.Keys {
+		if subtle.ConstantTimeCompare(signedAuth.SigningKeyID, key.ID) == 1 {
+			verificationKey = key
+		}
+	}
+
+	if verificationKey == nil {
+		return nil, common.ContextError(errors.New("invalid key ID"))
+	}
+
+	if !ed25519.Verify(
+		verificationKey.PublicKey, signedAuth.Authorization, signedAuth.Signature) {
+		return nil, common.ContextError(errors.New("invalid signature"))
+	}
+
+	var auth Authorization
+
+	err = json.Unmarshal(signedAuth.Authorization, &auth)
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+
+	if len(auth.ID) == 0 {
+		return nil, common.ContextError(errors.New("invalid authentication ID"))
+	}
+
+	if auth.AccessType != verificationKey.AccessType {
+		return nil, common.ContextError(errors.New("invalid access type"))
+	}
+
+	if auth.Expires.IsZero() {
+		return nil, common.ContextError(errors.New("invalid expiry"))
+	}
+
+	if auth.Expires.Before(time.Now().UTC()) {
+		return nil, common.ContextError(errors.New("expired authentication"))
+	}
+
+	return &auth, nil
+}

+ 160 - 0
psiphon/common/accesscontrol/accesscontrol_test.go

@@ -0,0 +1,160 @@
+/*
+ * Copyright (c) 2018, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package accesscontrol
+
+import (
+	"encoding/json"
+	"testing"
+	"time"
+)
+
+func TestAuthorization(t *testing.T) {
+
+	correctAccess := "access1"
+	otherAccess := "access2"
+
+	correctSigningKey, correctVerificationKey, err := NewKeyPair(correctAccess)
+	if err != nil {
+		t.Fatalf("NewKeyPair failed: %s", err)
+	}
+
+	otherSigningKey, otherVerificationKey, err := NewKeyPair(otherAccess)
+	if err != nil {
+		t.Fatalf("NewKeyPair failed: %s", err)
+	}
+
+	invalidSigningKey, _, err := NewKeyPair(correctAccess)
+	if err != nil {
+		t.Fatalf("NewKeyPair failed: %s", err)
+	}
+
+	keyRing := &VerificationKeyRing{
+		Keys: []*VerificationKey{correctVerificationKey, otherVerificationKey},
+	}
+
+	// Test: valid key ring
+
+	err = ValidateKeyRing(keyRing)
+	if err != nil {
+		t.Fatalf("ValidateKeyRing failed: %s", err)
+	}
+
+	// Test: invalid key ring
+
+	invalidKeyRing := &VerificationKeyRing{
+		Keys: []*VerificationKey{&VerificationKey{}},
+	}
+
+	err = ValidateKeyRing(invalidKeyRing)
+	if err == nil {
+		t.Fatalf("ValidateKeyRing unexpected success")
+	}
+
+	// Test: valid authorization
+
+	id := []byte("0000000000000001")
+
+	expires := time.Now().Add(10 * time.Second)
+
+	auth, err := IssueAuthorization(correctSigningKey, id, expires)
+	if err != nil {
+		t.Fatalf("IssueAuthorization failed: %s", err)
+	}
+
+	verifiedAuth, err := VerifyAuthorization(keyRing, auth)
+	if err != nil {
+		t.Fatalf("VerifyAuthorization failed: %s", err)
+	}
+
+	if verifiedAuth.AccessType != correctAccess {
+		t.Fatalf("unexpected access type: %s", verifiedAuth.AccessType)
+	}
+
+	// Test: expired authorization
+
+	expires = time.Now().Add(-10 * time.Second)
+
+	auth, err = IssueAuthorization(correctSigningKey, id, expires)
+	if err != nil {
+		t.Fatalf("IssueAuthorization failed: %s", err)
+	}
+
+	verifiedAuth, err = VerifyAuthorization(keyRing, auth)
+	// TODO: check error message?
+	if err == nil {
+		t.Fatalf("VerifyAuthorization unexpected success")
+	}
+
+	// Test: authorization signed with key not in key ring
+
+	expires = time.Now().Add(10 * time.Second)
+
+	auth, err = IssueAuthorization(invalidSigningKey, id, expires)
+	if err != nil {
+		t.Fatalf("IssueAuthorization failed: %s", err)
+	}
+
+	verifiedAuth, err = VerifyAuthorization(keyRing, auth)
+	// TODO: check error message?
+	if err == nil {
+		t.Fatalf("VerifyAuthorization unexpected success")
+	}
+
+	// Test: authorization signed with valid key, but hacked access type
+
+	expires = time.Now().Add(10 * time.Second)
+
+	auth, err = IssueAuthorization(otherSigningKey, id, expires)
+	if err != nil {
+		t.Fatalf("IssueAuthorization failed: %s", err)
+	}
+
+	var hackSignedAuth signedAuthorization
+	err = json.Unmarshal(auth, &hackSignedAuth)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %s", err)
+	}
+
+	var hackAuth Authorization
+	err = json.Unmarshal(hackSignedAuth.Authorization, &hackAuth)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %s", err)
+	}
+
+	hackAuth.AccessType = correctAccess
+
+	auth, err = json.Marshal(hackAuth)
+	if err != nil {
+		t.Fatalf("Marshall failed: %s", err)
+	}
+
+	hackSignedAuth.Authorization = auth
+
+	signedAuth, err := json.Marshal(hackSignedAuth)
+	if err != nil {
+		t.Fatalf("Marshall failed: %s", err)
+	}
+
+	verifiedAuth, err = VerifyAuthorization(keyRing, signedAuth)
+	// TODO: check error message?
+	if err == nil {
+		t.Fatalf("VerifyAuthorization unexpected success")
+	}
+}

+ 23 - 3
psiphon/common/osl/osl.go

@@ -1297,9 +1297,17 @@ func NewRegistryStreamer(
 	if err != nil {
 		return nil, common.ContextError(err)
 	}
-	if name, ok := token.(string); !ok || name != "FileSpecs" {
+
+	name, ok := token.(string)
+
+	if !ok {
+		return nil, common.ContextError(
+			fmt.Errorf("unexpected token type: %T", token))
+	}
+
+	if name != "FileSpecs" {
 		return nil, common.ContextError(
-			fmt.Errorf("unexpected name: %s", name))
+			fmt.Errorf("unexpected field name: %s", name))
 	}
 
 	err = expectJSONDelimiter(jsonDecoder, "[")
@@ -1367,10 +1375,19 @@ func expectJSONDelimiter(jsonDecoder *json.Decoder, delimiter string) error {
 	if err != nil {
 		return common.ContextError(err)
 	}
-	if delim, ok := token.(json.Delim); !ok || delim.String() != delimiter {
+
+	delim, ok := token.(json.Delim)
+
+	if !ok {
+		return common.ContextError(
+			fmt.Errorf("unexpected token type: %T", token))
+	}
+
+	if delim.String() != delimiter {
 		return common.ContextError(
 			fmt.Errorf("unexpected delimiter: %s", delim.String()))
 	}
+
 	return nil
 }
 
@@ -1446,6 +1463,9 @@ func newSeededKeyMaterialReader(seed []byte) (io.Reader, error) {
 // deriveKeyHKDF implements HKDF-Expand as defined in https://tools.ietf.org/html/rfc5869
 // where masterKey = PRK, context = info, and L = 32; SHA-256 is used so HashLen = 32
 func deriveKeyHKDF(masterKey []byte, context ...[]byte) []byte {
+
+	// TODO: use golang.org/x/crypto/hkdf?
+
 	mac := hmac.New(sha256.New, masterKey)
 	for _, item := range context {
 		mac.Write([]byte(item))

+ 11 - 8
psiphon/common/protocol/protocol.go

@@ -56,6 +56,8 @@ const (
 	PSIPHON_WEB_API_PROTOCOL = "web"
 
 	PACKET_TUNNEL_CHANNEL_TYPE = "tun@psiphon.ca"
+
+	PSIPHON_API_HANDSHAKE_AUTHORIZATIONS = "authorizations"
 )
 
 var SupportedTunnelProtocols = []string{
@@ -125,14 +127,15 @@ func UseClientTunnelProtocol(
 }
 
 type HandshakeResponse struct {
-	SSHSessionID         string              `json:"ssh_session_id"`
-	Homepages            []string            `json:"homepages"`
-	UpgradeClientVersion string              `json:"upgrade_client_version"`
-	PageViewRegexes      []map[string]string `json:"page_view_regexes"`
-	HttpsRequestRegexes  []map[string]string `json:"https_request_regexes"`
-	EncodedServerList    []string            `json:"encoded_server_list"`
-	ClientRegion         string              `json:"client_region"`
-	ServerTimestamp      string              `json:"server_timestamp"`
+	SSHSessionID          string              `json:"ssh_session_id"`
+	Homepages             []string            `json:"homepages"`
+	UpgradeClientVersion  string              `json:"upgrade_client_version"`
+	PageViewRegexes       []map[string]string `json:"page_view_regexes"`
+	HttpsRequestRegexes   []map[string]string `json:"https_request_regexes"`
+	EncodedServerList     []string            `json:"encoded_server_list"`
+	ClientRegion          string              `json:"client_region"`
+	ServerTimestamp       string              `json:"server_timestamp"`
+	AuthorizedAccessTypes []string            `json:"authorized_access_types"`
 }
 
 type ConnectedResponse struct {

+ 11 - 0
psiphon/common/utils.go

@@ -48,6 +48,17 @@ func Contains(list []string, target string) bool {
 	return false
 }
 
+// ContainsAny returns true if any string in targets
+// is present ini he list.
+func ContainsAny(list, targets []string) bool {
+	for _, target := range targets {
+		if Contains(list, target) {
+			return true
+		}
+	}
+	return false
+}
+
 // ContainsInt returns true if the target int is
 // in the list.
 func ContainsInt(list []int, target int) bool {

+ 4 - 0
psiphon/config.go

@@ -499,6 +499,10 @@ type Config struct {
 	// ID is automatically generated. Supply a session ID when a single client session
 	// will cross multiple Controller instances.
 	SessionID string
+
+	// Authorizations is a list of encoded, signed access control authorizations that
+	// the client has obtained and will present to the server.
+	Authorizations []json.RawMessage
 }
 
 // DownloadURL specifies a URL for downloading resources along with parameters

+ 8 - 1
psiphon/notice.go

@@ -704,13 +704,20 @@ func NoticeSLOKSeeded(slokID string, duplicate bool) {
 		"duplicate", duplicate)
 }
 
-// NoticeServerTimestamp reports server side timestamp as seen in the handshake
+// NoticeServerTimestamp reports server side timestamp as seen in the handshake.
 func NoticeServerTimestamp(timestamp string) {
 	singletonNoticeLogger.outputNotice(
 		"ServerTimestamp", 0,
 		"timestamp", timestamp)
 }
 
+// NoticeAuthorizedAccessTypes reports the authorized access types the server has accepted.
+func NoticeAuthorizedAccessTypes(authorizedAccessTypes []string) {
+	singletonNoticeLogger.outputNotice(
+		"AuthorizedAccessTypes", 0,
+		"accessTypes", authorizedAccessTypes)
+}
+
 type repetitiveNoticeState struct {
 	message string
 	repeats int

+ 174 - 38
psiphon/server/api.go

@@ -20,6 +20,7 @@
 package server
 
 import (
+	"bytes"
 	"crypto/subtle"
 	"encoding/json"
 	"errors"
@@ -47,8 +48,6 @@ const (
 
 var CLIENT_VERIFICATION_REQUIRED = false
 
-type requestJSONObject map[string]interface{}
-
 // sshAPIRequestHandler routes Psiphon API requests transported as
 // JSON objects via the SSH request mechanism.
 //
@@ -56,21 +55,21 @@ type requestJSONObject map[string]interface{}
 // reused by webServer which offers the Psiphon API via web transport.
 //
 // The API request parameters and event log values follow the legacy
-// psi_web protocol and naming conventions. The API is compatible all
-// tunnel-core clients but are not backwards compatible with older
-// clients.
+// psi_web protocol and naming conventions. The API is compatible with
+// all tunnel-core clients but are not backwards compatible with all
+// legacy clients.
 //
 func sshAPIRequestHandler(
 	support *SupportServices,
 	geoIPData GeoIPData,
+	authorizedAccessTypes []string,
 	name string,
 	requestPayload []byte) ([]byte, error) {
 
 	// Note: for SSH requests, MAX_API_PARAMS_SIZE is implicitly enforced
-	// by max SSH reqest packet size.
+	// by max SSH request packet size.
 
-	var params requestJSONObject
-	err := json.Unmarshal(requestPayload, &params)
+	params, err := requestJSONUnmarshal(requestPayload)
 	if err != nil {
 		return nil, common.ContextError(
 			fmt.Errorf("invalid payload for request name: %s: %s", name, err))
@@ -80,6 +79,7 @@ func sshAPIRequestHandler(
 		support,
 		protocol.PSIPHON_SSH_API_PROTOCOL,
 		geoIPData,
+		authorizedAccessTypes,
 		name,
 		params)
 }
@@ -90,6 +90,7 @@ func dispatchAPIRequestHandler(
 	support *SupportServices,
 	apiProtocol string,
 	geoIPData GeoIPData,
+	authorizedAccessTypes []string,
 	name string,
 	params requestJSONObject) (response []byte, reterr error) {
 
@@ -155,11 +156,11 @@ func dispatchAPIRequestHandler(
 	case protocol.PSIPHON_API_HANDSHAKE_REQUEST_NAME:
 		return handshakeAPIRequestHandler(support, apiProtocol, geoIPData, params)
 	case protocol.PSIPHON_API_CONNECTED_REQUEST_NAME:
-		return connectedAPIRequestHandler(support, geoIPData, params)
+		return connectedAPIRequestHandler(support, geoIPData, authorizedAccessTypes, params)
 	case protocol.PSIPHON_API_STATUS_REQUEST_NAME:
-		return statusAPIRequestHandler(support, geoIPData, params)
+		return statusAPIRequestHandler(support, geoIPData, authorizedAccessTypes, params)
 	case protocol.PSIPHON_API_CLIENT_VERIFICATION_REQUEST_NAME:
-		return clientVerificationAPIRequestHandler(support, geoIPData, params)
+		return clientVerificationAPIRequestHandler(support, geoIPData, authorizedAccessTypes, params)
 	}
 
 	return nil, common.ContextError(fmt.Errorf("invalid request name: %s", name))
@@ -189,19 +190,32 @@ func handshakeAPIRequestHandler(
 	isMobile := isMobileClientPlatform(clientPlatform)
 	normalizedPlatform := normalizeClientPlatform(clientPlatform)
 
+	var authorizations [][]byte
+	if params[protocol.PSIPHON_API_HANDSHAKE_AUTHORIZATIONS] != nil {
+		authorizationsRawJSON, err := getRawJSONArrayRequestParam(params, protocol.PSIPHON_API_HANDSHAKE_AUTHORIZATIONS)
+		if err != nil {
+			return nil, common.ContextError(err)
+		}
+		authorizations = make([][]byte, len(authorizationsRawJSON))
+		for i := 0; i < len(authorizationsRawJSON); i++ {
+			authorizations[i] = authorizationsRawJSON[i]
+		}
+	}
+
 	// Flag the SSH client as having completed its handshake. This
 	// may reselect traffic rules and starts allowing port forwards.
 
 	// TODO: in the case of SSH API requests, the actual sshClient could
 	// be passed in and used here. The session ID lookup is only strictly
 	// necessary to support web API requests.
-	err = support.TunnelServer.SetClientHandshakeState(
+	authorizedAccessTypes, err := support.TunnelServer.SetClientHandshakeState(
 		sessionID,
 		handshakeState{
 			completed:   true,
 			apiProtocol: apiProtocol,
 			apiParams:   copyBaseRequestParams(params),
-		})
+		},
+		authorizations)
 	if err != nil {
 		return nil, common.ContextError(err)
 	}
@@ -211,23 +225,24 @@ func handshakeAPIRequestHandler(
 
 	log.LogRawFieldsWithTimestamp(
 		getRequestLogFields(
-			support,
 			"handshake",
 			geoIPData,
+			authorizedAccessTypes,
 			params,
 			baseRequestParams))
 
 	// Note: no guarantee that PsinetDatabase won't reload between database calls
 	db := support.PsinetDatabase
 	handshakeResponse := protocol.HandshakeResponse{
-		SSHSessionID:         sessionID,
-		Homepages:            db.GetRandomizedHomepages(sponsorID, geoIPData.Country, isMobile),
-		UpgradeClientVersion: db.GetUpgradeClientVersion(clientVersion, normalizedPlatform),
-		PageViewRegexes:      make([]map[string]string, 0),
-		HttpsRequestRegexes:  db.GetHttpsRequestRegexes(sponsorID),
-		EncodedServerList:    db.DiscoverServers(geoIPData.DiscoveryValue),
-		ClientRegion:         geoIPData.Country,
-		ServerTimestamp:      common.GetCurrentTimestamp(),
+		SSHSessionID:          sessionID,
+		Homepages:             db.GetRandomizedHomepages(sponsorID, geoIPData.Country, isMobile),
+		UpgradeClientVersion:  db.GetUpgradeClientVersion(clientVersion, normalizedPlatform),
+		PageViewRegexes:       make([]map[string]string, 0),
+		HttpsRequestRegexes:   db.GetHttpsRequestRegexes(sponsorID),
+		EncodedServerList:     db.DiscoverServers(geoIPData.DiscoveryValue),
+		ClientRegion:          geoIPData.Country,
+		ServerTimestamp:       common.GetCurrentTimestamp(),
+		AuthorizedAccessTypes: authorizedAccessTypes,
 	}
 
 	responsePayload, err := json.Marshal(handshakeResponse)
@@ -252,6 +267,7 @@ var connectedRequestParams = append(
 func connectedAPIRequestHandler(
 	support *SupportServices,
 	geoIPData GeoIPData,
+	authorizedAccessTypes []string,
 	params requestJSONObject) ([]byte, error) {
 
 	err := validateRequestParams(support, params, connectedRequestParams)
@@ -261,9 +277,9 @@ func connectedAPIRequestHandler(
 
 	log.LogRawFieldsWithTimestamp(
 		getRequestLogFields(
-			support,
 			"connected",
 			geoIPData,
+			authorizedAccessTypes,
 			params,
 			connectedRequestParams))
 
@@ -294,6 +310,7 @@ var statusRequestParams = append(
 func statusAPIRequestHandler(
 	support *SupportServices,
 	geoIPData GeoIPData,
+	authorizedAccessTypes []string,
 	params requestJSONObject) ([]byte, error) {
 
 	err := validateRequestParams(support, params, statusRequestParams)
@@ -319,8 +336,14 @@ func statusAPIRequestHandler(
 	if err != nil {
 		return nil, common.ContextError(err)
 	}
+
 	bytesTransferredFields := getRequestLogFields(
-		support, "bytes_transferred", geoIPData, params, statusRequestParams)
+		"bytes_transferred",
+		geoIPData,
+		authorizedAccessTypes,
+		params,
+		statusRequestParams)
+
 	bytesTransferredFields["bytes"] = bytesTransferred
 	logQueue = append(logQueue, bytesTransferredFields)
 
@@ -336,7 +359,11 @@ func statusAPIRequestHandler(
 		for domain, bytes := range hostBytes {
 
 			domainBytesFields := getRequestLogFields(
-				support, "domain_bytes", geoIPData, params, statusRequestParams)
+				"domain_bytes",
+				geoIPData,
+				authorizedAccessTypes,
+				params,
+				statusRequestParams)
 
 			domainBytesFields["domain"] = domain
 			domainBytesFields["bytes"] = bytes
@@ -357,7 +384,11 @@ func statusAPIRequestHandler(
 		for _, tunnelStat := range tunnelStats {
 
 			sessionFields := getRequestLogFields(
-				support, "session", geoIPData, params, statusRequestParams)
+				"session",
+				geoIPData,
+				authorizedAccessTypes,
+				params,
+				statusRequestParams)
 
 			sessionID, err := getStringRequestParam(tunnelStat, "session_id")
 			if err != nil {
@@ -437,7 +468,11 @@ func statusAPIRequestHandler(
 		for _, remoteServerListStat := range remoteServerListStats {
 
 			remoteServerListFields := getRequestLogFields(
-				support, "remote_server_list", geoIPData, params, statusRequestParams)
+				"remote_server_list",
+				geoIPData,
+				authorizedAccessTypes,
+				params,
+				statusRequestParams)
 
 			clientDownloadTimestamp, err := getStringRequestParam(remoteServerListStat, "client_download_timestamp")
 			if err != nil {
@@ -475,6 +510,7 @@ func statusAPIRequestHandler(
 func clientVerificationAPIRequestHandler(
 	support *SupportServices,
 	geoIPData GeoIPData,
+	authorizedAccessTypes []string,
 	params requestJSONObject) ([]byte, error) {
 
 	err := validateRequestParams(support, params, baseRequestParams)
@@ -510,9 +546,9 @@ func clientVerificationAPIRequestHandler(
 		}
 
 		logFields := getRequestLogFields(
-			support,
 			"client_verification",
 			geoIPData,
+			authorizedAccessTypes,
 			params,
 			baseRequestParams)
 
@@ -533,6 +569,91 @@ func clientVerificationAPIRequestHandler(
 	}
 }
 
+type requestJSONObject map[string]interface{}
+
+// requestJSONUnmarshal is equivilent to:
+//
+//   var params requestJSONObject
+//   json.Unmarshal(jsonPayload, &params)
+//
+// ...with the one exception that when the field name is
+// protocol.PSIPHON_API_HANDSHAKE_AUTHORIZATIONS, the value is
+// not fully unmarshaled but instead treated as []json.RawMessage.
+// This leaves the authentications in PSIPHON_API_HANDSHAKE_AUTHORIZATIONS
+// as raw JSON to be unmarshaled in accesscontrol.VerifyAuthorization.
+func requestJSONUnmarshal(jsonPayload []byte) (requestJSONObject, error) {
+
+	expectJSONDelimiter := func(jsonDecoder *json.Decoder, delimiter string) error {
+
+		token, err := jsonDecoder.Token()
+		if err != nil {
+			return err
+		}
+
+		delim, ok := token.(json.Delim)
+		if !ok {
+			return fmt.Errorf("unexpected token type: %T", token)
+		}
+
+		if delim.String() != delimiter {
+			return fmt.Errorf("unexpected delimiter: %s", delim.String())
+		}
+
+		return nil
+	}
+
+	params := make(requestJSONObject)
+
+	jsonDecoder := json.NewDecoder(bytes.NewReader(jsonPayload))
+
+	err := expectJSONDelimiter(jsonDecoder, "{")
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+
+	for jsonDecoder.More() {
+
+		token, err := jsonDecoder.Token()
+		if err != nil {
+			return nil, common.ContextError(err)
+		}
+
+		name, ok := token.(string)
+		if !ok {
+			return nil, common.ContextError(
+				fmt.Errorf("unexpected token type: %T", token))
+		}
+
+		var value interface{}
+
+		if name == protocol.PSIPHON_API_HANDSHAKE_AUTHORIZATIONS {
+
+			var rawJSONArray []json.RawMessage
+			err = jsonDecoder.Decode(&rawJSONArray)
+			if err != nil {
+				return nil, common.ContextError(err)
+			}
+			value = rawJSONArray
+
+		} else {
+
+			err = jsonDecoder.Decode(&value)
+			if err != nil {
+				return nil, common.ContextError(err)
+			}
+		}
+
+		params[name] = value
+	}
+
+	err = expectJSONDelimiter(jsonDecoder, "}")
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+
+	return params, nil
+}
+
 type requestParamSpec struct {
 	name      string
 	validator func(*SupportServices, string) bool
@@ -663,9 +784,9 @@ func validateStringArrayRequestParam(
 // getRequestLogFields makes LogFields to log the API event following
 // the legacy psi_web and current ELK naming conventions.
 func getRequestLogFields(
-	support *SupportServices,
 	eventName string,
 	geoIPData GeoIPData,
+	authorizedAccessTypes []string,
 	params requestJSONObject,
 	expectedParams []requestParamSpec) LogFields {
 
@@ -680,6 +801,10 @@ func getRequestLogFields(
 	logFields["client_city"] = strings.Replace(geoIPData.City, " ", "_", -1)
 	logFields["client_isp"] = strings.Replace(geoIPData.ISP, " ", "_", -1)
 
+	if len(authorizedAccessTypes) > 0 {
+		logFields["authorized_access_types"] = authorizedAccessTypes
+	}
+
 	if params == nil {
 		return logFields
 	}
@@ -718,7 +843,7 @@ func getRequestLogFields(
 				logFields[expectedParam.name] = intValue
 			case "meek_dial_address":
 				host, _, _ := net.SplitHostPort(strValue)
-				if isIPAddress(support, host) {
+				if isIPAddress(nil, host) {
 					logFields["meek_dial_ip_address"] = host
 				} else {
 					logFields["meek_dial_domain"] = host
@@ -825,6 +950,17 @@ func getMapStringInt64RequestParam(params requestJSONObject, name string) (map[s
 	return result, nil
 }
 
+func getRawJSONArrayRequestParam(params requestJSONObject, name string) ([]json.RawMessage, error) {
+	if params[name] == nil {
+		return nil, common.ContextError(fmt.Errorf("missing param: %s", name))
+	}
+	value, ok := params[name].([]json.RawMessage)
+	if !ok {
+		return nil, common.ContextError(fmt.Errorf("invalid param: %s", name))
+	}
+	return value, nil
+}
+
 // Normalize reported client platform. Android clients, for example, report
 // OS version, rooted status, and Google Play build status in the clientPlatform
 // string along with "Android".
@@ -904,16 +1040,16 @@ func isRegionCode(_ *SupportServices, value string) bool {
 	})
 }
 
-func isDialAddress(support *SupportServices, value string) bool {
+func isDialAddress(_ *SupportServices, value string) bool {
 	// "<host>:<port>", where <host> is a domain or IP address
 	parts := strings.Split(value, ":")
 	if len(parts) != 2 {
 		return false
 	}
-	if !isIPAddress(support, parts[0]) && !isDomain(support, parts[0]) {
+	if !isIPAddress(nil, parts[0]) && !isDomain(nil, parts[0]) {
 		return false
 	}
-	if !isDigits(support, parts[1]) {
+	if !isDigits(nil, parts[1]) {
 		return false
 	}
 	port, err := strconv.Atoi(parts[1])
@@ -956,12 +1092,12 @@ func isDomain(_ *SupportServices, value string) bool {
 	return true
 }
 
-func isHostHeader(support *SupportServices, value string) bool {
+func isHostHeader(_ *SupportServices, value string) bool {
 	// "<host>:<port>", where <host> is a domain or IP address and ":<port>" is optional
 	if strings.Contains(value, ":") {
-		return isDialAddress(support, value)
+		return isDialAddress(nil, value)
 	}
-	return isIPAddress(support, value) || isDomain(support, value)
+	return isIPAddress(nil, value) || isDomain(nil, value)
 }
 
 func isServerEntrySource(_ *SupportServices, value string) bool {
@@ -975,6 +1111,6 @@ func isISO8601Date(_ *SupportServices, value string) bool {
 	return isISO8601DateRegex.Match([]byte(value))
 }
 
-func isLastConnected(support *SupportServices, value string) bool {
-	return value == "None" || value == "Unknown" || isISO8601Date(support, value)
+func isLastConnected(_ *SupportServices, value string) bool {
+	return value == "None" || value == "Unknown" || isISO8601Date(nil, value)
 }

+ 14 - 0
psiphon/server/config.go

@@ -33,6 +33,7 @@ import (
 	"strings"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/accesscontrol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/nacl/box"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
@@ -289,6 +290,13 @@ type Config struct {
 	// every specified number of seconds, to force garbage collection.
 	// The default, 0 is off.
 	PeriodicGarbageCollectionSeconds int
+
+	// AccessControlVerificationKeyRing is the access control authorization
+	// verification key ring used to verify signed authorizations presented
+	// by clients. Verified, active (unexpired) access control types will be
+	// available for matching in the TrafficRulesFilter for the client via
+	// AuthorizedAccessTypes. All other authorizations are ignored.
+	AccessControlVerificationKeyRing accesscontrol.VerificationKeyRing
 }
 
 // RunWebServer indicates whether to run a web server component.
@@ -386,6 +394,12 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 		}
 	}
 
+	err = accesscontrol.ValidateKeyRing(&config.AccessControlVerificationKeyRing)
+	if err != nil {
+		return nil, fmt.Errorf(
+			"AccessControlVerificationKeyRing is invalid: %s", err)
+	}
+
 	return &config, nil
 }
 

+ 127 - 10
psiphon/server/server_test.go

@@ -39,6 +39,7 @@ import (
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/accesscontrol"
 	"golang.org/x/net/proxy"
 )
 
@@ -124,6 +125,8 @@ func TestSSH(t *testing.T) {
 			doHotReload:          false,
 			doDefaultSessionID:   false,
 			denyTrafficRules:     false,
+			requireAuthorization: true,
+			omitAuthorization:    false,
 			doClientVerification: true,
 			doTunneledWebRequest: true,
 			doTunneledNTPRequest: true,
@@ -138,6 +141,8 @@ func TestOSSH(t *testing.T) {
 			doHotReload:          false,
 			doDefaultSessionID:   false,
 			denyTrafficRules:     false,
+			requireAuthorization: true,
+			omitAuthorization:    false,
 			doClientVerification: false,
 			doTunneledWebRequest: true,
 			doTunneledNTPRequest: true,
@@ -152,6 +157,8 @@ func TestUnfrontedMeek(t *testing.T) {
 			doHotReload:          false,
 			doDefaultSessionID:   false,
 			denyTrafficRules:     false,
+			requireAuthorization: true,
+			omitAuthorization:    false,
 			doClientVerification: false,
 			doTunneledWebRequest: true,
 			doTunneledNTPRequest: true,
@@ -166,6 +173,8 @@ func TestUnfrontedMeekHTTPS(t *testing.T) {
 			doHotReload:          false,
 			doDefaultSessionID:   false,
 			denyTrafficRules:     false,
+			requireAuthorization: true,
+			omitAuthorization:    false,
 			doClientVerification: false,
 			doTunneledWebRequest: true,
 			doTunneledNTPRequest: true,
@@ -180,6 +189,8 @@ func TestUnfrontedMeekSessionTicket(t *testing.T) {
 			doHotReload:          false,
 			doDefaultSessionID:   false,
 			denyTrafficRules:     false,
+			requireAuthorization: true,
+			omitAuthorization:    false,
 			doClientVerification: false,
 			doTunneledWebRequest: true,
 			doTunneledNTPRequest: true,
@@ -194,6 +205,8 @@ func TestWebTransportAPIRequests(t *testing.T) {
 			doHotReload:          false,
 			doDefaultSessionID:   false,
 			denyTrafficRules:     false,
+			requireAuthorization: false,
+			omitAuthorization:    true,
 			doClientVerification: true,
 			doTunneledWebRequest: true,
 			doTunneledNTPRequest: true,
@@ -208,6 +221,8 @@ func TestHotReload(t *testing.T) {
 			doHotReload:          true,
 			doDefaultSessionID:   false,
 			denyTrafficRules:     false,
+			requireAuthorization: true,
+			omitAuthorization:    false,
 			doClientVerification: false,
 			doTunneledWebRequest: true,
 			doTunneledNTPRequest: true,
@@ -222,6 +237,8 @@ func TestDefaultSessionID(t *testing.T) {
 			doHotReload:          true,
 			doDefaultSessionID:   true,
 			denyTrafficRules:     false,
+			requireAuthorization: true,
+			omitAuthorization:    false,
 			doClientVerification: false,
 			doTunneledWebRequest: true,
 			doTunneledNTPRequest: true,
@@ -236,6 +253,56 @@ func TestDenyTrafficRules(t *testing.T) {
 			doHotReload:          true,
 			doDefaultSessionID:   false,
 			denyTrafficRules:     true,
+			requireAuthorization: true,
+			omitAuthorization:    false,
+			doClientVerification: false,
+			doTunneledWebRequest: true,
+			doTunneledNTPRequest: true,
+		})
+}
+
+func TestOmitAuthorization(t *testing.T) {
+	runServer(t,
+		&runServerConfig{
+			tunnelProtocol:       "OSSH",
+			enableSSHAPIRequests: true,
+			doHotReload:          true,
+			doDefaultSessionID:   false,
+			denyTrafficRules:     false,
+			requireAuthorization: true,
+			omitAuthorization:    true,
+			doClientVerification: false,
+			doTunneledWebRequest: true,
+			doTunneledNTPRequest: true,
+		})
+}
+
+func TestNoAuthorization(t *testing.T) {
+	runServer(t,
+		&runServerConfig{
+			tunnelProtocol:       "OSSH",
+			enableSSHAPIRequests: true,
+			doHotReload:          true,
+			doDefaultSessionID:   false,
+			denyTrafficRules:     false,
+			requireAuthorization: false,
+			omitAuthorization:    true,
+			doClientVerification: false,
+			doTunneledWebRequest: true,
+			doTunneledNTPRequest: true,
+		})
+}
+
+func TestUnusedAuthorization(t *testing.T) {
+	runServer(t,
+		&runServerConfig{
+			tunnelProtocol:       "OSSH",
+			enableSSHAPIRequests: true,
+			doHotReload:          true,
+			doDefaultSessionID:   false,
+			denyTrafficRules:     false,
+			requireAuthorization: false,
+			omitAuthorization:    false,
 			doClientVerification: false,
 			doTunneledWebRequest: true,
 			doTunneledNTPRequest: true,
@@ -250,6 +317,8 @@ func TestTCPOnlySLOK(t *testing.T) {
 			doHotReload:          false,
 			doDefaultSessionID:   false,
 			denyTrafficRules:     false,
+			requireAuthorization: true,
+			omitAuthorization:    false,
 			doClientVerification: false,
 			doTunneledWebRequest: true,
 			doTunneledNTPRequest: false,
@@ -264,6 +333,8 @@ func TestUDPOnlySLOK(t *testing.T) {
 			doHotReload:          false,
 			doDefaultSessionID:   false,
 			denyTrafficRules:     false,
+			requireAuthorization: true,
+			omitAuthorization:    false,
 			doClientVerification: false,
 			doTunneledWebRequest: false,
 			doTunneledNTPRequest: true,
@@ -276,6 +347,8 @@ type runServerConfig struct {
 	doHotReload          bool
 	doDefaultSessionID   bool
 	denyTrafficRules     bool
+	requireAuthorization bool
+	omitAuthorization    bool
 	doClientVerification bool
 	doTunneledWebRequest bool
 	doTunneledNTPRequest bool
@@ -304,6 +377,29 @@ const dummyClientVerificationPayload = `
 
 func runServer(t *testing.T, runConfig *runServerConfig) {
 
+	// configure authorized access
+
+	accessType := "test-access-type"
+
+	accessControlSigningKey, accessControlVerificationKey, err := accesscontrol.NewKeyPair(accessType)
+	if err != nil {
+		t.Fatalf("error creating access control key pair: %s", err)
+	}
+
+	accessControlVerificationKeyRing := accesscontrol.VerificationKeyRing{
+		Keys: []*accesscontrol.VerificationKey{accessControlVerificationKey},
+	}
+
+	var authorizationID [32]byte
+
+	clientAuthorization, err := accesscontrol.IssueAuthorization(
+		accessControlSigningKey,
+		authorizationID[:],
+		time.Now().Add(1*time.Hour))
+	if err != nil {
+		t.Fatalf("error issuing authorization: %s", err)
+	}
+
 	// create a server
 
 	serverConfigJSON, _, encodedServerEntry, err := GenerateConfig(
@@ -332,7 +428,9 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	// must handshake with specified sponsor ID in order to allow ports for tunneled
 	// requests.
 	trafficRulesFilename := filepath.Join(testDataDirName, "traffic_rules.json")
-	paveTrafficRulesFile(t, trafficRulesFilename, propagationChannelID, runConfig.denyTrafficRules)
+	paveTrafficRulesFile(
+		t, trafficRulesFilename, propagationChannelID, accessType,
+		runConfig.requireAuthorization, runConfig.denyTrafficRules)
 
 	var serverConfig map[string]interface{}
 	json.Unmarshal(serverConfigJSON, &serverConfig)
@@ -343,6 +441,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	serverConfig["LogFilename"] = filepath.Join(testDataDirName, "psiphond.log")
 	serverConfig["LogLevel"] = "debug"
 
+	serverConfig["AccessControlVerificationKeyRing"] = accessControlVerificationKeyRing
+
 	// Set this parameter so at least the semaphore functions are called.
 	// TODO: test that the concurrency limit is correctly enforced.
 	serverConfig["MaxConcurrentSSHHandshakes"] = 1
@@ -400,7 +500,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		propagationChannelID = paveOSLConfigFile(t, oslConfigFilename)
 
 		paveTrafficRulesFile(
-			t, trafficRulesFilename, propagationChannelID, runConfig.denyTrafficRules)
+			t, trafficRulesFilename, propagationChannelID, accessType,
+			runConfig.requireAuthorization, runConfig.denyTrafficRules)
 
 		p, _ := os.FindProcess(os.Getpid())
 		p.Signal(syscall.SIGUSR1)
@@ -450,6 +551,10 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	clientConfig.LocalHttpProxyPort = localHTTPProxyPort
 	clientConfig.EmitSLOKs = true
 
+	if !runConfig.omitAuthorization {
+		clientConfig.Authorizations = []json.RawMessage{json.RawMessage(clientAuthorization)}
+	}
+
 	if runConfig.doClientVerification {
 		clientConfig.ClientPlatform = "Android"
 	}
@@ -554,6 +659,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		waitOnNotification(t, verificationCompleted, timeoutSignal, "verification completed timeout exceeded")
 	}
 
+	expectTrafficFailure := runConfig.denyTrafficRules || (runConfig.omitAuthorization && runConfig.requireAuthorization)
+
 	if runConfig.doTunneledWebRequest {
 
 		// Test: tunneled web site fetch
@@ -562,11 +669,11 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			t, localHTTPProxyPort, mockWebServerURL, mockWebServerExpectedResponse)
 
 		if err == nil {
-			if runConfig.denyTrafficRules {
+			if expectTrafficFailure {
 				t.Fatalf("unexpected tunneled web request success")
 			}
 		} else {
-			if !runConfig.denyTrafficRules {
+			if !expectTrafficFailure {
 				t.Fatalf("tunneled web request failed: %s", err)
 			}
 		}
@@ -581,11 +688,11 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		err = makeTunneledNTPRequest(t, localSOCKSProxyPort, udpgwServerAddress)
 
 		if err == nil {
-			if runConfig.denyTrafficRules {
+			if expectTrafficFailure {
 				t.Fatalf("unexpected tunneled NTP request success")
 			}
 		} else {
-			if !runConfig.denyTrafficRules {
+			if !expectTrafficFailure {
 				t.Fatalf("tunneled NTP request failed: %s", err)
 			}
 		}
@@ -593,7 +700,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 
 	// Test: await SLOK payload
 
-	if !runConfig.denyTrafficRules {
+	if !expectTrafficFailure {
 
 		time.Sleep(1 * time.Second)
 		waitOnNotification(t, slokSeeded, timeoutSignal, "SLOK seeded timeout exceeded")
@@ -887,7 +994,8 @@ func pavePsinetDatabaseFile(
 }
 
 func paveTrafficRulesFile(
-	t *testing.T, trafficRulesFilename, propagationChannelID string, deny bool) {
+	t *testing.T, trafficRulesFilename, propagationChannelID, accessType string,
+	requireAuthorization, deny bool) {
 
 	allowTCPPorts := fmt.Sprintf("%d", mockWebServerPort)
 	allowUDPPorts := "53, 123"
@@ -897,6 +1005,15 @@ func paveTrafficRulesFile(
 		allowUDPPorts = "0"
 	}
 
+	authorizationFilterFormat := `,
+                    "AuthorizedAccessTypes" : ["%s"]
+	`
+
+	authorizationFilter := ""
+	if requireAuthorization {
+		authorizationFilter = fmt.Sprintf(authorizationFilterFormat, accessType)
+	}
+
 	trafficRulesJSONFormat := `
     {
         "DefaultRules" :  {
@@ -912,7 +1029,7 @@ func paveTrafficRulesFile(
                 "Filter" : {
                     "HandshakeParameters" : {
                         "propagation_channel_id" : ["%s"]
-                    }
+                    }%s
                 },
                 "Rules" : {
                     "RateLimits" : {
@@ -928,7 +1045,7 @@ func paveTrafficRulesFile(
     `
 
 	trafficRulesJSON := fmt.Sprintf(
-		trafficRulesJSONFormat, propagationChannelID, allowTCPPorts, allowUDPPorts)
+		trafficRulesJSONFormat, propagationChannelID, authorizationFilter, allowTCPPorts, allowUDPPorts)
 
 	err := ioutil.WriteFile(trafficRulesFilename, []byte(trafficRulesJSON), 0600)
 	if err != nil {

+ 15 - 0
psiphon/server/trafficRules.go

@@ -83,6 +83,11 @@ type TrafficRulesFilter struct {
 	// a list of values, one of which must be specified to match this
 	// filter. Only scalar string API parameters may be filtered.
 	HandshakeParameters map[string][]string
+
+	// AuthorizedAccessTypes specifies a list of access types, at least
+	// one of which the client must have presented an active authorization
+	// for.
+	AuthorizedAccessTypes []string
 }
 
 // TrafficRules specify the limits placed on client traffic.
@@ -390,6 +395,16 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			}
 		}
 
+		if len(filteredRules.Filter.AuthorizedAccessTypes) > 0 {
+			if !state.completed {
+				continue
+			}
+
+			if !common.ContainsAny(filteredRules.Filter.AuthorizedAccessTypes, state.authorizedAccessTypes) {
+				continue
+			}
+		}
+
 		log.WithContextFields(LogFields{"filter": filteredRules.Filter}).Debug("filter match")
 
 		// This is the first match. Override defaults using provided fields from selected rules, and return result.

+ 158 - 14
psiphon/server/tunnelServer.go

@@ -22,6 +22,7 @@ package server
 import (
 	"context"
 	"crypto/subtle"
+	"encoding/hex"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -35,6 +36,7 @@ import (
 
 	"github.com/Psiphon-Inc/goarista/monotime"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/accesscontrol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
@@ -53,6 +55,7 @@ const (
 	SSH_SEND_OSL_INITIAL_RETRY_DELAY      = 30 * time.Second
 	SSH_SEND_OSL_RETRY_FACTOR             = 2
 	OSL_SESSION_CACHE_TTL                 = 5 * time.Minute
+	MAX_AUTHORIZATIONS                    = 16
 )
 
 // TunnelServer is the main server that accepts Psiphon client
@@ -222,10 +225,16 @@ func (server *TunnelServer) ResetAllClientOSLConfigs() {
 // also triggers an immediate traffic rule re-selection, as the rules selected
 // upon tunnel establishment may no longer apply now that handshake values are
 // set.
+//
+// The authorizations received from the client handshake are verified and the
+// resulting list of authorized access types are applied to the client's tunnel
+// and traffic rules. A list of authorized access types is returned.
 func (server *TunnelServer) SetClientHandshakeState(
-	sessionID string, state handshakeState) error {
+	sessionID string,
+	state handshakeState,
+	authorizations [][]byte) ([]string, error) {
 
-	return server.sshServer.setClientHandshakeState(sessionID, state)
+	return server.sshServer.setClientHandshakeState(sessionID, state, authorizations)
 }
 
 // GetClientHandshaked indicates whether the client has completed a handshake
@@ -264,6 +273,7 @@ type sshServer struct {
 	clients                 map[string]*sshClient
 	oslSessionCacheMutex    sync.Mutex
 	oslSessionCache         *cache.Cache
+	activeAuthorizationIDs  sync.Map
 }
 
 func newSSHServer(
@@ -488,6 +498,13 @@ func (sshServer *sshServer) registerEstablishedClient(client *sshClient) bool {
 	// Call stop() outside the mutex to avoid deadlock.
 	if existingClient != nil {
 		existingClient.stop()
+
+		// Since existingClient.run() isn't guaranteed to have terminated at
+		// this point, synchronously release authorizations for the previous
+		// client here. This ensures that the authorization IDs are not in
+		// use when the reconnecting client submits its authorizations.
+		existingClient.cleanupAuthorizations()
+
 		log.WithContext().Debug(
 			"stopped existing client with duplicate session ID")
 	}
@@ -663,22 +680,24 @@ func (sshServer *sshServer) resetAllClientOSLConfigs() {
 }
 
 func (sshServer *sshServer) setClientHandshakeState(
-	sessionID string, state handshakeState) error {
+	sessionID string,
+	state handshakeState,
+	authorizations [][]byte) ([]string, error) {
 
 	sshServer.clientsMutex.Lock()
 	client := sshServer.clients[sessionID]
 	sshServer.clientsMutex.Unlock()
 
 	if client == nil {
-		return common.ContextError(errors.New("unknown session ID"))
+		return nil, common.ContextError(errors.New("unknown session ID"))
 	}
 
-	err := client.setHandshakeState(state)
+	authorizedAccessTypes, err := client.setHandshakeState(state, authorizations)
 	if err != nil {
-		return common.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
-	return nil
+	return authorizedAccessTypes, nil
 }
 
 func (sshServer *sshServer) getClientHandshaked(
@@ -821,6 +840,8 @@ type sshClient struct {
 	runCtx                               context.Context
 	stopRunning                          context.CancelFunc
 	tcpPortForwardDialingAvailableSignal context.CancelFunc
+	releaseAuthorizations                func()
+	stopTimer                            *time.Timer
 }
 
 type trafficState struct {
@@ -848,9 +869,10 @@ type qualityMetrics struct {
 }
 
 type handshakeState struct {
-	completed   bool
-	apiProtocol string
-	apiParams   requestJSONObject
+	completed             bool
+	apiProtocol           string
+	apiParams             requestJSONObject
+	authorizedAccessTypes []string
 }
 
 func newSshClient(
@@ -1187,10 +1209,20 @@ func (sshClient *sshClient) runTunnel(
 			if request.Type == "keepalive@openssh.com" {
 				// Keepalive requests have an empty response.
 			} else {
+
 				// All other requests are assumed to be API requests.
+
+				sshClient.Lock()
+				authorizedAccessTypes := sshClient.handshakeState.authorizedAccessTypes
+				sshClient.Unlock()
+
+				// Note: unlock before use is only safe as long as referenced sshClient data,
+				// such as slices in handshakeState, is read-only after initially set.
+
 				responsePayload, err = sshAPIRequestHandler(
 					sshClient.sshServer.support,
 					sshClient.geoIPData,
+					authorizedAccessTypes,
 					request.Type,
 					request.Payload)
 			}
@@ -1506,6 +1538,22 @@ func (sshClient *sshClient) runTunnel(
 	}
 
 	waitGroup.Wait()
+
+	sshClient.cleanupAuthorizations()
+}
+
+func (sshClient *sshClient) cleanupAuthorizations() {
+	sshClient.Lock()
+
+	if sshClient.releaseAuthorizations != nil {
+		sshClient.releaseAuthorizations()
+	}
+
+	if sshClient.stopTimer != nil {
+		sshClient.stopTimer.Stop()
+	}
+
+	sshClient.Unlock()
 }
 
 // setPacketTunnelChannel sets the single packet tunnel channel
@@ -1547,9 +1595,9 @@ func (sshClient *sshClient) logTunnel(additionalMetrics LogFields) {
 	sshClient.Lock()
 
 	logFields := getRequestLogFields(
-		sshClient.sshServer.support,
 		"server_tunnel",
 		sshClient.geoIPData,
+		sshClient.handshakeState.authorizedAccessTypes,
 		sshClient.handshakeState.apiParams,
 		baseRequestParams)
 
@@ -1579,6 +1627,9 @@ func (sshClient *sshClient) logTunnel(additionalMetrics LogFields) {
 
 	sshClient.Unlock()
 
+	// Note: unlock before use is only safe as long as referenced sshClient data,
+	// such as slices in handshakeState, is read-only after initially set.
+
 	log.LogRawFieldsWithTimestamp(logFields)
 }
 
@@ -1673,7 +1724,9 @@ func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason s
 // selection. Port forwards are disallowed until a handshake is complete. The
 // handshake parameters are included in the session summary log recorded in
 // sshClient.stop().
-func (sshClient *sshClient) setHandshakeState(state handshakeState) error {
+func (sshClient *sshClient) setHandshakeState(
+	state handshakeState,
+	authorizations [][]byte) ([]string, error) {
 
 	sshClient.Lock()
 	completed := sshClient.handshakeState.completed
@@ -1684,13 +1737,104 @@ func (sshClient *sshClient) setHandshakeState(state handshakeState) error {
 
 	// Client must only perform one handshake
 	if completed {
-		return common.ContextError(errors.New("handshake already completed"))
+		return nil, common.ContextError(errors.New("handshake already completed"))
+	}
+
+	// Verify the authorizations submitted by the client. Verified, active (non-expired)
+	// access types will be available for traffic rules filtering.
+	//
+	// When an authorization is active but expires while the client is connected, the
+	// client is disconnected to ensure the access is revoked. This is implemented by
+	// setting a timer to perform the disconnect at the expiry time of the soonest
+	// expiring authorization.
+	//
+	// sshServer.activeAuthorizationIDs tracks the unique IDs of active authorizations
+	// and is used to detect and prevent multiple malicious clients from reusing a
+	// single authorization (within the scope of this server).
+
+	var authorizationIDs []string
+	var authorizedAccessTypes []string
+	var stopTime time.Time
+
+	for i, authorization := range authorizations {
+
+		// This sanity check mitigates malicious clients causing excess CPU use.
+		if i >= MAX_AUTHORIZATIONS {
+			log.WithContext().Warning("too many authorizations")
+			break
+		}
+
+		verifiedAuthorization, err := accesscontrol.VerifyAuthorization(
+			&sshClient.sshServer.support.Config.AccessControlVerificationKeyRing,
+			authorization)
+
+		if err != nil {
+			log.WithContextFields(
+				LogFields{"error": err}).Warning("verify authorization failed")
+			continue
+		}
+
+		authorizationID := hex.EncodeToString(verifiedAuthorization.ID)
+
+		// A client may reconnect while the server still has an active sshClient for that
+		// client session. In this case, the previous sshClient is closed by the new
+		// client's call to sshServer.registerEstablishedClient.
+		// This is assumed to call sshClient.releaseAuthorizations which will remove
+		// the client's authorization IDs before this check is reached.
+
+		if _, exists := sshClient.sshServer.activeAuthorizationIDs.LoadOrStore(authorizationID, true); exists {
+			log.WithContextFields(
+				LogFields{"ID": verifiedAuthorization.ID}).Warning("duplicate active authorization")
+			continue
+		}
+
+		if common.Contains(authorizedAccessTypes, verifiedAuthorization.AccessType) {
+			log.WithContextFields(
+				LogFields{"accessType": verifiedAuthorization.AccessType}).Warning("duplicate authorization access type")
+			continue
+		}
+
+		authorizationIDs = append(authorizationIDs, authorizationID)
+		authorizedAccessTypes = append(authorizedAccessTypes, verifiedAuthorization.AccessType)
+
+		if stopTime.IsZero() || stopTime.After(verifiedAuthorization.Expires) {
+			stopTime = verifiedAuthorization.Expires
+		}
+	}
+
+	if len(authorizationIDs) > 0 {
+
+		sshClient.Lock()
+
+		// Make the authorizedAccessTypes available for traffic rules filtering.
+
+		sshClient.handshakeState.authorizedAccessTypes = authorizedAccessTypes
+
+		// On exit, sshClient.runTunnel will call releaseAuthorizations, which
+		// will release the authorization IDs so the client can reconnect and
+		// present the same authorizations again. sshClient.runTunnel will
+		// also cancel the stopTimer in case it has not yet fired.
+		// Note: termination of the stopTimer goroutine is not synchronized.
+
+		sshClient.releaseAuthorizations = func() {
+			for _, ID := range authorizationIDs {
+				sshClient.sshServer.activeAuthorizationIDs.Delete(ID)
+			}
+		}
+
+		sshClient.stopTimer = time.AfterFunc(
+			stopTime.Sub(time.Now()),
+			func() {
+				sshClient.stop()
+			})
+
+		sshClient.Unlock()
 	}
 
 	sshClient.setTrafficRules()
 	sshClient.setOSLConfig()
 
-	return nil
+	return authorizedAccessTypes, nil
 }
 
 // getHandshaked returns whether the client has completed a handshake API

+ 10 - 6
psiphon/server/webServer.go

@@ -48,14 +48,14 @@ type webServer struct {
 //
 // The HTTP request handlers are light wrappers around the base Psiphon
 // API request handlers from the SSH API transport. The SSH API transport
-// is preferred by new clients; however the web API transport is still
-// required for untunneled final status requests. The web API transport
-// may be retired once untunneled final status requests are made obsolete
-// (e.g., by server-side bytes transferred stats, by client-side local
-// storage of stats for retry, or some other future development).
+// is preferred by new clients. The web API transport provides support for
+// older clients.
 //
 // The API is compatible with all tunnel-core clients but not backwards
-// compatible with older clients.
+// compatible with all legacy clients.
+//
+// Note: new features, including authorizations, are not supported in the
+// web API transport.
 //
 func RunWebServer(
 	support *SupportServices,
@@ -237,6 +237,7 @@ func (webServer *webServer) handshakeHandler(w http.ResponseWriter, r *http.Requ
 			webServer.support,
 			protocol.PSIPHON_WEB_API_PROTOCOL,
 			webServer.lookupGeoIPData(params),
+			nil,
 			protocol.PSIPHON_API_HANDSHAKE_REQUEST_NAME,
 			params)
 	}
@@ -267,6 +268,7 @@ func (webServer *webServer) connectedHandler(w http.ResponseWriter, r *http.Requ
 			webServer.support,
 			protocol.PSIPHON_WEB_API_PROTOCOL,
 			webServer.lookupGeoIPData(params),
+			nil, // authorizedAccessTypes not logged in web API transport
 			protocol.PSIPHON_API_CONNECTED_REQUEST_NAME,
 			params)
 	}
@@ -291,6 +293,7 @@ func (webServer *webServer) statusHandler(w http.ResponseWriter, r *http.Request
 			webServer.support,
 			protocol.PSIPHON_WEB_API_PROTOCOL,
 			webServer.lookupGeoIPData(params),
+			nil, // authorizedAccessTypes not logged in web API transport
 			protocol.PSIPHON_API_STATUS_REQUEST_NAME,
 			params)
 	}
@@ -315,6 +318,7 @@ func (webServer *webServer) clientVerificationHandler(w http.ResponseWriter, r *
 			webServer.support,
 			protocol.PSIPHON_WEB_API_PROTOCOL,
 			webServer.lookupGeoIPData(params),
+			nil, // authorizedAccessTypes not logged in web API transport
 			protocol.PSIPHON_API_CLIENT_VERIFICATION_REQUEST_NAME,
 			params)
 	}

+ 4 - 14
psiphon/serverApi.go

@@ -121,23 +121,11 @@ func (serverContext *ServerContext) doHandshakeRequest(
 
 	params := serverContext.getBaseParams()
 
-	// *TODO*: this is obsolete?
-	/*
-		serverEntryIpAddresses, err := GetServerEntryIpAddresses()
-		if err != nil {
-			return common.ContextError(err)
-		}
-
-		// Submit a list of known servers -- this will be used for
-		// discovery statistics.
-		for _, ipAddress := range serverEntryIpAddresses {
-			params = append(params, requestParam{"known_server", ipAddress})
-		}
-	*/
-
 	var response []byte
 	if serverContext.psiphonHttpsClient == nil {
 
+		params[protocol.PSIPHON_API_HANDSHAKE_AUTHORIZATIONS] = serverContext.tunnel.config.Authorizations
+
 		request, err := makeSSHAPIRequestPayload(params)
 		if err != nil {
 			return common.ContextError(err)
@@ -241,6 +229,8 @@ func (serverContext *ServerContext) doHandshakeRequest(
 	serverContext.serverHandshakeTimestamp = handshakeResponse.ServerTimestamp
 	NoticeServerTimestamp(serverContext.serverHandshakeTimestamp)
 
+	NoticeAuthorizedAccessTypes(handshakeResponse.AuthorizedAccessTypes)
+
 	return nil
 }