Browse Source

Merge pull request #372 from rod-hynes/master

psiphond bug fixes
Rod Hynes 8 years ago
parent
commit
050365a999
3 changed files with 185 additions and 65 deletions
  1. 2 2
      psiphon/pluginProtocol.go
  2. 55 11
      psiphon/server/api.go
  3. 128 52
      psiphon/server/tunnelServer.go

+ 2 - 2
psiphon/pluginProtocol.go

@@ -54,8 +54,8 @@ type PluginProtocolDialer func(
 
 // RegisterPluginProtocol sets the current plugin protocol
 // dialer.
-func RegisterPluginProtocol(protcolDialer PluginProtocolDialer) {
-	registeredPluginProtocolDialer.Store(protcolDialer)
+func RegisterPluginProtocol(protocolDialer PluginProtocolDialer) {
+	registeredPluginProtocolDialer.Store(protocolDialer)
 }
 
 // DialPluginProtocol uses the current plugin protocol dialer,

+ 55 - 11
psiphon/server/api.go

@@ -108,6 +108,49 @@ func dispatchAPIRequestHandler(
 		}
 	}()
 
+	// Before invoking the handlers, enforce some preconditions:
+	//
+	// - A handshake request must preceed any other requests.
+	// - When the handshake results in a traffic rules state where
+	//   the client is immediately exhausted, no requests
+	//   may succeed. This case ensures that blocked clients do
+	//   not log "connected", etc.
+	//
+	// Only one handshake request may be made. There is no check here
+	// to enforce that handshakeAPIRequestHandler will be called at
+	// most once. The SetHandshakeState call in handshakeAPIRequestHandler
+	// enforces that only a single handshake is made; enforcing that there
+	// ensures no race condition even if concurrent requests are
+	// in flight.
+
+	if name != protocol.PSIPHON_API_HANDSHAKE_REQUEST_NAME {
+
+		// TODO: same session-ID-lookup TODO in handshakeAPIRequestHandler
+		// applies here.
+
+		sessionID, err := getStringRequestParam(params, "client_session_id")
+		if err == nil {
+			// Note: follows/duplicates baseRequestParams validation
+			if !isHexDigits(support, sessionID) {
+				err = errors.New("invalid param: client_session_id")
+			}
+		}
+		if err != nil {
+			return nil, common.ContextError(err)
+		}
+
+		completed, exhausted, err := support.TunnelServer.GetClientHandshaked(sessionID)
+		if err != nil {
+			return nil, common.ContextError(err)
+		}
+		if !completed {
+			return nil, common.ContextError(errors.New("handshake not completed"))
+		}
+		if exhausted {
+			return nil, common.ContextError(errors.New("exhausted after handshake"))
+		}
+	}
+
 	switch name {
 	case protocol.PSIPHON_API_HANDSHAKE_REQUEST_NAME:
 		return handshakeAPIRequestHandler(support, apiProtocol, geoIPData, params)
@@ -139,16 +182,6 @@ func handshakeAPIRequestHandler(
 		return nil, common.ContextError(err)
 	}
 
-	log.LogRawFieldsWithTimestamp(
-		getRequestLogFields(
-			support,
-			"handshake",
-			geoIPData,
-			params,
-			baseRequestParams))
-
-	// Note: ignoring param format errors as params have been validated
-
 	sessionID, _ := getStringRequestParam(params, "client_session_id")
 	sponsorID, _ := getStringRequestParam(params, "sponsor_id")
 	clientVersion, _ := getStringRequestParam(params, "client_version")
@@ -173,6 +206,17 @@ func handshakeAPIRequestHandler(
 		return nil, common.ContextError(err)
 	}
 
+	// The log comes _after_ SetClientHandshakeState, in case that call rejects
+	// the state change (for example, if a second handshake is performed)
+
+	log.LogRawFieldsWithTimestamp(
+		getRequestLogFields(
+			support,
+			"handshake",
+			geoIPData,
+			params,
+			baseRequestParams))
+
 	// Note: no guarantee that PsinetDatabase won't reload between database calls
 	db := support.PsinetDatabase
 	handshakeResponse := protocol.HandshakeResponse{
@@ -509,7 +553,7 @@ const (
 // is specified, in which case an array of string is expected.
 var baseRequestParams = []requestParamSpec{
 	requestParamSpec{"server_secret", isServerSecret, requestParamNotLogged},
-	requestParamSpec{"client_session_id", isHexDigits, requestParamOptional | requestParamNotLogged},
+	requestParamSpec{"client_session_id", isHexDigits, requestParamNotLogged},
 	requestParamSpec{"propagation_channel_id", isHexDigits, 0},
 	requestParamSpec{"sponsor_id", isHexDigits, 0},
 	requestParamSpec{"client_version", isIntString, 0},

+ 128 - 52
psiphon/server/tunnelServer.go

@@ -225,6 +225,14 @@ func (server *TunnelServer) SetClientHandshakeState(
 	return server.sshServer.setClientHandshakeState(sessionID, state)
 }
 
+// GetClientHandshaked indicates whether the client has completed a handshake
+// and whether its traffic rules are immediately exhausted.
+func (server *TunnelServer) GetClientHandshaked(
+	sessionID string) (bool, bool, error) {
+
+	return server.sshServer.getClientHandshaked(sessionID)
+}
+
 // SetEstablishTunnels sets whether new tunnels may be established or not.
 // When not establishing, incoming connections are immediately closed.
 func (server *TunnelServer) SetEstablishTunnels(establish bool) {
@@ -660,6 +668,22 @@ func (sshServer *sshServer) setClientHandshakeState(
 	return nil
 }
 
+func (sshServer *sshServer) getClientHandshaked(
+	sessionID string) (bool, bool, error) {
+
+	sshServer.clientsMutex.Lock()
+	client := sshServer.clients[sessionID]
+	sshServer.clientsMutex.Unlock()
+
+	if client == nil {
+		return false, false, common.ContextError(errors.New("unknown session ID"))
+	}
+
+	completed, exhausted := client.getHandshaked()
+
+	return completed, exhausted, nil
+}
+
 func (sshServer *sshServer) stopClients() {
 
 	sshServer.clientsMutex.Lock()
@@ -1497,6 +1521,40 @@ func (sshClient *sshClient) setHandshakeState(state handshakeState) error {
 	return nil
 }
 
+// getHandshaked returns whether the client has completed a handshake API
+// request and whether the traffic rules that were selected after the
+// handshake immediately exhaust the client.
+//
+// When the client is immediately exhausted it will be closed; but this
+// takes effect asynchronously. The "exhausted" return value is used to
+// prevent API requests by clients that will close.
+func (sshClient *sshClient) getHandshaked() (bool, bool) {
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	completed := sshClient.handshakeState.completed
+
+	exhausted := false
+
+	// Notes:
+	// - "Immediately exhausted" is when CloseAfterExhausted is set and
+	//   either ReadUnthrottledBytes or WriteUnthrottledBytes starts from
+	//   0, so no bytes would be read or written. This check does not
+	//   examine whether 0 bytes _remain_ in the ThrottledConn.
+	// - This check is made against the current traffic rules, which
+	//   could have changed in a hot reload since the handshake.
+
+	if completed &&
+		*sshClient.trafficRules.RateLimits.CloseAfterExhausted == true &&
+		(*sshClient.trafficRules.RateLimits.ReadUnthrottledBytes == 0 ||
+			*sshClient.trafficRules.RateLimits.WriteUnthrottledBytes == 0) {
+
+		exhausted = true
+	}
+
+	return completed, exhausted
+}
+
 // setTrafficRules resets the client's traffic rules based on the latest server config
 // and client properties. As sshClient.trafficRules may be reset by a concurrent
 // goroutine, trafficRules must only be accessed within the sshClient mutex.
@@ -1697,28 +1755,6 @@ func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
 	return false
 }
 
-func (sshClient *sshClient) isAtPortForwardLimit(
-	portForwardType int) bool {
-
-	sshClient.Lock()
-	defer sshClient.Unlock()
-
-	var max int
-	var state *trafficState
-	if portForwardType == portForwardTypeTCP {
-		max = *sshClient.trafficRules.MaxTCPPortForwardCount
-		state = &sshClient.tcpTrafficState
-	} else {
-		max = *sshClient.trafficRules.MaxUDPPortForwardCount
-		state = &sshClient.udpTrafficState
-	}
-
-	if max > 0 && state.concurrentPortForwardCount >= int64(max) {
-		return true
-	}
-	return false
-}
-
 func (sshClient *sshClient) getTCPPortForwardQueueSize() int {
 
 	sshClient.Lock()
@@ -1757,6 +1793,55 @@ func (sshClient *sshClient) abortedTCPPortForward() {
 	sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
 }
 
+func (sshClient *sshClient) allocatePortForward(portForwardType int) bool {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	// Check if at port forward limit. The subsequent counter
+	// changes must be atomic with the limit check to ensure
+	// the counter never exceeds the limit in the case of
+	// concurrent allocations.
+
+	var max int
+	var state *trafficState
+	if portForwardType == portForwardTypeTCP {
+		max = *sshClient.trafficRules.MaxTCPPortForwardCount
+		state = &sshClient.tcpTrafficState
+	} else {
+		max = *sshClient.trafficRules.MaxUDPPortForwardCount
+		state = &sshClient.udpTrafficState
+	}
+
+	if max > 0 && state.concurrentPortForwardCount >= int64(max) {
+		return false
+	}
+
+	// Update port forward counters.
+
+	if portForwardType == portForwardTypeTCP {
+
+		// Assumes TCP port forwards called dialingTCPPortForward
+		state.concurrentDialingPortForwardCount -= 1
+
+		if sshClient.tcpPortForwardDialingAvailableSignal != nil {
+
+			max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
+			if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
+				sshClient.tcpPortForwardDialingAvailableSignal()
+			}
+		}
+	}
+
+	state.concurrentPortForwardCount += 1
+	if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
+		state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
+	}
+	state.totalPortForwardCount += 1
+
+	return true
+}
+
 // establishedPortForward increments the concurrent port
 // forward counter. closedPortForward decrements it, so it
 // must always be called for each establishedPortForward
@@ -1774,6 +1859,8 @@ func (sshClient *sshClient) abortedTCPPortForward() {
 func (sshClient *sshClient) establishedPortForward(
 	portForwardType int, portForwardLRU *common.LRUConns) {
 
+	// Do not lock sshClient here.
+
 	var state *trafficState
 	if portForwardType == portForwardTypeTCP {
 		state = &sshClient.tcpTrafficState
@@ -1791,44 +1878,33 @@ func (sshClient *sshClient) establishedPortForward(
 	// waits for a LRU handler to be interrupted and signal
 	// availability.
 	//
-	// Note: the port forward limit can change via a traffic
-	// rules hot reload; the condition variable handles this
-	// case whereas a channel-based semaphore would not.
+	// Notes:
+	//
+	// - the port forward limit can change via a traffic
+	//   rules hot reload; the condition variable handles
+	//   this case whereas a channel-based semaphore would
+	//   not.
+	//
+	// - if a number of goroutines exceeding the total limit
+	//   arrive here all concurrently, some CloseOldest() calls
+	//   will have no effect as there can be less existing port
+	//   forwards than new ones. In this case, the new port
+	//   forward will be delayed. This is highly unlikely in
+	//   practise since UDP calls to establishedPortForward are
+	//   serialized and TCP calls are limited by the dial
+	//   queue/count.
+
+	if !sshClient.allocatePortForward(portForwardType) {
 
-	if sshClient.isAtPortForwardLimit(portForwardType) {
 		portForwardLRU.CloseOldest()
 		log.WithContext().Debug("closed LRU port forward")
+
 		state.availablePortForwardCond.L.Lock()
-		for sshClient.isAtPortForwardLimit(portForwardType) {
+		for !sshClient.allocatePortForward(portForwardType) {
 			state.availablePortForwardCond.Wait()
 		}
 		state.availablePortForwardCond.L.Unlock()
 	}
-
-	sshClient.Lock()
-
-	if portForwardType == portForwardTypeTCP {
-
-		// Assumes TCP port forwards called dialingTCPPortForward
-		state.concurrentDialingPortForwardCount -= 1
-
-		if sshClient.tcpPortForwardDialingAvailableSignal != nil {
-
-			max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
-			if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
-				sshClient.tcpPortForwardDialingAvailableSignal()
-			}
-		}
-
-	}
-
-	state.concurrentPortForwardCount += 1
-	if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
-		state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
-	}
-	state.totalPortForwardCount += 1
-
-	sshClient.Unlock()
 }
 
 func (sshClient *sshClient) closedPortForward(