Просмотр исходного кода

Fix Accept not called again until conn classified

mirokuratczyk 2 лет назад
Родитель
Сommit
a7b46caaa0
2 измененных файлов с 34 добавлено и 29 удалено
  1. 26 29
      psiphon/server/demux.go
  2. 8 0
      psiphon/tlsTunnelConn.go

+ 26 - 29
psiphon/server/demux.go

@@ -35,7 +35,6 @@ type protocolDemux struct {
 	cancelFunc    context.CancelFunc
 	innerListener net.Listener
 	classifiers   []protocolClassifier
-	accept        chan struct{}
 
 	conns []chan net.Conn
 }
@@ -74,7 +73,6 @@ func newProtocolDemux(ctx context.Context, listener net.Listener, classifiers []
 		innerListener: listener,
 		conns:         conns,
 		classifiers:   classifiers,
-		accept:        make(chan struct{}, 1),
 	}
 
 	protoListeners := make([]protoListener, len(classifiers))
@@ -115,8 +113,21 @@ func (mux *protocolDemux) run() error {
 
 	for mux.ctx.Err() == nil {
 
-		// Accept first conn immediately and then wait for downstream listeners
-		// to request new conns.
+		// Accept new conn and spawn a goroutine where it is read until
+		// either:
+		// - It matches one of the configured protocols and is sent downstream
+		//   to the corresponding protocol listener
+		// - It does not match any of the configured protocols, or an error
+		//   occurs, and the conn is closed
+		// New conns are accepted, and classified, continuously even if the
+		// downstream consumers are not ready to process them, which could
+		// result in spawning many goroutines that become blocked until the
+		// downstream consumers manage to catch up. Although, this scenario
+		// should be unlikely because the producer - accepting new conns - is
+		// bounded by network I/O and the consumer is not. Generally, the
+		// consumer continuously loops accepting new conns, from its
+		// corresponding protocol listener, and immediately spawns a goroutine
+		// to handle each new conn after it is accepted.
 
 		conn, err := mux.innerListener.Accept()
 		if err != nil {
@@ -142,22 +153,20 @@ func (mux *protocolDemux) run() error {
 
 				acc.Write(b[:n])
 
-				for i, detector := range mux.classifiers {
+				for i, classifier := range mux.classifiers {
 
-					if acc.Len() >= detector.minBytesToMatch {
+					if acc.Len() >= classifier.minBytesToMatch {
 
-						if detector.match(acc.Bytes()) {
+						if classifier.match(acc.Bytes()) {
 
 							// Found a match, replay buffered bytes in new conn
 							// and downstream.
-							go func() {
-								bConn := newBufferedConn(conn, acc)
-								select {
-								case mux.conns[i] <- bConn:
-								case <-mux.ctx.Done():
-									bConn.Close()
-								}
-							}()
+							bConn := newBufferedConn(conn, acc)
+							select {
+							case mux.conns[i] <- bConn:
+							case <-mux.ctx.Done():
+								bConn.Close()
+							}
 
 							return
 						}
@@ -166,9 +175,9 @@ func (mux *protocolDemux) run() error {
 
 				if maxBytesToMatch != 0 && acc.Len() > maxBytesToMatch {
 
-					// No match. Sample does not match any detector and is
+					// No match. Sample does not match any classifier and is
 					// longer than required by each.
-					log.WithTrace().Warning("no detector match for conn")
+					log.WithTrace().Warning("no classifier match for conn")
 
 					break // conn will be closed
 				}
@@ -180,13 +189,6 @@ func (mux *protocolDemux) run() error {
 				log.WithTraceFields(LogFields{"error": err}).Debug("close conn failed")
 			}
 		}()
-
-		// Wait for one of the downstream listeners to request another conn.
-		select {
-		case <-mux.accept:
-		case <-mux.ctx.Done():
-			return mux.ctx.Err()
-		}
 	}
 
 	return mux.ctx.Err()
@@ -199,11 +201,6 @@ func (mux *protocolDemux) acceptForIndex(index int) (net.Conn, error) {
 	for mux.ctx.Err() == nil {
 		select {
 		case conn := <-mux.conns[index]:
-			// trigger another accept
-			select {
-			case mux.accept <- struct{}{}:
-			default: // don't block when a signal is already buffered
-			}
 			return conn, nil
 		case <-mux.ctx.Done():
 			return nil, errors.Trace(mux.ctx.Err())

+ 8 - 0
psiphon/tlsTunnelConn.go

@@ -171,3 +171,11 @@ func (conn *TLSTunnelConn) GetMetrics() common.LogFields {
 	}
 	return logFields
 }
+
+func (conn *TLSTunnelConn) IsClosed() bool {
+	closer, ok := conn.Conn.(common.Closer)
+	if !ok {
+		return false
+	}
+	return closer.IsClosed()
+}