Просмотр исходного кода

Enhanced SLOK delivery
- UpdateProgress sends signal when any seed spec target is achieved
- psiphond receives signal, issues SLOKs, and sends any newly
issued SLOKs to client
- payload is delivered with new server-to-client SSH request
instead of using "status" response
- SLOKs only sent to clients reporting support for new
server-to-client SSH request mechanism; this is done via
a new ClientCapabilities included in the SSH password
payload

Rod Hynes 9 лет назад
Родитель
Сommit
eefc457f24

+ 3 - 0
ConsoleClient/main.go

@@ -68,16 +68,19 @@ func main() {
 	// Handle required config file parameter
 
 	if configFilename == "" {
+		psiphon.SetEmitDiagnosticNotices(true)
 		psiphon.NoticeError("configuration file is required")
 		os.Exit(1)
 	}
 	configFileContents, err := ioutil.ReadFile(configFilename)
 	if err != nil {
+		psiphon.SetEmitDiagnosticNotices(true)
 		psiphon.NoticeError("error loading configuration file: %s", err)
 		os.Exit(1)
 	}
 	config, err := psiphon.LoadConfig(configFileContents)
 	if err != nil {
+		psiphon.SetEmitDiagnosticNotices(true)
 		psiphon.NoticeError("error processing configuration file: %s", err)
 		os.Exit(1)
 	}

+ 100 - 42
psiphon/common/osl/osl.go

@@ -140,8 +140,7 @@ type Scheme struct {
 	// The following fields are ephemeral state.
 
 	epoch                 time.Time
-	subnetLookups         map[*SeedSpec]common.SubnetLookup
-	subnetLookup          common.SubnetLookup
+	subnetLookups         []common.SubnetLookup
 	derivedSLOKCacheMutex sync.RWMutex
 	derivedSLOKCache      map[slokReference]*SLOK
 }
@@ -186,10 +185,12 @@ type KeySplit struct {
 type ClientSeedState struct {
 	scheme               *Scheme
 	propagationChannelID string
-	progressSLOKTime     int64
-	progress             map[*SeedSpec]*TrafficValues
+	signalIssueSLOKs     chan struct{}
+	progress             []*TrafficValues
 	mutex                sync.Mutex
+	progressSLOKTime     int64
 	issuedSLOKs          map[string]*SLOK
+	payloadSLOKs         []*SLOK
 }
 
 // ClientSeedPortForward map a client port forward, which is relaying
@@ -199,8 +200,8 @@ type ClientSeedState struct {
 // and duration count towards the progress of these SeedSpecs and
 // associated SLOKs.
 type ClientSeedPortForward struct {
-	state    *ClientSeedState
-	progress []*TrafficValues
+	state           *ClientSeedState
+	progressIndexes []int
 }
 
 // slokReference uniquely identifies a SLOK by specifying all the fields
@@ -286,14 +287,14 @@ func loadConfig(configJSON []byte) (*Config, error) {
 		previousEpoch = epoch
 
 		scheme.epoch = epoch
-		scheme.subnetLookups = make(map[*SeedSpec]common.SubnetLookup)
+		scheme.subnetLookups = make([]common.SubnetLookup, len(scheme.SeedSpecs))
 		scheme.derivedSLOKCache = make(map[slokReference]*SLOK)
 
 		if len(scheme.MasterKey) != KEY_LENGTH_BYTES {
 			return nil, common.ContextError(errors.New("invalid master key"))
 		}
 
-		for _, seedSpec := range scheme.SeedSpecs {
+		for index, seedSpec := range scheme.SeedSpecs {
 			if len(seedSpec.ID) != KEY_LENGTH_BYTES {
 				return nil, common.ContextError(errors.New("invalid seed spec ID"))
 			}
@@ -304,7 +305,7 @@ func loadConfig(configJSON []byte) (*Config, error) {
 				return nil, common.ContextError(fmt.Errorf("invalid upstream subnets: %s", err))
 			}
 
-			scheme.subnetLookups[seedSpec] = subnetLookup
+			scheme.subnetLookups[index] = subnetLookup
 		}
 
 		if !isValidShamirSplit(len(scheme.SeedSpecs), scheme.SeedSpecThreshold) {
@@ -328,8 +329,16 @@ func loadConfig(configJSON []byte) (*Config, error) {
 // NewClientSeedState creates a new client seed state to track
 // client progress towards seeding SLOKs. psiphond maintains one
 // ClientSeedState for each connected client.
+//
+// A signal is sent on signalIssueSLOKs when sufficient progress
+// has been made that a new SLOK *may* be issued. psiphond will
+// receive the signal and then call GetClientSeedPayload/IssueSLOKs
+// to issue SLOKs, generate payload, and send to the client. The
+// sender will not block sending to signalIssueSLOKs; the channel
+// should be appropriately buffered.
 func (config *Config) NewClientSeedState(
-	clientRegion, propagationChannelID string) *ClientSeedState {
+	clientRegion, propagationChannelID string,
+	signalIssueSLOKs chan struct{}) *ClientSeedState {
 
 	config.ReloadableFile.RLock()
 	defer config.ReloadableFile.RUnlock()
@@ -344,19 +353,21 @@ func (config *Config) NewClientSeedState(
 			(len(scheme.Regions) == 0 || common.Contains(scheme.Regions, clientRegion)) {
 
 			// Empty progress is initialized up front for all seed specs. Once
-			// created, the progress map structure is read-only (the map, not the
+			// created, the progress structure is read-only (the slice, not the
 			// TrafficValue fields); this permits lock-free operation.
-			progress := make(map[*SeedSpec]*TrafficValues)
-			for _, seedSpec := range scheme.SeedSpecs {
-				progress[seedSpec] = &TrafficValues{}
+			progress := make([]*TrafficValues, len(scheme.SeedSpecs))
+			for index := 0; index < len(scheme.SeedSpecs); index++ {
+				progress[index] = &TrafficValues{}
 			}
 
 			return &ClientSeedState{
 				scheme:               scheme,
 				propagationChannelID: propagationChannelID,
+				signalIssueSLOKs:     signalIssueSLOKs,
 				progressSLOKTime:     getSLOKTime(scheme.SeedPeriodNanoseconds),
 				progress:             progress,
 				issuedSLOKs:          make(map[string]*SLOK),
+				payloadSLOKs:         nil,
 			}
 		}
 	}
@@ -382,28 +393,39 @@ func (state *ClientSeedState) NewClientSeedPortForward(
 		return nil
 	}
 
-	var progress []*TrafficValues
+	var progressIndexes []int
 
 	// Determine which seed spec subnets contain upstreamIPAddress
 	// and point to the progress for each. When progress is reported,
 	// it is added directly to all of these TrafficValues instances.
+	// Assumes state.progress entries correspond 1-to-1 with
+	// state.scheme.subnetLookups.
 	// Note: this implementation assumes a small number of seed specs.
 	// For larger numbers, instead of N SubnetLookups, create a single
 	// SubnetLookup which returns, for a given IP address, all matching
 	// subnets and associated seed specs.
-	for seedSpec, subnetLookup := range state.scheme.subnetLookups {
+	for index, subnetLookup := range state.scheme.subnetLookups {
 		if subnetLookup.ContainsIPAddress(upstreamIPAddress) {
-			progress = append(progress, state.progress[seedSpec])
+			progressIndexes = append(progressIndexes, index)
 		}
 	}
 
-	if progress == nil {
+	if progressIndexes == nil {
 		return nil
 	}
 
 	return &ClientSeedPortForward{
-		state:    state,
-		progress: progress,
+		state:           state,
+		progressIndexes: progressIndexes,
+	}
+}
+
+func (state *ClientSeedState) sendIssueSLOKsSignal() {
+	if state.signalIssueSLOKs != nil {
+		select {
+		case state.signalIssueSLOKs <- *new(struct{}):
+		default:
+		}
 	}
 }
 
@@ -419,9 +441,9 @@ func (state *ClientSeedState) NewClientSeedPortForward(
 func (portForward *ClientSeedPortForward) UpdateProgress(
 	bytesRead, bytesWritten int64, durationNanoseconds int64) {
 
-	// Concurrency: access to ClientSeedState is unsynchronized to read-only
-	// fields or atomic, except in the case of a time period rollover, in which
-	// case a mutex is acquired.
+	// Concurrency: non-blocking -- access to ClientSeedState is unsynchronized
+	// to read-only fields, atomic, or channels, except in the case of a time
+	// period rollover, in which case a mutex is acquired.
 
 	slokTime := getSLOKTime(portForward.state.scheme.SeedPeriodNanoseconds)
 
@@ -436,6 +458,12 @@ func (portForward *ClientSeedPortForward) UpdateProgress(
 		portForward.state.mutex.Lock()
 		portForward.state.issueSLOKs()
 		portForward.state.mutex.Unlock()
+
+		// Call to issueSLOKs may have issued new SLOKs. Note that
+		// this will only happen if the time period rolls over with
+		// sufficient progress pending while the signalIssueSLOKs
+		// receiver did not call IssueSLOKs soon enough.
+		portForward.state.sendIssueSLOKsSignal()
 	}
 
 	// Add directly to the permanent TrafficValues progress accumulators
@@ -444,23 +472,40 @@ func (portForward *ClientSeedPortForward) UpdateProgress(
 	// goroutine may be invoking issueSLOKs, which zeros all the accumulators.
 	// As a consequence, progress may be dropped at the exact time of
 	// time period rollover.
-	for _, progress := range portForward.progress {
+	for _, progressIndex := range portForward.progressIndexes {
+
+		seedSpec := portForward.state.scheme.SeedSpecs[progressIndex]
+		progress := portForward.state.progress[progressIndex]
+
+		alreadyExceedsTargets := progress.exceeds(&seedSpec.Targets)
+
 		atomic.AddInt64(&progress.BytesRead, bytesRead)
 		atomic.AddInt64(&progress.BytesWritten, bytesWritten)
 		atomic.AddInt64(&progress.PortForwardDurationNanoseconds, durationNanoseconds)
+
+		// With the target newly met for a SeedSpec, a new
+		// SLOK *may* be issued.
+		if !alreadyExceedsTargets && progress.exceeds(&seedSpec.Targets) {
+			portForward.state.sendIssueSLOKsSignal()
+		}
 	}
 }
 
-// IssueSLOKs checks client progress against each candidate seed spec
+func (lhs *TrafficValues) exceeds(rhs *TrafficValues) bool {
+	return atomic.LoadInt64(&lhs.BytesRead) >= atomic.LoadInt64(&rhs.BytesRead) &&
+		atomic.LoadInt64(&lhs.BytesWritten) >= atomic.LoadInt64(&rhs.BytesWritten) &&
+		atomic.LoadInt64(&lhs.PortForwardDurationNanoseconds) >=
+			atomic.LoadInt64(&rhs.PortForwardDurationNanoseconds)
+}
+
+// issueSLOKs checks client progress against each candidate seed spec
 // and seeds SLOKs when the client traffic levels are achieved. After
 // checking progress, and if the SLOK time period has changed since
 // progress was last recorded, progress is reset. Partial, insufficient
 // progress is intentionally dropped when the time period rolls over.
 // Derived SLOKs are cached to avoid redundant CPU intensive operations.
 // All issued SLOKs are retained in the client state for the duration
-// of the client's session. As there is no mechanism for the client to
-// explicitly acknowledge recieved SLOKs, it is intended that SLOKs
-// will be resent to the client.
+// of the client's session.
 func (state *ClientSeedState) issueSLOKs() {
 
 	// Concurrency: the caller must lock state.mutex.
@@ -471,12 +516,11 @@ func (state *ClientSeedState) issueSLOKs() {
 
 	progressSLOKTime := time.Unix(0, state.progressSLOKTime)
 
-	for seedSpec, progress := range state.progress {
+	for index, progress := range state.progress {
+
+		seedSpec := state.scheme.SeedSpecs[index]
 
-		if atomic.LoadInt64(&progress.BytesRead) >= seedSpec.Targets.BytesRead &&
-			atomic.LoadInt64(&progress.BytesWritten) >= seedSpec.Targets.BytesWritten &&
-			atomic.LoadInt64(&progress.PortForwardDurationNanoseconds) >=
-				seedSpec.Targets.PortForwardDurationNanoseconds {
+		if progress.exceeds(&seedSpec.Targets) {
 
 			ref := &slokReference{
 				PropagationChannelID: state.propagationChannelID,
@@ -494,7 +538,12 @@ func (state *ClientSeedState) issueSLOKs() {
 				state.scheme.derivedSLOKCacheMutex.Unlock()
 			}
 
-			state.issuedSLOKs[string(slok.ID)] = slok
+			// Previously issued SLOKs are not re-added to
+			// the payload.
+			if state.issuedSLOKs[string(slok.ID)] == nil {
+				state.issuedSLOKs[string(slok.ID)] = slok
+				state.payloadSLOKs = append(state.payloadSLOKs, slok)
+			}
 		}
 	}
 
@@ -543,8 +592,8 @@ func deriveSLOK(
 }
 
 // GetSeedPayload issues any pending SLOKs and returns the accumulated
-// SLOKs for a given client. psiphond will periodically call this and
-// return the SLOKs in API request responses.
+// SLOKs for a given client. psiphond will calls this when it receives
+// signalIssueSLOKs which is the trigger to check for new SLOKs.
 // Note: caller must not modify the SLOKs in SeedPayload.SLOKs
 // as these are shared data.
 func (state *ClientSeedState) GetSeedPayload() *SeedPayload {
@@ -552,17 +601,15 @@ func (state *ClientSeedState) GetSeedPayload() *SeedPayload {
 	state.mutex.Lock()
 	defer state.mutex.Unlock()
 
-	state.issueSLOKs()
-
 	if state.scheme == nil {
 		return &SeedPayload{}
 	}
 
-	sloks := make([]*SLOK, len(state.issuedSLOKs))
-	index := 0
-	for _, slok := range state.issuedSLOKs {
+	state.issueSLOKs()
+
+	sloks := make([]*SLOK, len(state.payloadSLOKs))
+	for index, slok := range state.payloadSLOKs {
 		sloks[index] = slok
-		index++
 	}
 
 	return &SeedPayload{
@@ -570,6 +617,17 @@ func (state *ClientSeedState) GetSeedPayload() *SeedPayload {
 	}
 }
 
+// ClearSeedPayload resets the accumulated SLOK payload (but not SLOK
+// progress). psiphond calls this after the client has acknowledged
+// receipt of a payload.
+func (state *ClientSeedState) ClearSeedPayload() {
+
+	state.mutex.Lock()
+	defer state.mutex.Unlock()
+
+	state.payloadSLOKs = nil
+}
+
 // PaveFile describes an OSL data file to be paved to an out-of-band
 // distribution drop site. There are two types of files: a directory,
 // which describes how to assemble keys for OSLs, and the encrypted

+ 30 - 3
psiphon/common/osl/osl_test.go

@@ -169,7 +169,7 @@ func TestOSL(t *testing.T) {
 
 	t.Run("ineligible client, sufficient transfer", func(t *testing.T) {
 
-		clientSeedState := config.NewClientSeedState("US", "C5E8D2EDFD093B50D8D65CF59D0263CA")
+		clientSeedState := config.NewClientSeedState("US", "C5E8D2EDFD093B50D8D65CF59D0263CA", nil)
 
 		seedPortForward := clientSeedState.NewClientSeedPortForward(net.ParseIP("192.168.0.1"))
 
@@ -179,7 +179,8 @@ func TestOSL(t *testing.T) {
 	})
 
 	// This clientSeedState is used across multiple tests.
-	clientSeedState := config.NewClientSeedState("US", "2995DB0C968C59C4F23E87988D9C0D41")
+	signalIssueSLOKs := make(chan struct{}, 1)
+	clientSeedState := config.NewClientSeedState("US", "2995DB0C968C59C4F23E87988D9C0D41", signalIssueSLOKs)
 
 	t.Run("eligible client, no transfer", func(t *testing.T) {
 
@@ -224,6 +225,12 @@ func TestOSL(t *testing.T) {
 
 		clientSeedPortForward.UpdateProgress(5, 5, 5)
 
+		select {
+		case <-signalIssueSLOKs:
+		default:
+			t.Fatalf("expected issue SLOKs signal")
+		}
+
 		if len(clientSeedState.GetSeedPayload().SLOKs) != 1 {
 			t.Fatalf("expected 1 SLOKs, got %d", len(clientSeedState.GetSeedPayload().SLOKs))
 		}
@@ -237,6 +244,12 @@ func TestOSL(t *testing.T) {
 
 		clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1")).UpdateProgress(5, 5, 5)
 
+		select {
+		case <-signalIssueSLOKs:
+		default:
+			t.Fatalf("expected issue SLOKs signal")
+		}
+
 		// Expect 2 SLOKS: 1 new, and 1 remaining in payload.
 		if len(clientSeedState.GetSeedPayload().SLOKs) != 2 {
 			t.Fatalf("expected 2 SLOKs, got %d", len(clientSeedState.GetSeedPayload().SLOKs))
@@ -251,17 +264,31 @@ func TestOSL(t *testing.T) {
 
 		clientSeedState.NewClientSeedPortForward(net.ParseIP("10.0.0.1")).UpdateProgress(5, 5, 5)
 
+		select {
+		case <-signalIssueSLOKs:
+		default:
+			t.Fatalf("expected issue SLOKs signal")
+		}
+
 		// Expect 4 SLOKS: 2 new, and 2 remaining in payload.
 		if len(clientSeedState.GetSeedPayload().SLOKs) != 4 {
 			t.Fatalf("expected 4 SLOKs, got %d", len(clientSeedState.GetSeedPayload().SLOKs))
 		}
 	})
 
+	t.Run("clear payload", func(t *testing.T) {
+		clientSeedState.ClearSeedPayload()
+
+		if len(clientSeedState.GetSeedPayload().SLOKs) != 0 {
+			t.Fatalf("expected 0 SLOKs, got %d", len(clientSeedState.GetSeedPayload().SLOKs))
+		}
+	})
+
 	t.Run("no transfer required", func(t *testing.T) {
 
 		rolloverToNextSLOKTime()
 
-		clientSeedState := config.NewClientSeedState("US", "36F1CF2DF1250BF0C7BA0629CE3DC657")
+		clientSeedState := config.NewClientSeedState("US", "36F1CF2DF1250BF0C7BA0629CE3DC657", nil)
 
 		if len(clientSeedState.GetSeedPayload().SLOKs) != 1 {
 			t.Fatalf("expected 1 SLOKs, got %d", len(clientSeedState.GetSeedPayload().SLOKs))

+ 10 - 1
psiphon/common/protocol/protocol.go

@@ -39,10 +39,13 @@ const (
 	CAPABILITY_SSH_API_REQUESTS            = "ssh-api-requests"
 	CAPABILITY_UNTUNNELED_WEB_API_REQUESTS = "handshake"
 
+	CLIENT_CAPABILITY_SERVER_REQUESTS = "server-requests"
+
 	PSIPHON_API_HANDSHAKE_REQUEST_NAME           = "psiphon-handshake"
 	PSIPHON_API_CONNECTED_REQUEST_NAME           = "psiphon-connected"
 	PSIPHON_API_STATUS_REQUEST_NAME              = "psiphon-status"
 	PSIPHON_API_CLIENT_VERIFICATION_REQUEST_NAME = "psiphon-client-verification"
+	PSIPHON_API_OSL_REQUEST_NAME                 = "psiphon-osl"
 
 	PSIPHON_API_CLIENT_SESSION_ID_LENGTH = 16
 
@@ -98,6 +101,12 @@ type ConnectedResponse struct {
 	ConnectedTimestamp string `json:"connected_timestamp"`
 }
 
-type StatusResponse struct {
+type OSLRequest struct {
 	SeedPayload *osl.SeedPayload `json:"seed_payload"`
 }
+
+type SSHPasswordPayload struct {
+	SessionId          string   `json:"SessionId"`
+	SshPassword        string   `json:"SshPassword"`
+	ClientCapabilities []string `json:"ClientCapabilities"`
+}

+ 1 - 20
psiphon/server/api.go

@@ -369,30 +369,11 @@ func statusAPIRequestHandler(
 		}
 	}
 
-	// Note: ignoring param format errors as params have been validated
-	sessionID, _ := getStringRequestParam(params, "client_session_id")
-
-	// TODO: in the case of SSH API requests, the actual sshClient could
-	// be passed in and used directly.
-	seedPayload, err := support.TunnelServer.GetClientSeedPayload(sessionID)
-	if err != nil {
-		return nil, common.ContextError(err)
-	}
-
-	statusResponse := protocol.StatusResponse{
-		SeedPayload: seedPayload,
-	}
-
-	responsePayload, err := json.Marshal(statusResponse)
-	if err != nil {
-		return nil, common.ContextError(err)
-	}
-
 	for _, logItem := range logQueue {
 		log.LogRawFieldsWithTimestamp(logItem)
 	}
 
-	return responsePayload, nil
+	return make([]byte, 0), nil
 }
 
 // clientVerificationAPIRequestHandler implements the

+ 106 - 29
psiphon/server/tunnelServer.go

@@ -45,6 +45,8 @@ const (
 	SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT = 30 * time.Second
 	SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT      = 30 * time.Second
 	SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE  = 8192
+	SSH_SEND_OSL_INITIAL_RETRY_DELAY       = 30 * time.Second
+	SSH_SEND_OSL_RETRY_FACTOR              = 2
 )
 
 // TunnelServer is the main server that accepts Psiphon client
@@ -217,14 +219,6 @@ func (server *TunnelServer) SetClientHandshakeState(
 	return server.sshServer.setClientHandshakeState(sessionID, state)
 }
 
-// GetClientSeedPayload gets the current OSL seed payload for the specified
-// client session. Any seeded SLOKs are issued and included in the payload.
-func (server *TunnelServer) GetClientSeedPayload(
-	sessionID string) (*osl.SeedPayload, error) {
-
-	return server.sshServer.getClientSeedPayload(sessionID)
-}
-
 // SetEstablishTunnels sets whether new tunnels may be established or not.
 // When not establishing, incoming connections are immediately closed.
 func (server *TunnelServer) SetEstablishTunnels(establish bool) {
@@ -558,20 +552,6 @@ func (sshServer *sshServer) setClientHandshakeState(
 	return nil
 }
 
-func (sshServer *sshServer) getClientSeedPayload(
-	sessionID string) (*osl.SeedPayload, error) {
-
-	sshServer.clientsMutex.Lock()
-	client := sshServer.clients[sessionID]
-	sshServer.clientsMutex.Unlock()
-
-	if client == nil {
-		return nil, common.ContextError(errors.New("unknown session ID"))
-	}
-
-	return client.getClientSeedPayload(), nil
-}
-
 func (sshServer *sshServer) stopClients() {
 
 	sshServer.clientsMutex.Lock()
@@ -607,6 +587,7 @@ type sshClient struct {
 	throttledConn           *common.ThrottledConn
 	geoIPData               GeoIPData
 	sessionID               string
+	supportsServerRequests  bool
 	handshakeState          handshakeState
 	udpChannel              ssh.Channel
 	trafficRules            TrafficRules
@@ -616,6 +597,7 @@ type sshClient struct {
 	channelHandlerWaitGroup *sync.WaitGroup
 	tcpPortForwardLRU       *common.LRUConns
 	oslClientSeedState      *osl.ClientSeedState
+	signalIssueSLOKs        chan struct{}
 	stopBroadcast           chan struct{}
 }
 
@@ -653,6 +635,7 @@ func newSshClient(
 		geoIPData:               geoIPData,
 		channelHandlerWaitGroup: new(sync.WaitGroup),
 		tcpPortForwardLRU:       common.NewLRUConns(),
+		signalIssueSLOKs:        make(chan struct{}, 1),
 		stopBroadcast:           make(chan struct{}),
 	}
 }
@@ -784,10 +767,7 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
 	expectedSessionIDLength := 2 * protocol.PSIPHON_API_CLIENT_SESSION_ID_LENGTH
 	expectedSSHPasswordLength := 2 * SSH_PASSWORD_BYTE_LENGTH
 
-	var sshPasswordPayload struct {
-		SessionId   string `json:"SessionId"`
-		SshPassword string `json:"SshPassword"`
-	}
+	var sshPasswordPayload protocol.SSHPasswordPayload
 	err := json.Unmarshal(password, &sshPasswordPayload)
 	if err != nil {
 
@@ -820,8 +800,12 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
 
 	sessionID := sshPasswordPayload.SessionId
 
+	supportsServerRequests := common.Contains(
+		sshPasswordPayload.ClientCapabilities, protocol.CLIENT_CAPABILITY_SERVER_REQUESTS)
+
 	sshClient.Lock()
 	sshClient.sessionID = sessionID
+	sshClient.supportsServerRequests = supportsServerRequests
 	geoIPData := sshClient.geoIPData
 	sshClient.Unlock()
 
@@ -920,7 +904,10 @@ func (sshClient *sshClient) stop() {
 func (sshClient *sshClient) runTunnel(
 	channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
 
+	stopBroadcast := make(chan struct{})
+
 	requestsWaitGroup := new(sync.WaitGroup)
+
 	requestsWaitGroup.Add(1)
 	go func() {
 		defer requestsWaitGroup.Done()
@@ -956,6 +943,14 @@ func (sshClient *sshClient) runTunnel(
 		}
 	}()
 
+	if sshClient.supportsServerRequests {
+		requestsWaitGroup.Add(1)
+		go func() {
+			defer requestsWaitGroup.Done()
+			sshClient.runOSLSender(stopBroadcast)
+		}()
+	}
+
 	for newChannel := range channels {
 
 		if newChannel.ChannelType() != "direct-tcpip" {
@@ -968,9 +963,83 @@ func (sshClient *sshClient) runTunnel(
 		go sshClient.handleNewPortForwardChannel(newChannel)
 	}
 
+	close(stopBroadcast)
+
 	requestsWaitGroup.Wait()
 }
 
+func (sshClient *sshClient) runOSLSender(stopBroadcast <-chan struct{}) {
+
+	for {
+		// Await a signal that there are SLOKs to send
+		// TODO: use reflect.SelectCase, and optionally await timer here?
+		select {
+		case <-sshClient.signalIssueSLOKs:
+		case <-stopBroadcast:
+			return
+		}
+
+		retryDelay := SSH_SEND_OSL_INITIAL_RETRY_DELAY
+		for {
+			err := sshClient.sendOSLRequest()
+			if err == nil {
+				break
+			}
+			log.WithContextFields(LogFields{"error": err}).Warning("sendOSLRequest failed")
+
+			// If the request failed, retry after a delay (with exponential backoff)
+			// or when signaled that there are additional SLOKs to send
+			retryTimer := time.NewTimer(retryDelay)
+			select {
+			case <-retryTimer.C:
+			case <-sshClient.signalIssueSLOKs:
+			case <-stopBroadcast:
+				retryTimer.Stop()
+				return
+			}
+			retryTimer.Stop()
+			retryDelay *= SSH_SEND_OSL_RETRY_FACTOR
+		}
+	}
+}
+
+// sendOSLRequest will invoke osl.GetSeedPayload to issue SLOKs and
+// generate a payload, and send an OSL request to the client when
+// there are new SLOKs in the payload.
+func (sshClient *sshClient) sendOSLRequest() error {
+
+	seedPayload := sshClient.getOSLSeedPayload()
+
+	// Don't send when no SLOKs. This will happen when signalIssueSLOKs
+	// is received but no new SLOKs are issued.
+	if len(seedPayload.SLOKs) == 0 {
+		return nil
+	}
+
+	oslRequest := protocol.OSLRequest{
+		SeedPayload: seedPayload,
+	}
+	requestPayload, err := json.Marshal(oslRequest)
+	if err != nil {
+		return common.ContextError(err)
+	}
+
+	ok, _, err := sshClient.sshConn.SendRequest(
+		protocol.PSIPHON_API_OSL_REQUEST_NAME,
+		true,
+		requestPayload)
+	if err != nil {
+		return common.ContextError(err)
+	}
+	if !ok {
+		return common.ContextError(errors.New("client rejected request"))
+	}
+
+	sshClient.clearOSLSeedPayload()
+
+	return nil
+}
+
 func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, logMessage string) {
 
 	// Note: Debug level, as logMessage may contain user traffic destination address information
@@ -1083,7 +1152,8 @@ func (sshClient *sshClient) setOSLConfig() {
 
 	sshClient.oslClientSeedState = sshClient.sshServer.support.OSLConfig.NewClientSeedState(
 		sshClient.geoIPData.Country,
-		propagationChannelID)
+		propagationChannelID,
+		sshClient.signalIssueSLOKs)
 }
 
 // newClientSeedPortForward will return nil when no seeding is
@@ -1100,9 +1170,9 @@ func (sshClient *sshClient) newClientSeedPortForward(ipAddress net.IP) *osl.Clie
 	return sshClient.oslClientSeedState.NewClientSeedPortForward(ipAddress)
 }
 
-// getClientSeedPayload returns a payload containing all seeded SLOKs for
+// getOSLSeedPayload returns a payload containing all seeded SLOKs for
 // this client's session.
-func (sshClient *sshClient) getClientSeedPayload() *osl.SeedPayload {
+func (sshClient *sshClient) getOSLSeedPayload() *osl.SeedPayload {
 	sshClient.Lock()
 	defer sshClient.Unlock()
 
@@ -1114,6 +1184,13 @@ func (sshClient *sshClient) getClientSeedPayload() *osl.SeedPayload {
 	return sshClient.oslClientSeedState.GetSeedPayload()
 }
 
+func (sshClient *sshClient) clearOSLSeedPayload() {
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	sshClient.oslClientSeedState.ClearSeedPayload()
+}
+
 func (sshClient *sshClient) rateLimits() common.RateLimits {
 	sshClient.Lock()
 	defer sshClient.Unlock()

+ 35 - 29
psiphon/serverApi.go

@@ -322,7 +322,6 @@ func (serverContext *ServerContext) DoStatusRequest(tunnel *Tunnel) error {
 		return common.ContextError(err)
 	}
 
-	var response []byte
 	if serverContext.psiphonHttpsClient == nil {
 
 		rawMessage := json.RawMessage(statusPayload)
@@ -332,14 +331,14 @@ func (serverContext *ServerContext) DoStatusRequest(tunnel *Tunnel) error {
 		request, err = makeSSHAPIRequestPayload(params)
 
 		if err == nil {
-			response, err = serverContext.tunnel.SendAPIRequest(
+			_, err = serverContext.tunnel.SendAPIRequest(
 				protocol.PSIPHON_API_STATUS_REQUEST_NAME, request)
 		}
 
 	} else {
 
 		// Legacy web service API request
-		response, err = serverContext.doPostRequest(
+		_, err = serverContext.doPostRequest(
 			makeRequestUrl(serverContext.tunnel, "", "status", params),
 			"application/json",
 			bytes.NewReader(statusPayload))
@@ -357,32 +356,6 @@ func (serverContext *ServerContext) DoStatusRequest(tunnel *Tunnel) error {
 
 	confirmStatusRequestPayload(statusPayloadInfo)
 
-	if len(response) > 0 {
-
-		var statusResponse protocol.StatusResponse
-		err = json.Unmarshal(response, &statusResponse)
-		if err != nil {
-			return common.ContextError(err)
-		}
-
-		for _, slok := range statusResponse.SeedPayload.SLOKs {
-			duplicate, err := SetSLOK(slok.ID, slok.Key)
-			if err != nil {
-
-				NoticeAlert("SetSLOK failed: %s", common.ContextError(err))
-
-				// Proceed with next SLOK. Also, no immediate retry.
-				// For an ongoing session, another status request will occur within
-				// PSIPHON_API_STATUS_REQUEST_PERIOD_MIN/MAX and the server will
-				// resend the same SLOKs, giving another opportunity to store.
-			}
-
-			if tunnel.config.ReportSLOKs {
-				NoticeSLOKSeeded(base64.StdEncoding.EncodeToString(slok.ID), duplicate)
-			}
-		}
-	}
-
 	return nil
 }
 
@@ -932,3 +905,36 @@ func makePsiphonHttpsClient(tunnel *Tunnel) (httpsClient *http.Client, err error
 		Timeout:   timeout,
 	}, nil
 }
+
+func HandleServerRequest(tunnel *Tunnel, name string, payload []byte) error {
+
+	switch name {
+	case protocol.PSIPHON_API_OSL_REQUEST_NAME:
+		return HandleOSLRequest(tunnel, payload)
+	}
+
+	return common.ContextError(fmt.Errorf("invalid request name: %s", name))
+}
+
+func HandleOSLRequest(tunnel *Tunnel, payload []byte) error {
+
+	var oslRequest protocol.OSLRequest
+	err := json.Unmarshal(payload, &oslRequest)
+	if err != nil {
+		return common.ContextError(err)
+	}
+
+	for _, slok := range oslRequest.SeedPayload.SLOKs {
+		duplicate, err := SetSLOK(slok.ID, slok.Key)
+		if err != nil {
+			// TODO: return error to trigger retry?
+			NoticeAlert("SetSLOK failed: %s", common.ContextError(err))
+		}
+
+		if tunnel.config.ReportSLOKs {
+			NoticeSLOKSeeded(base64.StdEncoding.EncodeToString(slok.ID), duplicate)
+		}
+	}
+
+	return nil
+}

+ 63 - 30
psiphon/tunnel.go

@@ -77,6 +77,7 @@ type Tunnel struct {
 	protocol                     string
 	conn                         *common.ActivityMonitoredConn
 	sshClient                    *ssh.Client
+	sshServerRequests            <-chan *ssh.Request
 	operateWaitGroup             *sync.WaitGroup
 	shutdownOperateBroadcast     chan struct{}
 	signalPortForwardFailure     chan struct{}
@@ -129,7 +130,7 @@ func EstablishTunnel(
 
 	// Build transport layers and establish SSH connection. Note that
 	// dialConn and monitoredConn are the same network connection.
-	dialConn, monitoredConn, sshClient, dialStats, err := dialSsh(
+	dialResult, err := dialSsh(
 		config, pendingConns, serverEntry, selectedProtocol, sessionId)
 	if err != nil {
 		return nil, common.ContextError(err)
@@ -138,9 +139,9 @@ func EstablishTunnel(
 	// Cleanup on error
 	defer func() {
 		if err != nil {
-			sshClient.Close()
-			monitoredConn.Close()
-			pendingConns.Remove(dialConn)
+			dialResult.sshClient.Close()
+			dialResult.monitoredConn.Close()
+			pendingConns.Remove(dialResult.dialConn)
 		}
 	}()
 
@@ -152,14 +153,15 @@ func EstablishTunnel(
 		isClosed:                 false,
 		serverEntry:              serverEntry,
 		protocol:                 selectedProtocol,
-		conn:                     monitoredConn,
-		sshClient:                sshClient,
+		conn:                     dialResult.monitoredConn,
+		sshClient:                dialResult.sshClient,
+		sshServerRequests:        dialResult.sshRequests,
 		operateWaitGroup:         new(sync.WaitGroup),
 		shutdownOperateBroadcast: make(chan struct{}),
 		// A buffer allows at least one signal to be sent even when the receiver is
 		// not listening. Senders should not block.
 		signalPortForwardFailure: make(chan struct{}, 1),
-		dialStats:                dialStats,
+		dialStats:                dialResult.dialStats,
 		// Buffer allows SetClientVerificationPayload to submit one new payload
 		// without blocking or dropping it.
 		newClientVerificationPayload: make(chan string, 1),
@@ -191,7 +193,7 @@ func EstablishTunnel(
 	tunnel.establishedTime = monotime.Now()
 
 	// Now that network operations are complete, cancel interruptibility
-	pendingConns.Remove(dialConn)
+	pendingConns.Remove(dialResult.dialConn)
 
 	// Spawn the operateTunnel goroutine, which monitors the tunnel and handles periodic stats updates.
 	tunnel.operateWaitGroup.Add(1)
@@ -522,6 +524,14 @@ func initMeekConfig(
 	}, nil
 }
 
+type dialResult struct {
+	dialConn      net.Conn
+	monitoredConn *common.ActivityMonitoredConn
+	sshClient     *ssh.Client
+	sshRequests   <-chan *ssh.Request
+	dialStats     *TunnelDialStats
+}
+
 // dialSsh is a helper that builds the transport layers and establishes the SSH connection.
 // When additional dial configuration is used, DialStats are recorded and returned.
 //
@@ -534,7 +544,7 @@ func dialSsh(
 	pendingConns *common.Conns,
 	serverEntry *ServerEntry,
 	selectedProtocol,
-	sessionId string) (net.Conn, *common.ActivityMonitoredConn, *ssh.Client, *TunnelDialStats, error) {
+	sessionId string) (*dialResult, error) {
 
 	// The meek protocols tunnel obfuscated SSH. Obfuscated SSH is layered on top of SSH.
 	// So depending on which protocol is used, multiple layers are initialized.
@@ -556,7 +566,7 @@ func dialSsh(
 		useObfuscatedSsh = true
 		meekConfig, err = initMeekConfig(config, serverEntry, selectedProtocol, sessionId)
 		if err != nil {
-			return nil, nil, nil, nil, common.ContextError(err)
+			return nil, common.ContextError(err)
 		}
 	}
 
@@ -596,12 +606,12 @@ func dialSsh(
 	if meekConfig != nil {
 		dialConn, err = DialMeek(meekConfig, dialConfig)
 		if err != nil {
-			return nil, nil, nil, nil, common.ContextError(err)
+			return nil, common.ContextError(err)
 		}
 	} else {
 		dialConn, err = DialTCP(directTCPDialAddress, dialConfig)
 		if err != nil {
-			return nil, nil, nil, nil, common.ContextError(err)
+			return nil, common.ContextError(err)
 		}
 	}
 
@@ -617,7 +627,7 @@ func dialSsh(
 	// Activity monitoring is used to measure tunnel duration
 	monitoredConn, err := common.NewActivityMonitoredConn(dialConn, 0, false, nil, nil)
 	if err != nil {
-		return nil, nil, nil, nil, common.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 
 	// Apply throttling (if configured)
@@ -629,14 +639,14 @@ func dialSsh(
 		sshConn, err = NewObfuscatedSshConn(
 			OBFUSCATION_CONN_MODE_CLIENT, throttledConn, serverEntry.SshObfuscatedKey)
 		if err != nil {
-			return nil, nil, nil, nil, common.ContextError(err)
+			return nil, common.ContextError(err)
 		}
 	}
 
 	// Now establish the SSH session over the conn transport
 	expectedPublicKey, err := base64.StdEncoding.DecodeString(serverEntry.SshHostKey)
 	if err != nil {
-		return nil, nil, nil, nil, common.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	sshCertChecker := &ssh.CertChecker{
 		HostKeyFallback: func(addr string, remote net.Addr, publicKey ssh.PublicKey) error {
@@ -646,18 +656,21 @@ func dialSsh(
 			return nil
 		},
 	}
-	sshPasswordPayload, err := json.Marshal(
-		struct {
-			SessionId   string `json:"SessionId"`
-			SshPassword string `json:"SshPassword"`
-		}{sessionId, serverEntry.SshPassword})
+
+	sshPasswordPayload := &protocol.SSHPasswordPayload{
+		SessionId:          sessionId,
+		SshPassword:        serverEntry.SshPassword,
+		ClientCapabilities: []string{protocol.CLIENT_CAPABILITY_SERVER_REQUESTS},
+	}
+
+	payload, err := json.Marshal(sshPasswordPayload)
 	if err != nil {
-		return nil, nil, nil, nil, common.ContextError(err)
+		return nil, common.ContextError(err)
 	}
 	sshClientConfig := &ssh.ClientConfig{
 		User: serverEntry.SshUsername,
 		Auth: []ssh.AuthMethod{
-			ssh.Password(string(sshPasswordPayload)),
+			ssh.Password(string(payload)),
 		},
 		HostKeyCallback: sshCertChecker.CheckHostKey,
 	}
@@ -674,13 +687,14 @@ func dialSsh(
 	// TODO: adjust the timeout to account for time-elapsed-from-start
 
 	type sshNewClientResult struct {
-		sshClient *ssh.Client
-		err       error
+		sshClient   *ssh.Client
+		sshRequests <-chan *ssh.Request
+		err         error
 	}
 	resultChannel := make(chan *sshNewClientResult, 2)
 	if *config.TunnelConnectTimeoutSeconds > 0 {
 		time.AfterFunc(time.Duration(*config.TunnelConnectTimeoutSeconds)*time.Second, func() {
-			resultChannel <- &sshNewClientResult{nil, errors.New("ssh dial timeout")}
+			resultChannel <- &sshNewClientResult{nil, nil, errors.New("ssh dial timeout")}
 		})
 	}
 
@@ -688,17 +702,18 @@ func dialSsh(
 		// The following is adapted from ssh.Dial(), here using a custom conn
 		// The sshAddress is passed through to host key verification callbacks; we don't use it.
 		sshAddress := ""
-		sshClientConn, sshChans, sshReqs, err := ssh.NewClientConn(sshConn, sshAddress, sshClientConfig)
+		sshClientConn, sshChannels, sshRequests, err := ssh.NewClientConn(
+			sshConn, sshAddress, sshClientConfig)
 		var sshClient *ssh.Client
 		if err == nil {
-			sshClient = ssh.NewClient(sshClientConn, sshChans, sshReqs)
+			sshClient = ssh.NewClient(sshClientConn, sshChannels, nil)
 		}
-		resultChannel <- &sshNewClientResult{sshClient, err}
+		resultChannel <- &sshNewClientResult{sshClient, sshRequests, err}
 	}()
 
 	result := <-resultChannel
 	if result.err != nil {
-		return nil, nil, nil, nil, common.ContextError(result.err)
+		return nil, common.ContextError(result.err)
 	}
 
 	var dialStats *TunnelDialStats
@@ -737,7 +752,13 @@ func dialSsh(
 	// but should not be used to perform I/O as that would interfere with SSH
 	// (and also bypasses throttling).
 
-	return dialConn, monitoredConn, result.sshClient, dialStats, nil
+	return &dialResult{
+			dialConn:      dialConn,
+			monitoredConn: monitoredConn,
+			sshClient:     result.sshClient,
+			sshRequests:   result.sshRequests,
+			dialStats:     dialStats},
+		nil
 }
 
 func makeRandomPeriod(min, max time.Duration) time.Duration {
@@ -978,6 +999,18 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 
 		case err = <-sshKeepAliveError:
 
+		case serverRequest := <-tunnel.sshServerRequests:
+			if serverRequest != nil {
+				err := HandleServerRequest(tunnel, serverRequest.Type, serverRequest.Payload)
+				if err == nil {
+					serverRequest.Reply(true, nil)
+				} else {
+					NoticeAlert("HandleServerRequest for %s failed: %s", serverRequest.Type, err)
+					serverRequest.Reply(false, nil)
+
+				}
+			}
+
 		case <-tunnel.shutdownOperateBroadcast:
 			shutdown = true
 		}