Kaynağa Gözat

Always signal a connected request when the first tunnel is [re]established

Rod Hynes 10 yıl önce
ebeveyn
işleme
84452335ac
1 değiştirilmiş dosya ile 54 ekleme ve 37 silme
  1. 54 37
      psiphon/controller.go

+ 54 - 37
psiphon/controller.go

@@ -58,6 +58,7 @@ type Controller struct {
 	splitTunnelClassifier          *SplitTunnelClassifier
 	signalFetchRemoteServerList    chan struct{}
 	impairedProtocolClassification map[string]int
+	signalReportConnected          chan struct{}
 }
 
 // NewController initializes a new controller.
@@ -97,20 +98,19 @@ func NewController(config *Config) (controller *Controller, err error) {
 		runWaitGroup:           new(sync.WaitGroup),
 		// establishedTunnels and failedTunnels buffer sizes are large enough to
 		// receive full pools of tunnels without blocking. Senders should not block.
-		establishedTunnels:       make(chan *Tunnel, config.TunnelPoolSize),
-		failedTunnels:            make(chan *Tunnel, config.TunnelPoolSize),
-		tunnels:                  make([]*Tunnel, 0),
-		establishedOnce:          false,
-		startedConnectedReporter: false,
-		startedUpgradeDownloader: false,
-		isEstablishing:           false,
-		establishPendingConns:    new(Conns),
-		untunneledPendingConns:   untunneledPendingConns,
-		untunneledDialConfig:     untunneledDialConfig,
-		// A buffer allows at least one signal to be sent even when the receiver is
-		// not listening. Senders should not block.
-		signalFetchRemoteServerList:    make(chan struct{}, 1),
+		establishedTunnels:             make(chan *Tunnel, config.TunnelPoolSize),
+		failedTunnels:                  make(chan *Tunnel, config.TunnelPoolSize),
+		tunnels:                        make([]*Tunnel, 0),
+		establishedOnce:                false,
+		startedConnectedReporter:       false,
+		startedUpgradeDownloader:       false,
+		isEstablishing:                 false,
+		establishPendingConns:          new(Conns),
+		untunneledPendingConns:         untunneledPendingConns,
+		untunneledDialConfig:           untunneledDialConfig,
+		signalFetchRemoteServerList:    make(chan struct{}),
 		impairedProtocolClassification: make(map[string]int),
+		signalReportConnected:          make(chan struct{}),
 	}
 
 	controller.splitTunnelClassifier = NewSplitTunnelClassifier(config, controller)
@@ -274,7 +274,9 @@ func (controller *Controller) establishTunnelWatcher() {
 // comment in DoConnectedRequest for a description of the request mechanism.
 // To ensure we don't over- or under-count unique users, only one connected
 // request is made across all simultaneous multi-tunnels; and the connected
-// request is repeated periodically.
+// request is repeated periodically for very long-lived tunnels.
+// The signalReportConnected mechanism is used to trigger another connected
+// request immediately after a reconnect.
 func (controller *Controller) connectedReporter() {
 	defer controller.runWaitGroup.Done()
 loop:
@@ -302,8 +304,10 @@ loop:
 		}
 		timeout := time.After(duration)
 		select {
+		case <-controller.signalReportConnected:
 		case <-timeout:
 			// Make another connected request
+
 		case <-controller.shutdownBroadcast:
 			break loop
 		}
@@ -312,7 +316,7 @@ loop:
 	NoticeInfo("exiting connected reporter")
 }
 
-func (controller *Controller) startConnectedReporter() {
+func (controller *Controller) startOrSignalConnectedReporter() {
 	// session is nil when DisableApi is set
 	if controller.config.DisableApi {
 		return
@@ -324,6 +328,11 @@ func (controller *Controller) startConnectedReporter() {
 		controller.startedConnectedReporter = true
 		controller.runWaitGroup.Add(1)
 		go controller.connectedReporter()
+	} else {
+		select {
+		case controller.signalReportConnected <- *new(struct{}):
+		default:
+		}
 	}
 }
 
@@ -439,16 +448,39 @@ loop:
 		// !TODO! design issue: might not be enough server entries with region/caps to ever fill tunnel slots
 		// solution(?) target MIN(CountServerEntries(region, protocol), TunnelPoolSize)
 		case establishedTunnel := <-controller.establishedTunnels:
-			if controller.registerTunnel(establishedTunnel) {
+			tunnelCount, registered := controller.registerTunnel(establishedTunnel)
+			if registered {
 				NoticeActiveTunnel(establishedTunnel.serverEntry.IpAddress, establishedTunnel.protocol)
+
+				if tunnelCount == 1 {
+
+					// The split tunnel classifier is started once the first tunnel is
+					// established. This first tunnel is passed in to be used to make
+					// the routes data request.
+					// A long-running controller may run while the host device is present
+					// in different regions. In this case, we want the split tunnel logic
+					// to switch to routes for new regions and not classify traffic based
+					// on routes installed for older regions.
+					// We assume that when regions change, the host network will also
+					// change, and so all tunnels will fail and be re-established. Under
+					// that assumption, the classifier will be re-Start()-ed here when
+					// the region has changed.
+					controller.splitTunnelClassifier.Start(establishedTunnel)
+
+					// Signal a connected request on each 1st tunnel establishment. For
+					// multi-tunnels, the session is connected as long as at least one
+					// tunnel is established.
+					controller.startOrSignalConnectedReporter()
+
+					controller.startClientUpgradeDownloader(establishedTunnel.session)
+				}
+
 			} else {
 				controller.discardTunnel(establishedTunnel)
 			}
 			if controller.isFullyEstablished() {
 				controller.stopEstablishing()
 			}
-			controller.startConnectedReporter()
-			controller.startClientUpgradeDownloader(establishedTunnel.session)
 
 		case <-controller.shutdownBroadcast:
 			break loop
@@ -546,40 +578,25 @@ func (controller *Controller) discardTunnel(tunnel *Tunnel) {
 // registerTunnel adds the connected tunnel to the pool of active tunnels
 // which are candidates for port forwarding. Returns true if the pool has an
 // empty slot and false if the pool is full (caller should discard the tunnel).
-func (controller *Controller) registerTunnel(tunnel *Tunnel) bool {
+func (controller *Controller) registerTunnel(tunnel *Tunnel) (int, bool) {
 	controller.tunnelMutex.Lock()
 	defer controller.tunnelMutex.Unlock()
 	if len(controller.tunnels) >= controller.config.TunnelPoolSize {
-		return false
+		return len(controller.tunnels), false
 	}
 	// Perform a final check just in case we've established
 	// a duplicate connection.
 	for _, activeTunnel := range controller.tunnels {
 		if activeTunnel.serverEntry.IpAddress == tunnel.serverEntry.IpAddress {
 			NoticeAlert("duplicate tunnel: %s", tunnel.serverEntry.IpAddress)
-			return false
+			return len(controller.tunnels), false
 		}
 	}
 	controller.establishedOnce = true
 	controller.tunnels = append(controller.tunnels, tunnel)
 	NoticeTunnels(len(controller.tunnels))
 
-	// The split tunnel classifier is started once the first tunnel is
-	// established. This first tunnel is passed in to be used to make
-	// the routes data request.
-	// A long-running controller may run while the host device is present
-	// in different regions. In this case, we want the split tunnel logic
-	// to switch to routes for new regions and not classify traffic based
-	// on routes installed for older regions.
-	// We assume that when regions change, the host network will also
-	// change, and so all tunnels will fail and be re-established. Under
-	// that assumption, the classifier will be re-Start()-ed here when
-	// the region has changed.
-	if len(controller.tunnels) == 1 {
-		controller.splitTunnelClassifier.Start(tunnel)
-	}
-
-	return true
+	return len(controller.tunnels), true
 }
 
 // hasEstablishedOnce indicates if at least one active tunnel has