Просмотр исходного кода

Add conn classification timeout

mirokuratczyk 2 лет назад
Родитель
Сommit
5a2d0a0a10
3 измененных файлов с 97 добавлено и 45 удалено
  1. 95 43
      psiphon/server/demux.go
  2. 1 1
      psiphon/server/demux_test.go
  3. 1 1
      psiphon/server/tunnelServer.go

+ 95 - 43
psiphon/server/demux.go

@@ -22,19 +22,23 @@ package server
 import (
 	"bytes"
 	"context"
+	std_errors "errors"
 	"net"
+	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/sirupsen/logrus"
 )
 
 // protocolDemux enables a single listener to support multiple protocols
 // by demultiplexing each conn it accepts into the corresponding protocol
 // handler.
 type protocolDemux struct {
-	ctx           context.Context
-	cancelFunc    context.CancelFunc
-	innerListener net.Listener
-	classifiers   []protocolClassifier
+	ctx                       context.Context
+	cancelFunc                context.CancelFunc
+	innerListener             net.Listener
+	classifiers               []protocolClassifier
+	connClassificationTimeout time.Duration
 
 	conns []chan net.Conn
 }
@@ -57,8 +61,13 @@ type protocolClassifier struct {
 // newProtocolDemux returns a newly initialized ProtocolDemux and an
 // array of protocol listeners. For each protocol classifier in classifiers
 // there will be a corresponding protocol listener at the same index in the
-// array of returned protocol listeners.
-func newProtocolDemux(ctx context.Context, listener net.Listener, classifiers []protocolClassifier) (*protocolDemux, []protoListener) {
+// array of returned protocol listeners. If connClassificationTimeout is >0,
+// then any conn not classified in this amount of time will be closed.
+//
+// Limitation: the conn is also closed after reading maxBytesToMatch and
+// failing to find a match, which can be a fingerprint for a raw conn with no
+// preceding anti-probing measure, such as TLS passthrough.
+func newProtocolDemux(ctx context.Context, listener net.Listener, classifiers []protocolClassifier, connClassificationTimeout time.Duration) (*protocolDemux, []protoListener) {
 
 	ctx, cancelFunc := context.WithCancel(ctx)
 
@@ -68,11 +77,12 @@ func newProtocolDemux(ctx context.Context, listener net.Listener, classifiers []
 	}
 
 	p := protocolDemux{
-		ctx:           ctx,
-		cancelFunc:    cancelFunc,
-		innerListener: listener,
-		conns:         conns,
-		classifiers:   classifiers,
+		ctx:                       ctx,
+		cancelFunc:                cancelFunc,
+		innerListener:             listener,
+		conns:                     conns,
+		classifiers:               classifiers,
+		connClassificationTimeout: connClassificationTimeout,
 	}
 
 	protoListeners := make([]protoListener, len(classifiers))
@@ -117,8 +127,9 @@ func (mux *protocolDemux) run() error {
 		// either:
 		// - It matches one of the configured protocols and is sent downstream
 		//   to the corresponding protocol listener
-		// - It does not match any of the configured protocols, or an error
-		//   occurs, and the conn is closed
+		// - It does not match any of the configured protocols, an error
+		//   occurs, or mux.connClassificationTimeout elapses before the conn
+		//   is classified and the conn is closed
 		// New conns are accepted, and classified, continuously even if the
 		// downstream consumers are not ready to process them, which could
 		// result in spawning many goroutines that become blocked until the
@@ -133,60 +144,101 @@ func (mux *protocolDemux) run() error {
 		if err != nil {
 			if mux.ctx.Err() == nil {
 				log.WithTraceFields(LogFields{"error": err}).Debug("accept failed")
-				// TODO: add backoff before continue?
 			}
 			continue
 		}
 
 		go func() {
 
-			var acc bytes.Buffer
-			b := make([]byte, readBufferSize)
+			type classifiedConnResult struct {
+				index       int
+				acc         bytes.Buffer
+				err         error
+				errLogLevel logrus.Level
+			}
 
-			for mux.ctx.Err() == nil {
+			resultChannel := make(chan *classifiedConnResult, 2)
 
-				n, err := conn.Read(b)
-				if err != nil {
-					log.WithTraceFields(LogFields{"error": err}).Debug("read conn failed")
-					break // conn will be closed
-				}
+			var connClassifiedAfterFunc *time.Timer
 
-				acc.Write(b[:n])
+			if mux.connClassificationTimeout > 0 {
+				connClassifiedAfterFunc = time.AfterFunc(mux.connClassificationTimeout, func() {
+					resultChannel <- &classifiedConnResult{
+						err:         std_errors.New("conn classification timeout"),
+						errLogLevel: logrus.DebugLevel,
+					}
+				})
+			}
 
-				for i, classifier := range mux.classifiers {
+			go func() {
+				var acc bytes.Buffer
+				b := make([]byte, readBufferSize)
 
-					if acc.Len() >= classifier.minBytesToMatch {
+				for mux.ctx.Err() == nil {
 
-						if classifier.match(acc.Bytes()) {
+					n, err := conn.Read(b)
+					if err != nil {
+						resultChannel <- &classifiedConnResult{
+							err:         errors.TraceMsg(err, "read conn failed"),
+							errLogLevel: logrus.DebugLevel,
+						}
+						return
+					}
 
-							// Found a match, replay buffered bytes in new conn
-							// and downstream.
-							bConn := newBufferedConn(conn, acc)
-							select {
-							case mux.conns[i] <- bConn:
-							case <-mux.ctx.Done():
-								bConn.Close()
-							}
+					acc.Write(b[:n])
 
+					for i, classifier := range mux.classifiers {
+						if acc.Len() >= classifier.minBytesToMatch && classifier.match(acc.Bytes()) {
+							resultChannel <- &classifiedConnResult{
+								index: i,
+								acc:   acc,
+							}
 							return
 						}
 					}
+
+					if maxBytesToMatch != 0 && acc.Len() > maxBytesToMatch {
+						// No match. Sample does not match any classifier and is
+						// longer than required by each.
+						resultChannel <- &classifiedConnResult{
+							err:         std_errors.New("no classifier match for conn"),
+							errLogLevel: logrus.WarnLevel,
+						}
+						return
+					}
 				}
 
-				if maxBytesToMatch != 0 && acc.Len() > maxBytesToMatch {
+				resultChannel <- &classifiedConnResult{
+					err:         mux.ctx.Err(),
+					errLogLevel: logrus.DebugLevel,
+				}
+			}()
 
-					// No match. Sample does not match any classifier and is
-					// longer than required by each.
-					log.WithTrace().Warning("no classifier match for conn")
+			result := <-resultChannel
 
-					break // conn will be closed
+			if connClassifiedAfterFunc != nil {
+				connClassifiedAfterFunc.Stop()
+			}
+
+			if result.err != nil {
+				log.WithTraceFields(LogFields{"error": result.err}).Log(result.errLogLevel, "conn classification failed")
+
+				err := conn.Close()
+				if err != nil {
+					log.WithTraceFields(LogFields{"error": err}).Debug("close failed")
 				}
+				return
 			}
 
-			// cleanup conn
-			err := conn.Close()
-			if err != nil {
-				log.WithTraceFields(LogFields{"error": err}).Debug("close conn failed")
+			// Found a match, replay buffered bytes in new conn and send
+			// downstream.
+			// TODO: subtract the time it took to classify the conn from the
+			// subsequent SSH handshake timeout (sshHandshakeTimeout).
+			bConn := newBufferedConn(conn, result.acc)
+			select {
+			case mux.conns[result.index] <- bConn:
+			case <-mux.ctx.Done():
+				bConn.Close()
 			}
 		}()
 	}

+ 1 - 1
psiphon/server/demux_test.go

@@ -59,7 +59,7 @@ func runProtocolDemuxTest(tt *protocolDemuxTest) error {
 		}
 	}()
 
-	mux, protoListeners := newProtocolDemux(context.Background(), l, tt.classifiers)
+	mux, protoListeners := newProtocolDemux(context.Background(), l, tt.classifiers, 0)
 
 	errs := make([]chan error, len(protoListeners))
 	for i := range errs {

+ 1 - 1
psiphon/server/tunnelServer.go

@@ -601,7 +601,7 @@ func (sshServer *sshServer) runMeekTLSOSSHDemuxListener(sshListener *sshListener
 		return
 	}
 
-	mux, listeners := newProtocolDemux(context.Background(), listener, []protocolClassifier{meekClassifier, tlsClassifier})
+	mux, listeners := newProtocolDemux(context.Background(), listener, []protocolClassifier{meekClassifier, tlsClassifier}, sshServer.support.Config.sshHandshakeTimeout)
 
 	var wg sync.WaitGroup