Преглед изворни кода

Reworked delay/timeouts; add Limited response flag

- Set distinct timeouts for Noise session handshake vs. application request
  round trips. Ensure announce delays and long-polling announce timeouts
  aren't consumed by or impacted by Noise session establishment or blocking
  on waitToShareSession.

- Add an explict Limited response flag to announce and offer responses,
  instead of simply returning NoMatch. Clients/proxies skip-retry/back-off
  when Limited, but not when NoMatch. Limited takes priority over NoMatch.
Rod Hynes пре 1 година
родитељ
комит
b5999ea8cf

+ 16 - 14
psiphon/common/inproxy/api.go

@@ -230,15 +230,16 @@ type ProxyAnnounceRequest struct {
 type ProxyAnnounceResponse struct {
 	OperatorMessageJSON         string                               `cbor:"1,keyasint,omitempty"`
 	TacticsPayload              []byte                               `cbor:"2,keyasint,omitempty"`
-	NoMatch                     bool                                 `cbor:"3,keyasint,omitempty"`
-	ConnectionID                ID                                   `cbor:"4,keyasint,omitempty"`
-	ClientProxyProtocolVersion  int32                                `cbor:"5,keyasint,omitempty"`
-	ClientOfferSDP              webrtc.SessionDescription            `cbor:"6,keyasint,omitempty"`
-	ClientRootObfuscationSecret ObfuscationSecret                    `cbor:"7,keyasint,omitempty"`
-	DoDTLSRandomization         bool                                 `cbor:"8,keyasint,omitempty"`
-	TrafficShapingParameters    *DataChannelTrafficShapingParameters `cbor:"9,keyasint,omitempty"`
-	NetworkProtocol             NetworkProtocol                      `cbor:"10,keyasint,omitempty"`
-	DestinationAddress          string                               `cbor:"11,keyasint,omitempty"`
+	Limited                     bool                                 `cbor:"3,keyasint,omitempty"`
+	NoMatch                     bool                                 `cbor:"4,keyasint,omitempty"`
+	ConnectionID                ID                                   `cbor:"5,keyasint,omitempty"`
+	ClientProxyProtocolVersion  int32                                `cbor:"6,keyasint,omitempty"`
+	ClientOfferSDP              webrtc.SessionDescription            `cbor:"7,keyasint,omitempty"`
+	ClientRootObfuscationSecret ObfuscationSecret                    `cbor:"8,keyasint,omitempty"`
+	DoDTLSRandomization         bool                                 `cbor:"9,keyasint,omitempty"`
+	TrafficShapingParameters    *DataChannelTrafficShapingParameters `cbor:"10,keyasint,omitempty"`
+	NetworkProtocol             NetworkProtocol                      `cbor:"11,keyasint,omitempty"`
+	DestinationAddress          string                               `cbor:"12,keyasint,omitempty"`
 }
 
 // ClientOfferRequest is an API request sent from a client to a broker,
@@ -303,11 +304,12 @@ type DataChannelTrafficShapingParameters struct {
 // ClientRelayedPacketRequests until complete. ConnectionID identifies this
 // connection and its relayed BrokerServerReport.
 type ClientOfferResponse struct {
-	NoMatch                      bool                      `cbor:"1,keyasint,omitempty"`
-	ConnectionID                 ID                        `cbor:"2,keyasint,omitempty"`
-	SelectedProxyProtocolVersion int32                     `cbor:"3,keyasint,omitempty"`
-	ProxyAnswerSDP               webrtc.SessionDescription `cbor:"4,keyasint,omitempty"`
-	RelayPacketToServer          []byte                    `cbor:"5,keyasint,omitempty"`
+	Limited                      bool                      `cbor:"1,keyasint,omitempty"`
+	NoMatch                      bool                      `cbor:"2,keyasint,omitempty"`
+	ConnectionID                 ID                        `cbor:"3,keyasint,omitempty"`
+	SelectedProxyProtocolVersion int32                     `cbor:"4,keyasint,omitempty"`
+	ProxyAnswerSDP               webrtc.SessionDescription `cbor:"5,keyasint,omitempty"`
+	RelayPacketToServer          []byte                    `cbor:"6,keyasint,omitempty"`
 }
 
 // TODO: Encode SDPs using CBOR without field names, simliar to packed metrics?

+ 24 - 14
psiphon/common/inproxy/broker.go

@@ -611,11 +611,14 @@ func (b *Broker) handleProxyAnnounce(
 
 		// A no-match response is sent in the case of a timeout awaiting a
 		// match. The faster-failing rate or entry limiting case also results
-		// in a no-match response, rather than an error return from
-		// handleProxyAnnounce, so that the proxy doesn't receive a 404 and
-		// flag its BrokerClient as having failed.
+		// in a response, rather than an error return from handleProxyAnnounce,
+		// so that the proxy doesn't receive a 404 and flag its BrokerClient as
+		// having failed.
+		//
+		// When the timeout and limit case coincide, limit takes precedence in
+		// the response.
 
-		if timeout {
+		if timeout && !limited {
 
 			// Note: the respective proxy and broker timeouts,
 			// InproxyBrokerProxyAnnounceTimeout and
@@ -625,7 +628,7 @@ func (b *Broker) handleProxyAnnounce(
 
 			timedOut = true
 
-		} else if limited {
+		} else {
 
 			// Record the specific limit error in the proxy-announce broker event.
 
@@ -635,7 +638,8 @@ func (b *Broker) handleProxyAnnounce(
 		responsePayload, err := MarshalProxyAnnounceResponse(
 			&ProxyAnnounceResponse{
 				TacticsPayload: tacticsPayload,
-				NoMatch:        true,
+				Limited:        limited,
+				NoMatch:        timeout && !limited,
 			})
 		if err != nil {
 			return nil, errors.Trace(err)
@@ -836,11 +840,14 @@ func (b *Broker) handleClientOffer(
 
 		// A no-match response is sent in the case of a timeout awaiting a
 		// match. The faster-failing rate or entry limiting case also results
-		// in a no-match response, rather than an error return from
-		// handleClientOffer, so that the client doesn't receive a 404 and
-		// flag its BrokerClient as having failed.
+		// in a response, rather than an error return from handleClientOffer,
+		// so that the client doesn't receive a 404 and flag its BrokerClient
+		// as having failed.
+		//
+		// When the timeout and limit case coincide, limit takes precedence in
+		// the response.
 
-		if timeout {
+		if timeout && !limited {
 
 			// Note: the respective client and broker timeouts,
 			// InproxyBrokerClientOfferTimeout and
@@ -850,7 +857,7 @@ func (b *Broker) handleClientOffer(
 
 			timedOut = true
 
-		} else if limited {
+		} else {
 
 			// Record the specific limit error in the client-offer broker event.
 
@@ -858,7 +865,10 @@ func (b *Broker) handleClientOffer(
 		}
 
 		responsePayload, err := MarshalClientOfferResponse(
-			&ClientOfferResponse{NoMatch: true})
+			&ClientOfferResponse{
+				Limited: limited,
+				NoMatch: timeout && !limited,
+			})
 		if err != nil {
 			return nil, errors.Trace(err)
 		}
@@ -1127,7 +1137,7 @@ func (b *Broker) handleClientRelayedPacket(
 
 	// Next is given a nil ctx since we're not waiting for any other client to
 	// establish the session.
-	out, err := pendingServerReport.roundTrip.Next(
+	out, _, err := pendingServerReport.roundTrip.Next(
 		nil, relayedPacketRequest.PacketFromServer)
 	if err != nil {
 		return nil, errors.Trace(err)
@@ -1189,7 +1199,7 @@ func (b *Broker) initiateRelayedServerReport(
 		return nil, errors.Trace(err)
 	}
 
-	relayPacket, err := roundTrip.Next(nil, nil)
+	relayPacket, _, err := roundTrip.Next(nil, nil)
 	if err != nil {
 		return nil, errors.Trace(err)
 	}

+ 34 - 49
psiphon/common/inproxy/brokerClient.go

@@ -31,6 +31,7 @@ import (
 // Timeouts should be aligned with Broker timeouts.
 
 const (
+	sessionHandshakeRoundTripTimeout  = 10 * time.Second
 	proxyAnnounceRequestTimeout       = 2 * time.Minute
 	proxyAnswerRequestTimeout         = 10 * time.Second
 	clientOfferRequestTimeout         = 10 * time.Second
@@ -110,20 +111,12 @@ func (b *BrokerClient) ProxyAnnounce(
 		return nil, errors.Trace(err)
 	}
 
-	timeout := common.ValueOrDefault(
+	requestTimeout := common.ValueOrDefault(
 		b.coordinator.AnnounceRequestTimeout(),
 		proxyAnnounceRequestTimeout)
 
-	// Increase the timeout to account for requestDelay, which is applied
-	// before the actual network round trip.
-	if requestDelay > 0 {
-		timeout += requestDelay
-	}
-
-	requestCtx, requestCancelFunc := context.WithTimeout(ctx, timeout)
-	defer requestCancelFunc()
-
-	responsePayload, err := b.roundTrip(requestCtx, requestDelay, requestPayload)
+	responsePayload, err := b.roundTrip(
+		ctx, requestDelay, requestTimeout, requestPayload)
 	if err != nil {
 		return nil, errors.Trace(err)
 	}
@@ -146,13 +139,12 @@ func (b *BrokerClient) ClientOffer(
 		return nil, errors.Trace(err)
 	}
 
-	requestCtx, requestCancelFunc := context.WithTimeout(
-		ctx, common.ValueOrDefault(
-			b.coordinator.OfferRequestTimeout(),
-			clientOfferRequestTimeout))
-	defer requestCancelFunc()
+	requestTimeout := common.ValueOrDefault(
+		b.coordinator.OfferRequestTimeout(),
+		clientOfferRequestTimeout)
 
-	responsePayload, err := b.roundTrip(requestCtx, 0, requestPayload)
+	responsePayload, err := b.roundTrip(
+		ctx, 0, requestTimeout, requestPayload)
 	if err != nil {
 		return nil, errors.Trace(err)
 	}
@@ -175,13 +167,12 @@ func (b *BrokerClient) ProxyAnswer(
 		return nil, errors.Trace(err)
 	}
 
-	requestCtx, requestCancelFunc := context.WithTimeout(
-		ctx, common.ValueOrDefault(
-			b.coordinator.AnswerRequestTimeout(),
-			proxyAnswerRequestTimeout))
-	defer requestCancelFunc()
+	requestTimeout := common.ValueOrDefault(
+		b.coordinator.AnswerRequestTimeout(),
+		proxyAnswerRequestTimeout)
 
-	responsePayload, err := b.roundTrip(requestCtx, 0, requestPayload)
+	responsePayload, err := b.roundTrip(
+		ctx, 0, requestTimeout, requestPayload)
 	if err != nil {
 		return nil, errors.Trace(err)
 	}
@@ -205,13 +196,12 @@ func (b *BrokerClient) ClientRelayedPacket(
 		return nil, errors.Trace(err)
 	}
 
-	requestCtx, requestCancelFunc := context.WithTimeout(
-		ctx, common.ValueOrDefault(
-			b.coordinator.RelayedPacketRequestTimeout(),
-			clientRelayedPacketRequestTimeout))
-	defer requestCancelFunc()
+	requestTimeout := common.ValueOrDefault(
+		b.coordinator.RelayedPacketRequestTimeout(),
+		clientRelayedPacketRequestTimeout)
 
-	responsePayload, err := b.roundTrip(requestCtx, 0, requestPayload)
+	responsePayload, err := b.roundTrip(
+		ctx, 0, requestTimeout, requestPayload)
 	if err != nil {
 		return nil, errors.Trace(err)
 	}
@@ -227,6 +217,7 @@ func (b *BrokerClient) ClientRelayedPacket(
 func (b *BrokerClient) roundTrip(
 	ctx context.Context,
 	requestDelay time.Duration,
+	requestTimeout time.Duration,
 	request []byte) ([]byte, error) {
 
 	// The round tripper may need to establish a transport-level connection;
@@ -260,35 +251,29 @@ func (b *BrokerClient) roundTrip(
 	// request/response association, the application-level request and
 	// response are tagged with a RoundTripID which is checked to ensure the
 	// association is maintained.
-
-	var preRoundTrip func(context.Context)
-	if requestDelay > 0 {
-
-		// Use the pre-round trip callback apply the requestDelay _after_ any
-		// waitToShareSession delay, otherwise any waitToShareSession may
-		// collapse staggered requests back together.
-		//
-		// The context passed to preRoundTrip should cancel the delay both in
-		// the case where the request is canceled and and in the case where
-		// the round tripper is closed.
-		//
-		// It's assumed that the caller has adjusted the ctx deadline to
-		// account for requestDelay.
-
-		preRoundTrip = func(ctx context.Context) {
-			common.SleepWithContext(ctx, requestDelay)
-		}
-	}
+	//
+	// InitiatorSessions.RoundTrip will apply sessionHandshakeTimeout to any
+	// round trips required for Noise session handshakes; apply requestDelay
+	// before the application-level request round trip; and apply
+	// requestTimeout to the network round trip following the delay, if any.
+	// Any time spent blocking on waitToShareSession is not included in
+	// requestDelay or requestTimeout.
 
 	waitToShareSession := true
 
+	sessionHandshakeTimeout := common.ValueOrDefault(
+		b.coordinator.SessionHandshakeRoundTripTimeout(),
+		sessionHandshakeRoundTripTimeout)
+
 	response, err := b.sessions.RoundTrip(
 		ctx,
 		roundTripper,
-		preRoundTrip,
 		b.coordinator.BrokerPublicKey(),
 		b.coordinator.BrokerRootObfuscationSecret(),
 		waitToShareSession,
+		sessionHandshakeTimeout,
+		requestDelay,
+		requestTimeout,
 		request)
 	if err != nil {
 

+ 11 - 4
psiphon/common/inproxy/client.go

@@ -345,14 +345,21 @@ func dialClientWebRTCConn(
 		return nil, false, errors.Trace(err)
 	}
 
-	if offerResponse.NoMatch {
-		return nil, false, errors.TraceNew("no proxy match")
-	}
+	// No retry when rate/entry limited or when the proxy protocols is
+	// incompatible; do retry on no-match, as a match may soon appear.
+
+	if offerResponse.Limited {
+		return nil, false, errors.TraceNew("limited")
+
+	} else if offerResponse.SelectedProxyProtocolVersion != ProxyProtocolVersion1 {
 
-	if offerResponse.SelectedProxyProtocolVersion != ProxyProtocolVersion1 {
 		return nil, false, errors.Tracef(
 			"Unsupported proxy protocol version: %d",
 			offerResponse.SelectedProxyProtocolVersion)
+
+	} else if offerResponse.NoMatch {
+
+		return nil, true, errors.TraceNew("no proxy match")
 	}
 
 	// Establish the WebRTC DataChannel connection

+ 7 - 7
psiphon/common/inproxy/coordinator.go

@@ -29,19 +29,18 @@ import (
 // blocking circumvention capabilities. A typical implementation is domain
 // fronted HTTPS. RoundTripper is used by clients and proxies to make
 // requests to brokers.
+//
+// The round trip implementation must apply any specified delay before the
+// network round trip begins; and apply the specified timeout to the network
+// round trip, excluding any delay.
 type RoundTripper interface {
 	RoundTrip(
 		ctx context.Context,
-		preRoundTrip PreRoundTripCallback,
+		roundTripDelay time.Duration,
+		roundTripTimeout time.Duration,
 		requestPayload []byte) (responsePayload []byte, err error)
 }
 
-// PreRoundTripCallback is a callback that is invoked by the RoundTripper
-// immediately before the network round trip, and which takes a context that
-// will be canceled both in the case the request is canceled and in case the
-// round tripper is closed.
-type PreRoundTripCallback func(context.Context)
-
 // RoundTripperFailedError is an error type that should be returned from
 // RoundTripper.RoundTrip when the round trip transport has permanently
 // failed. When RoundTrip returns an error of type RoundTripperFailedError to
@@ -163,6 +162,7 @@ type BrokerDialCoordinator interface {
 	// after closing its network resources.
 	BrokerClientRoundTripperFailed(roundTripper RoundTripper)
 
+	SessionHandshakeRoundTripTimeout() time.Duration
 	AnnounceRequestTimeout() time.Duration
 	AnnounceDelay() time.Duration
 	AnnounceDelayJitter() float64

+ 7 - 0
psiphon/common/inproxy/coordinator_test.go

@@ -45,6 +45,7 @@ type testBrokerDialCoordinator struct {
 	brokerClientRoundTripper          RoundTripper
 	brokerClientRoundTripperSucceeded func(RoundTripper)
 	brokerClientRoundTripperFailed    func(RoundTripper)
+	sessionHandshakeRoundTripTimeout  time.Duration
 	announceRequestTimeout            time.Duration
 	announceDelay                     time.Duration
 	announceDelayJitter               float64
@@ -115,6 +116,12 @@ func (t *testBrokerDialCoordinator) BrokerClientRoundTripperFailed(roundTripper
 	t.brokerClientRoundTripperFailed(roundTripper)
 }
 
+func (t *testBrokerDialCoordinator) SessionHandshakeRoundTripTimeout() time.Duration {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.sessionHandshakeRoundTripTimeout
+}
+
 func (t *testBrokerDialCoordinator) AnnounceRequestTimeout() time.Duration {
 	t.mutex.Lock()
 	defer t.mutex.Unlock()

+ 8 - 4
psiphon/common/inproxy/inproxy_test.go

@@ -861,17 +861,21 @@ func newHTTPRoundTripper(endpointAddr string, path string) *httpRoundTripper {
 
 func (r *httpRoundTripper) RoundTrip(
 	ctx context.Context,
-	preRoundTrip PreRoundTripCallback,
+	roundTripDelay time.Duration,
+	roundTripTimeout time.Duration,
 	requestPayload []byte) ([]byte, error) {
 
-	if preRoundTrip != nil {
-		preRoundTrip(ctx)
+	if roundTripDelay > 0 {
+		common.SleepWithContext(ctx, roundTripDelay)
 	}
 
+	requestCtx, requestCancelFunc := context.WithTimeout(ctx, roundTripTimeout)
+	defer requestCancelFunc()
+
 	url := fmt.Sprintf("https://%s/%s", r.endpointAddr, r.path)
 
 	request, err := http.NewRequestWithContext(
-		ctx, "POST", url, bytes.NewReader(requestPayload))
+		requestCtx, "POST", url, bytes.NewReader(requestPayload))
 	if err != nil {
 		return nil, errors.Trace(err)
 	}

+ 23 - 13
psiphon/common/inproxy/proxy.go

@@ -541,28 +541,28 @@ func (p *Proxy) proxyOneClient(
 	// will also extend the base request timeout, as required, to account for
 	// any deliberate delay.
 
-	announceRequestDelay := time.Duration(0)
+	requestDelay := time.Duration(0)
 	announceDelay, announceDelayJitter := p.getAnnounceDelayParameters()
 	p.nextAnnounceMutex.Lock()
-	delay := prng.JitterDuration(announceDelay, announceDelayJitter)
+	nextDelay := prng.JitterDuration(announceDelay, announceDelayJitter)
 	if p.nextAnnounceBrokerClient != brokerClient {
 		// Reset the delay when the broker client changes.
 		p.nextAnnounceNotBefore = time.Time{}
 		p.nextAnnounceBrokerClient = brokerClient
 	}
 	if p.nextAnnounceNotBefore.IsZero() {
-		p.nextAnnounceNotBefore = time.Now().Add(delay)
+		p.nextAnnounceNotBefore = time.Now().Add(nextDelay)
 		// No delay for the very first announce request, so leave
 		// announceRequestDelay set to 0.
 	} else {
-		announceRequestDelay = time.Until(p.nextAnnounceNotBefore)
-		if announceRequestDelay < 0 {
+		requestDelay = time.Until(p.nextAnnounceNotBefore)
+		if requestDelay < 0 {
 			// This announce did not arrive until after the next delay already
 			// passed, so proceed with no delay.
-			p.nextAnnounceNotBefore = time.Now().Add(delay)
-			announceRequestDelay = 0
+			p.nextAnnounceNotBefore = time.Now().Add(nextDelay)
+			requestDelay = 0
 		} else {
-			p.nextAnnounceNotBefore = p.nextAnnounceNotBefore.Add(delay)
+			p.nextAnnounceNotBefore = p.nextAnnounceNotBefore.Add(nextDelay)
 		}
 	}
 	p.nextAnnounceMutex.Unlock()
@@ -574,7 +574,7 @@ func (p *Proxy) proxyOneClient(
 	// long-polling.
 	announceResponse, err := brokerClient.ProxyAnnounce(
 		ctx,
-		announceRequestDelay,
+		requestDelay,
 		&ProxyAnnounceRequest{
 			PersonalCompartmentIDs: brokerCoordinator.PersonalCompartmentIDs(),
 			Metrics:                metrics,
@@ -608,14 +608,24 @@ func (p *Proxy) proxyOneClient(
 		signalAnnounceDone()
 	}
 
-	if announceResponse.NoMatch {
-		return backOff, errors.TraceNew("no match")
-	}
+	// Trigger back-off back off when rate/entry limited or when the proxy
+	// protocols is incompatible; no back-off for no-match.
+
+	if announceResponse.Limited {
+
+		backOff = true
+		return backOff, errors.TraceNew("limited")
 
-	if announceResponse.ClientProxyProtocolVersion != ProxyProtocolVersion1 {
+	} else if announceResponse.ClientProxyProtocolVersion != ProxyProtocolVersion1 {
+
+		backOff = true
 		return backOff, errors.Tracef(
 			"Unsupported proxy protocol version: %d",
 			announceResponse.ClientProxyProtocolVersion)
+
+	} else if announceResponse.NoMatch {
+
+		return backOff, errors.TraceNew("no match")
 	}
 
 	// Trigger back-off if the following WebRTC operations fail to establish a

+ 79 - 30
psiphon/common/inproxy/session.go

@@ -276,14 +276,32 @@ func NewInitiatorSessions(
 // When waitToShareSession is true, RoundTrip will block until an existing,
 // non-established session is available to be shared.
 //
+// When making initial network round trips to establish a session,
+// sessionHandshakeTimeout is applied as the round trip timeout.
+//
+// When making the application-level request round trip, requestDelay, when >
+// 0, is applied before the request network round trip begins; requestDelay
+// may be used to spread out many concurrent requests, such as batch proxy
+// announcements, to avoid CDN rate limits.
+//
+// requestTimeout is applied to the application-level request network round
+// trip, and excludes any requestDelay; the distinct requestTimeout may be
+// used to set a longer timeout for long-polling requests, such as proxy
+// announcements.
+//
+// Any time spent blocking on waitToShareSession is not included in
+// requestDelay or requestTimeout.
+//
 // RoundTrip returns immediately when ctx becomes done.
 func (s *InitiatorSessions) RoundTrip(
 	ctx context.Context,
 	roundTripper RoundTripper,
-	preRoundTrip PreRoundTripCallback,
 	responderPublicKey SessionPublicKey,
 	responderRootObfuscationSecret ObfuscationSecret,
 	waitToShareSession bool,
+	sessionHandshakeTimeout time.Duration,
+	requestDelay time.Duration,
+	requestTimeout time.Duration,
 	request []byte) ([]byte, error) {
 
 	rt, err := s.NewRoundTrip(
@@ -297,7 +315,7 @@ func (s *InitiatorSessions) RoundTrip(
 
 	var in []byte
 	for {
-		out, err := rt.Next(ctx, in)
+		out, isRequestPacket, err := rt.Next(ctx, in)
 		if err != nil {
 			return nil, errors.Trace(err)
 		}
@@ -308,7 +326,38 @@ func (s *InitiatorSessions) RoundTrip(
 			}
 			return response, nil
 		}
-		in, err = roundTripper.RoundTrip(ctx, preRoundTrip, out)
+
+		// At this point, if sharing a session, any blocking on
+		// waitToShareSession is complete, and time elapsed in that blocking
+		// will not collapse delays or reduce timeouts. If not sharing, and
+		// establishing a new session, Noise session handshake round trips
+		// are required before the request payload round trip.
+		//
+		// Select the delay and timeout. For Noise session handshake round
+		// trips, use sessionHandshakeTimeout, which should be appropriate
+		// for a fast turn-around from the broker, and no delay. When sending
+		// the application-level request packet, use requestDelay and
+		// requestTimeout, which allows for applying a delay -- to spread out
+		// requests -- and a potentially longer timeout appropriate for a
+		// long-polling, slower turn-around from the broker.
+		//
+		// Delays and timeouts are passed down into the round tripper
+		// provider. Having the round tripper perform the delay sleep allows
+		// all delays to be interruped by any round tripper close, due to an
+		// overall broker client reset. Passing the timeout seperately, as
+		// opposed to adding to ctx, explicitly ensures that the timeout is
+		// applied only right before the network round trip and no sooner.
+
+		var delay, timeout time.Duration
+		if isRequestPacket {
+			delay = requestDelay
+			timeout = requestTimeout
+		} else {
+			// No delay for session handshake packet round trips.
+			timeout = sessionHandshakeTimeout
+		}
+
+		in, err = roundTripper.RoundTrip(ctx, delay, timeout, out)
 		if err != nil {
 
 			// There are no explicit retries here. Retrying in the case where
@@ -476,7 +525,7 @@ type InitiatorRoundTrip struct {
 // Next returns immediately when ctx becomes done.
 func (r *InitiatorRoundTrip) Next(
 	ctx context.Context,
-	receivedPacket []byte) (retSendPacket []byte, retErr error) {
+	receivedPacket []byte) (retSendPacket []byte, retIsRequestPacket bool, retErr error) {
 
 	// Note: don't clear or reset a session in the event of a bad/rejected
 	// packet as that would allow a malicious relay client to interrupt a
@@ -484,7 +533,7 @@ func (r *InitiatorRoundTrip) Next(
 	// packet and return an error.
 
 	// beginOrShareSession returns the next packet to send.
-	beginOrShareSession := func() ([]byte, error) {
+	beginOrShareSession := func() ([]byte, bool, error) {
 
 		// Check for an existing session, or create a new one if there's no
 		// existing session.
@@ -520,7 +569,7 @@ func (r *InitiatorRoundTrip) Next(
 		session, isNew, isReady, err := r.initiatorSessions.getSession(
 			r.responderPublicKey, newSession)
 		if err != nil {
-			return nil, errors.Trace(err)
+			return nil, false, errors.Trace(err)
 		}
 
 		if isNew {
@@ -576,11 +625,11 @@ func (r *InitiatorRoundTrip) Next(
 								// specified. It's expected that there will be retries by the
 								// RoundTrip caller.
 
-								return nil, errors.TraceNew("waitToShareSession failed")
+								return nil, false, errors.TraceNew("waitToShareSession failed")
 							}
 							// else, use the session
 						case <-ctx.Done():
-							return nil, errors.Trace(ctx.Err())
+							return nil, false, errors.Trace(ctx.Err())
 						}
 					}
 					r.session = session
@@ -592,7 +641,7 @@ func (r *InitiatorRoundTrip) Next(
 
 					r.session, err = newSession()
 					if err != nil {
-						return nil, errors.Trace(err)
+						return nil, false, errors.Trace(err)
 					}
 					r.sharingSession = false
 				}
@@ -606,20 +655,20 @@ func (r *InitiatorRoundTrip) Next(
 
 			sendPacket, err := r.session.sendPacket(r.requestPayload)
 			if err != nil {
-				return nil, errors.Trace(err)
+				return nil, false, errors.Trace(err)
 			}
 
-			return sendPacket, nil
+			return sendPacket, true, nil
 		}
 
 		// Begin the handshake for a new session.
 
 		_, sendPacket, _, err := r.session.nextHandshakePacket(nil)
 		if err != nil {
-			return nil, errors.Trace(err)
+			return nil, false, errors.Trace(err)
 		}
 
-		return sendPacket, nil
+		return sendPacket, false, nil
 
 	}
 
@@ -627,7 +676,7 @@ func (r *InitiatorRoundTrip) Next(
 	if ctx != nil {
 		err := ctx.Err()
 		if err != nil {
-			return nil, errors.Trace(err)
+			return nil, false, errors.Trace(err)
 		}
 	}
 
@@ -649,28 +698,28 @@ func (r *InitiatorRoundTrip) Next(
 		// packet from the peer is expected.
 
 		if receivedPacket != nil {
-			return nil, errors.TraceNew("unexpected received packet")
+			return nil, false, errors.TraceNew("unexpected received packet")
 		}
 
-		sendPacket, err := beginOrShareSession()
+		sendPacket, isRequestPacket, err := beginOrShareSession()
 
 		if err != nil {
-			return nil, errors.Trace(err)
+			return nil, false, errors.Trace(err)
 		}
-		return sendPacket, nil
+		return sendPacket, isRequestPacket, nil
 
 	}
 
 	// Not the first Next call, so a packet from the peer is expected.
 
 	if receivedPacket == nil {
-		return nil, errors.TraceNew("missing received packet")
+		return nil, false, errors.TraceNew("missing received packet")
 	}
 
 	if r.sharingSession || r.session.isEstablished() {
 
 		// When sharing an established and ready session, or once an owned
-		// session is eastablished, the next packet is post-handshake and
+		// session is established, the next packet is post-handshake and
 		// should be the round trip request response.
 
 		// Pre-unwrap here to check for a ResetSessionToken packet.
@@ -678,7 +727,7 @@ func (r *InitiatorRoundTrip) Next(
 		sessionPacket, err := unwrapSessionPacket(
 			r.session.receiveObfuscationSecret, true, nil, receivedPacket)
 		if err != nil {
-			return nil, errors.Trace(err)
+			return nil, false, errors.Trace(err)
 		}
 
 		// Reset the session when the packet is a valid ResetSessionToken. The
@@ -698,34 +747,34 @@ func (r *InitiatorRoundTrip) Next(
 			r.initiatorSessions.removeIfSession(r.responderPublicKey, r.session)
 			r.session = nil
 
-			sendPacket, err := beginOrShareSession()
+			sendPacket, isRequestPacket, err := beginOrShareSession()
 			if err != nil {
-				return nil, errors.Trace(err)
+				return nil, false, errors.Trace(err)
 			}
-			return sendPacket, nil
+			return sendPacket, isRequestPacket, nil
 		}
 
 		responsePayload, err := r.session.receiveUnmarshaledPacket(sessionPacket)
 		if err != nil {
-			return nil, errors.Trace(err)
+			return nil, false, errors.Trace(err)
 		}
 
 		var sessionRoundTrip SessionRoundTrip
 		err = unmarshalRecord(recordTypeSessionRoundTrip, responsePayload, &sessionRoundTrip)
 		if err != nil {
-			return nil, errors.Trace(err)
+			return nil, false, errors.Trace(err)
 		}
 
 		// Check that the response RoundTripID matches the request RoundTripID.
 
 		if sessionRoundTrip.RoundTripID != r.roundTripID {
-			return nil, errors.TraceNew("unexpected round trip ID")
+			return nil, false, errors.TraceNew("unexpected round trip ID")
 		}
 
 		// Store the response so it can be retrieved later.
 
 		r.response = sessionRoundTrip.Payload
-		return nil, nil
+		return nil, false, nil
 	}
 
 	// Continue the handshake. Since the first payload is sent to the
@@ -737,7 +786,7 @@ func (r *InitiatorRoundTrip) Next(
 
 	isEstablished, sendPacket, _, err := r.session.nextHandshakePacket(receivedPacket)
 	if err != nil {
-		return nil, errors.Trace(err)
+		return nil, false, errors.Trace(err)
 	}
 
 	if isEstablished {
@@ -752,7 +801,7 @@ func (r *InitiatorRoundTrip) Next(
 		r.initiatorSessions.setSession(r.responderPublicKey, r.session)
 	}
 
-	return sendPacket, nil
+	return sendPacket, isEstablished, nil
 }
 
 // TransportFailed marks any owned, not yet ready-to-share session as failed

+ 107 - 33
psiphon/common/inproxy/session_test.go

@@ -25,7 +25,6 @@ import (
 	"fmt"
 	"math"
 	"strings"
-	"sync/atomic"
 	"testing"
 	"time"
 
@@ -78,26 +77,32 @@ func runTestSessions() error {
 		return errors.Trace(err)
 	}
 
-	var preRoundTripCalls atomic.Int64
-	preRoundTrip := func(_ context.Context) {
-		preRoundTripCalls.Add(1)
-	}
-
 	initiatorSessions := NewInitiatorSessions(initiatorPrivateKey)
 
-	roundTripper := newTestSessionRoundTripper(responderSessions, &initiatorPublicKey)
-
 	waitToShareSession := true
 
+	sessionHandshakeTimeout := 100 * time.Millisecond
+	requestDelay := 1 * time.Microsecond
+	requestTimeout := 200 * time.Millisecond
+
+	roundTripper := newTestSessionRoundTripper(
+		responderSessions,
+		&initiatorPublicKey,
+		sessionHandshakeTimeout,
+		requestDelay,
+		requestTimeout)
+
 	request := roundTripper.MakeRequest()
 
 	response, err := initiatorSessions.RoundTrip(
 		context.Background(),
 		roundTripper,
-		preRoundTrip,
 		responderPublicKey,
 		responderRootObfuscationSecret,
 		waitToShareSession,
+		sessionHandshakeTimeout,
+		requestDelay,
+		requestTimeout,
 		request)
 	if err != nil {
 		return errors.Trace(err)
@@ -118,10 +123,12 @@ func runTestSessions() error {
 	response, err = initiatorSessions.RoundTrip(
 		context.Background(),
 		roundTripper,
-		preRoundTrip,
 		responderPublicKey,
 		responderRootObfuscationSecret,
 		waitToShareSession,
+		sessionHandshakeTimeout,
+		requestDelay,
+		requestTimeout,
 		request)
 	if err != nil {
 		return errors.Trace(err)
@@ -139,10 +146,12 @@ func runTestSessions() error {
 		_, err = initiatorSessions.RoundTrip(
 			context.Background(),
 			roundTripper,
-			preRoundTrip,
 			responderPublicKey,
 			responderRootObfuscationSecret,
 			waitToShareSession,
+			sessionHandshakeTimeout,
+			requestDelay,
+			requestTimeout,
 			roundTripper.MakeRequest())
 		if err != nil {
 			return errors.Trace(err)
@@ -156,10 +165,12 @@ func runTestSessions() error {
 	response, err = initiatorSessions.RoundTrip(
 		context.Background(),
 		roundTripper,
-		preRoundTrip,
 		responderPublicKey,
 		responderRootObfuscationSecret,
 		waitToShareSession,
+		sessionHandshakeTimeout,
+		requestDelay,
+		requestTimeout,
 		request)
 	if err != nil {
 		return errors.Trace(err)
@@ -176,7 +187,12 @@ func runTestSessions() error {
 
 	initiatorSessions = NewInitiatorSessions(initiatorPrivateKey)
 
-	failingRoundTripper := newTestSessionRoundTripper(nil, &initiatorPublicKey)
+	failingRoundTripper := newTestSessionRoundTripper(
+		nil,
+		&initiatorPublicKey,
+		sessionHandshakeTimeout,
+		requestDelay,
+		requestTimeout)
 
 	roundTripCount := 100
 
@@ -189,10 +205,12 @@ func runTestSessions() error {
 			_, err := initiatorSessions.RoundTrip(
 				context.Background(),
 				failingRoundTripper,
-				preRoundTrip,
 				responderPublicKey,
 				responderRootObfuscationSecret,
 				waitToShareSession,
+				sessionHandshakeTimeout,
+				requestDelay,
+				requestTimeout,
 				roundTripper.MakeRequest())
 			results <- err
 		}()
@@ -224,17 +242,24 @@ func runTestSessions() error {
 		return errors.Trace(err)
 	}
 
-	roundTripper = newTestSessionRoundTripper(responderSessions, &initiatorPublicKey)
+	roundTripper = newTestSessionRoundTripper(
+		responderSessions,
+		&initiatorPublicKey,
+		sessionHandshakeTimeout,
+		requestDelay,
+		requestTimeout)
 
 	request = roundTripper.MakeRequest()
 
 	response, err = initiatorSessions.RoundTrip(
 		context.Background(),
 		roundTripper,
-		preRoundTrip,
 		responderPublicKey,
 		responderRootObfuscationSecret,
 		waitToShareSession,
+		sessionHandshakeTimeout,
+		requestDelay,
+		requestTimeout,
 		request)
 	if err != nil {
 		return errors.Trace(err)
@@ -258,17 +283,24 @@ func runTestSessions() error {
 
 	responderSessions.SetKnownInitiatorPublicKeys([]SessionPublicKey{initiatorPublicKey})
 
-	roundTripper = newTestSessionRoundTripper(responderSessions, &initiatorPublicKey)
+	roundTripper = newTestSessionRoundTripper(
+		responderSessions,
+		&initiatorPublicKey,
+		sessionHandshakeTimeout,
+		requestDelay,
+		requestTimeout)
 
 	request = roundTripper.MakeRequest()
 
 	response, err = initiatorSessions.RoundTrip(
 		context.Background(),
 		roundTripper,
-		preRoundTrip,
 		responderPublicKey,
 		responderRootObfuscationSecret,
 		waitToShareSession,
+		sessionHandshakeTimeout,
+		requestDelay,
+		requestTimeout,
 		request)
 	if err != nil {
 		return errors.Trace(err)
@@ -318,10 +350,12 @@ func runTestSessions() error {
 	response, err = unknownInitiatorSessions.RoundTrip(
 		ctx,
 		roundTripper,
-		preRoundTrip,
 		responderPublicKey,
 		responderRootObfuscationSecret,
 		waitToShareSession,
+		sessionHandshakeTimeout,
+		requestDelay,
+		requestTimeout,
 		request)
 	if err == nil || !strings.HasSuffix(err.Error(), "unexpected initiator public key") {
 		return errors.Tracef("unexpected result: %v", err)
@@ -335,7 +369,12 @@ func runTestSessions() error {
 		return errors.Trace(err)
 	}
 
-	roundTripper = newTestSessionRoundTripper(responderSessions, nil)
+	roundTripper = newTestSessionRoundTripper(
+		responderSessions,
+		nil,
+		sessionHandshakeTimeout,
+		requestDelay,
+		requestTimeout)
 
 	clientCount := 10000
 	requestCount := 100
@@ -379,10 +418,12 @@ func runTestSessions() error {
 						response, err := initiatorSessions.RoundTrip(
 							context.Background(),
 							roundTripper,
-							preRoundTrip,
 							responderPublicKey,
 							responderRootObfuscationSecret,
 							waitToShareSession,
+							sessionHandshakeTimeout,
+							requestDelay,
+							requestTimeout,
 							request)
 						if err != nil {
 							requestResultChan <- errors.Trace(err)
@@ -418,25 +459,30 @@ func runTestSessions() error {
 		}
 	}
 
-	if preRoundTripCalls.Load() < int64(clientCount*requestCount) {
-		return errors.TraceNew("unexpected pre-round trip call count")
-	}
-
 	return nil
 }
 
 type testSessionRoundTripper struct {
-	sessions              *ResponderSessions
-	expectedPeerPublicKey *SessionPublicKey
+	sessions                        *ResponderSessions
+	expectedPeerPublicKey           *SessionPublicKey
+	expectedSessionHandshakeTimeout time.Duration
+	expectedRequestDelay            time.Duration
+	expectedRequestTimeout          time.Duration
 }
 
 func newTestSessionRoundTripper(
 	sessions *ResponderSessions,
-	expectedPeerPublicKey *SessionPublicKey) *testSessionRoundTripper {
+	expectedPeerPublicKey *SessionPublicKey,
+	expectedSessionHandshakeTimeout time.Duration,
+	expectedRequestDelay time.Duration,
+	expectedRequestTimeout time.Duration) *testSessionRoundTripper {
 
 	return &testSessionRoundTripper{
-		sessions:              sessions,
-		expectedPeerPublicKey: expectedPeerPublicKey,
+		sessions:                        sessions,
+		expectedPeerPublicKey:           expectedPeerPublicKey,
+		expectedSessionHandshakeTimeout: expectedSessionHandshakeTimeout,
+		expectedRequestDelay:            expectedRequestDelay,
+		expectedRequestTimeout:          expectedRequestTimeout,
 	}
 }
 
@@ -455,7 +501,8 @@ func (t *testSessionRoundTripper) ExpectedResponse(requestPayload []byte) []byte
 
 func (t *testSessionRoundTripper) RoundTrip(
 	ctx context.Context,
-	preRoundTrip PreRoundTripCallback,
+	roundTripDelay time.Duration,
+	roundTripTimeout time.Duration,
 	requestPayload []byte) ([]byte, error) {
 
 	err := ctx.Err()
@@ -467,10 +514,15 @@ func (t *testSessionRoundTripper) RoundTrip(
 		return nil, errors.TraceNew("closed")
 	}
 
-	if preRoundTrip != nil {
-		preRoundTrip(ctx)
+	if roundTripDelay > 0 {
+		common.SleepWithContext(ctx, roundTripDelay)
 	}
 
+	_, requestCancelFunc := context.WithTimeout(ctx, roundTripTimeout)
+	defer requestCancelFunc()
+
+	isRequestRoundTrip := false
+
 	unwrappedRequestHandler := func(initiatorID ID, unwrappedRequest []byte) ([]byte, error) {
 
 		if t.expectedPeerPublicKey != nil {
@@ -485,6 +537,8 @@ func (t *testSessionRoundTripper) RoundTrip(
 			}
 		}
 
+		isRequestRoundTrip = true
+
 		return t.ExpectedResponse(unwrappedRequest), nil
 	}
 
@@ -496,7 +550,27 @@ func (t *testSessionRoundTripper) RoundTrip(
 			fmt.Printf("HandlePacket returned packet and error: %v\n", err)
 			// Continue to relay packets
 		}
+	} else {
+
+		// Handshake round trips and request payload round trips should have the
+		// appropriate delays/timeouts.
+		if isRequestRoundTrip {
+			if roundTripDelay != t.expectedRequestDelay {
+				return nil, errors.TraceNew("unexpected round trip delay")
+			}
+			if roundTripTimeout != t.expectedRequestTimeout {
+				return nil, errors.TraceNew("unexpected round trip timeout")
+			}
+		} else {
+			if roundTripDelay != time.Duration(0) {
+				return nil, errors.TraceNew("unexpected round trip delay")
+			}
+			if roundTripTimeout != t.expectedSessionHandshakeTimeout {
+				return nil, errors.TraceNew("unexpected round trip timeout")
+			}
+		}
 	}
+
 	return responsePayload, nil
 }
 

+ 2 - 0
psiphon/common/parameters/parameters.go

@@ -394,6 +394,7 @@ const (
 	InproxyBrokerProxyAnnounceTimeout                  = "InproxyBrokerProxyAnnounceTimeout"
 	InproxyBrokerClientOfferTimeout                    = "InproxyBrokerClientOfferTimeout"
 	InproxyBrokerPendingServerRequestsTTL              = "InproxyBrokerPendingServerRequestsTTL"
+	InproxySessionHandshakeRoundTripTimeout            = "InproxySessionHandshakeRoundTripTimeout"
 	InproxyProxyAnnounceRequestTimeout                 = "InproxyProxyAnnounceRequestTimeout"
 	InproxyProxyAnnounceDelay                          = "InproxyProxyAnnounceDelay"
 	InproxyProxyAnnounceDelayJitter                    = "InproxyProxyAnnounceDelayJitter"
@@ -876,6 +877,7 @@ var defaultParameters = map[string]struct {
 	InproxyBrokerProxyAnnounceTimeout:                  {value: 2 * time.Minute, minimum: time.Duration(0), flags: serverSideOnly},
 	InproxyBrokerClientOfferTimeout:                    {value: 10 * time.Second, minimum: time.Duration(0), flags: serverSideOnly},
 	InproxyBrokerPendingServerRequestsTTL:              {value: 60 * time.Second, minimum: time.Duration(0), flags: serverSideOnly},
+	InproxySessionHandshakeRoundTripTimeout:            {value: 10 * time.Second, minimum: time.Duration(0), flags: useNetworkLatencyMultiplier},
 	InproxyProxyAnnounceRequestTimeout:                 {value: 2*time.Minute + 10*time.Second, minimum: time.Duration(0)},
 	InproxyProxyAnnounceDelay:                          {value: 100 * time.Millisecond, minimum: time.Duration(0)},
 	InproxyProxyAnnounceDelayJitter:                    {value: 0.5, minimum: 0.0},

+ 33 - 8
psiphon/inproxy.go

@@ -205,6 +205,7 @@ type InproxyBrokerClientInstance struct {
 	roundTripper                  *InproxyBrokerRoundTripper
 	personalCompartmentIDs        []inproxy.ID
 	commonCompartmentIDs          []inproxy.ID
+	sessionHandshakeTimeout       time.Duration
 	announceRequestTimeout        time.Duration
 	announceDelay                 time.Duration
 	announceDelayJitter           float64
@@ -366,6 +367,7 @@ func NewInproxyBrokerClientInstance(
 		personalCompartmentIDs:      personalCompartmentIDs,
 		commonCompartmentIDs:        commonCompartmentIDs,
 
+		sessionHandshakeTimeout:       p.Duration(parameters.InproxySessionHandshakeRoundTripTimeout),
 		announceRequestTimeout:        p.Duration(parameters.InproxyProxyAnnounceRequestTimeout),
 		announceDelay:                 p.Duration(parameters.InproxyProxyAnnounceDelay),
 		announceDelayJitter:           p.Float(parameters.InproxyProxyAnnounceDelayJitter),
@@ -667,6 +669,11 @@ func (b *InproxyBrokerClientInstance) AnnounceRequestTimeout() time.Duration {
 	return b.announceRequestTimeout
 }
 
+// Implements the inproxy.BrokerDialCoordinator interface.
+func (b *InproxyBrokerClientInstance) SessionHandshakeRoundTripTimeout() time.Duration {
+	return b.sessionHandshakeTimeout
+}
+
 // Implements the inproxy.BrokerDialCoordinator interface.
 func (b *InproxyBrokerClientInstance) AnnounceDelay() time.Duration {
 	return b.announceDelay
@@ -1220,7 +1227,8 @@ func (rt *InproxyBrokerRoundTripper) Close() error {
 // response.
 func (rt *InproxyBrokerRoundTripper) RoundTrip(
 	ctx context.Context,
-	preRoundTrip inproxy.PreRoundTripCallback,
+	roundTripDelay time.Duration,
+	roundTripTimeout time.Duration,
 	requestPayload []byte) (_ []byte, retErr error) {
 
 	defer func() {
@@ -1237,11 +1245,28 @@ func (rt *InproxyBrokerRoundTripper) RoundTrip(
 	ctx, cancelFunc := common.MergeContextCancel(ctx, rt.runCtx)
 	defer cancelFunc()
 
-	// Invoke the pre-round trip callback. Currently, this callback is used to
-	// apply an announce request delay post-waitToShareSession, pre-network
-	// round trip, and cancelable by the above merged context.
-	if preRoundTrip != nil {
-		preRoundTrip(ctx)
+	// Apply any round trip delay. Currently, this is used to apply an
+	// announce request delay post-waitToShareSession, pre-network round
+	// trip, and cancelable by the above merged context.
+	if roundTripDelay > 0 {
+		common.SleepWithContext(ctx, roundTripDelay)
+	}
+
+	// Apply the round trip timeout after any delay is complete.
+	//
+	// This timeout includes any TLS handshake network round trips, as
+	// performed by the initial DialMeek and may be performed subsequently by
+	// net/http via MeekConn.RoundTrip. These extra round trips should be
+	// accounted for in the in the difference between client-side request
+	// timeouts, such as InproxyProxyAnswerRequestTimeout, and broker-side
+	// handler timeouts, such as InproxyBrokerProxyAnnounceTimeout, with the
+	// former allowing more time for network round trips.
+
+	requestCtx := ctx
+	if roundTripTimeout > 0 {
+		var requestCancelFunc context.CancelFunc
+		requestCtx, requestCancelFunc = context.WithTimeout(ctx, roundTripTimeout)
+		defer requestCancelFunc()
 	}
 
 	// The first RoundTrip caller will perform the DialMeek step, which
@@ -1268,7 +1293,7 @@ func (rt *InproxyBrokerRoundTripper) RoundTrip(
 		// DialMeek hasn't been called yet.
 
 		conn, err := DialMeek(
-			ctx,
+			requestCtx,
 			rt.brokerDialParams.meekConfig,
 			rt.brokerDialParams.dialConfig)
 
@@ -1321,7 +1346,7 @@ func (rt *InproxyBrokerRoundTripper) RoundTrip(
 		inproxy.BrokerEndPointName)
 
 	request, err := http.NewRequestWithContext(
-		ctx, "POST", url, bytes.NewBuffer(requestPayload))
+		requestCtx, "POST", url, bytes.NewBuffer(requestPayload))
 	if err != nil {
 		return nil, errors.Trace(err)
 	}