Jelajahi Sumber

Merge pull request #651 from mirokuratczyk/master

Fix Accept not called again until conn classified
Rod Hynes 2 tahun lalu
induk
melakukan
ae17ec8e3d

+ 1 - 1
go.mod

@@ -19,7 +19,7 @@ require (
 	github.com/Psiphon-Labs/bolt v0.0.0-20200624191537-23cedaef7ad7
 	github.com/Psiphon-Labs/goptlib v0.0.0-20200406165125-c0e32a7a3464
 	github.com/Psiphon-Labs/quic-go v0.0.0-20230626192210-73f29effc9da
-	github.com/Psiphon-Labs/tls-tris v0.0.0-20210713133851-676a693d51ad
+	github.com/Psiphon-Labs/tls-tris v0.0.0-20230821160547-c948ccd6c156
 	github.com/armon/go-proxyproto v0.0.0-20180202201750-5b7edb60ff5f
 	github.com/bifurcation/mint v0.0.0-20180306135233-198357931e61
 	github.com/cheekybits/genny v0.0.0-20170328200008-9127e812e1e9

+ 2 - 0
go.sum

@@ -18,6 +18,8 @@ github.com/Psiphon-Labs/quic-go v0.0.0-20230626192210-73f29effc9da h1:TI2+ExyFR3
 github.com/Psiphon-Labs/quic-go v0.0.0-20230626192210-73f29effc9da/go.mod h1:wTIxqsKVrEQIxVIIYOEHuscY+PM3h6Wz79u5aF60fo0=
 github.com/Psiphon-Labs/tls-tris v0.0.0-20210713133851-676a693d51ad h1:m6HS84+b5xDPLj7D/ya1CeixyaHOCZoMbBilJ48y+Ts=
 github.com/Psiphon-Labs/tls-tris v0.0.0-20210713133851-676a693d51ad/go.mod h1:v3y9GXFo9Sf2mO6auD2ExGG7oDgrK8TI7eb49ZnUxrE=
+github.com/Psiphon-Labs/tls-tris v0.0.0-20230821160547-c948ccd6c156 h1:TlKg/9XkSlo5AqSJRVkTKIkwy/JXrQD6ybK3PZuAOwE=
+github.com/Psiphon-Labs/tls-tris v0.0.0-20230821160547-c948ccd6c156/go.mod h1:v3y9GXFo9Sf2mO6auD2ExGG7oDgrK8TI7eb49ZnUxrE=
 github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412 h1:w1UutsfOrms1J05zt7ISrnJIXKzwaspym5BTKGx93EI=
 github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412/go.mod h1:WPjqKcmVOxf0XSf3YxCJs6N6AOSrOx3obionmG7T0y0=
 github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=

+ 108 - 59
psiphon/server/demux.go

@@ -22,20 +22,23 @@ package server
 import (
 	"bytes"
 	"context"
+	std_errors "errors"
 	"net"
+	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/sirupsen/logrus"
 )
 
 // protocolDemux enables a single listener to support multiple protocols
 // by demultiplexing each conn it accepts into the corresponding protocol
 // handler.
 type protocolDemux struct {
-	ctx           context.Context
-	cancelFunc    context.CancelFunc
-	innerListener net.Listener
-	classifiers   []protocolClassifier
-	accept        chan struct{}
+	ctx                       context.Context
+	cancelFunc                context.CancelFunc
+	innerListener             net.Listener
+	classifiers               []protocolClassifier
+	connClassificationTimeout time.Duration
 
 	conns []chan net.Conn
 }
@@ -58,8 +61,13 @@ type protocolClassifier struct {
 // newProtocolDemux returns a newly initialized ProtocolDemux and an
 // array of protocol listeners. For each protocol classifier in classifiers
 // there will be a corresponding protocol listener at the same index in the
-// array of returned protocol listeners.
-func newProtocolDemux(ctx context.Context, listener net.Listener, classifiers []protocolClassifier) (*protocolDemux, []protoListener) {
+// array of returned protocol listeners. If connClassificationTimeout is >0,
+// then any conn not classified in this amount of time will be closed.
+//
+// Limitation: the conn is also closed after reading maxBytesToMatch and
+// failing to find a match, which can be a fingerprint for a raw conn with no
+// preceding anti-probing measure, such as TLS passthrough.
+func newProtocolDemux(ctx context.Context, listener net.Listener, classifiers []protocolClassifier, connClassificationTimeout time.Duration) (*protocolDemux, []protoListener) {
 
 	ctx, cancelFunc := context.WithCancel(ctx)
 
@@ -69,12 +77,12 @@ func newProtocolDemux(ctx context.Context, listener net.Listener, classifiers []
 	}
 
 	p := protocolDemux{
-		ctx:           ctx,
-		cancelFunc:    cancelFunc,
-		innerListener: listener,
-		conns:         conns,
-		classifiers:   classifiers,
-		accept:        make(chan struct{}, 1),
+		ctx:                       ctx,
+		cancelFunc:                cancelFunc,
+		innerListener:             listener,
+		conns:                     conns,
+		classifiers:               classifiers,
+		connClassificationTimeout: connClassificationTimeout,
 	}
 
 	protoListeners := make([]protoListener, len(classifiers))
@@ -115,78 +123,124 @@ func (mux *protocolDemux) run() error {
 
 	for mux.ctx.Err() == nil {
 
-		// Accept first conn immediately and then wait for downstream listeners
-		// to request new conns.
+		// Accept new conn and spawn a goroutine where it is read until
+		// either:
+		// - It matches one of the configured protocols and is sent downstream
+		//   to the corresponding protocol listener
+		// - It does not match any of the configured protocols, an error
+		//   occurs, or mux.connClassificationTimeout elapses before the conn
+		//   is classified and the conn is closed
+		// New conns are accepted, and classified, continuously even if the
+		// downstream consumers are not ready to process them, which could
+		// result in spawning many goroutines that become blocked until the
+		// downstream consumers manage to catch up. Although, this scenario
+		// should be unlikely because the producer - accepting new conns - is
+		// bounded by network I/O and the consumer is not. Generally, the
+		// consumer continuously loops accepting new conns, from its
+		// corresponding protocol listener, and immediately spawns a goroutine
+		// to handle each new conn after it is accepted.
 
 		conn, err := mux.innerListener.Accept()
 		if err != nil {
 			if mux.ctx.Err() == nil {
 				log.WithTraceFields(LogFields{"error": err}).Debug("accept failed")
-				// TODO: add backoff before continue?
 			}
 			continue
 		}
 
 		go func() {
 
-			var acc bytes.Buffer
-			b := make([]byte, readBufferSize)
+			type classifiedConnResult struct {
+				index       int
+				acc         bytes.Buffer
+				err         error
+				errLogLevel logrus.Level
+			}
 
-			for mux.ctx.Err() == nil {
+			resultChannel := make(chan *classifiedConnResult, 2)
 
-				n, err := conn.Read(b)
-				if err != nil {
-					log.WithTraceFields(LogFields{"error": err}).Debug("read conn failed")
-					break // conn will be closed
-				}
+			var connClassifiedAfterFunc *time.Timer
 
-				acc.Write(b[:n])
+			if mux.connClassificationTimeout > 0 {
+				connClassifiedAfterFunc = time.AfterFunc(mux.connClassificationTimeout, func() {
+					resultChannel <- &classifiedConnResult{
+						err:         std_errors.New("conn classification timeout"),
+						errLogLevel: logrus.DebugLevel,
+					}
+				})
+			}
 
-				for i, detector := range mux.classifiers {
+			go func() {
+				var acc bytes.Buffer
+				b := make([]byte, readBufferSize)
 
-					if acc.Len() >= detector.minBytesToMatch {
+				for mux.ctx.Err() == nil {
 
-						if detector.match(acc.Bytes()) {
+					n, err := conn.Read(b)
+					if err != nil {
+						resultChannel <- &classifiedConnResult{
+							err:         errors.TraceMsg(err, "read conn failed"),
+							errLogLevel: logrus.DebugLevel,
+						}
+						return
+					}
 
-							// Found a match, replay buffered bytes in new conn
-							// and downstream.
-							go func() {
-								bConn := newBufferedConn(conn, acc)
-								select {
-								case mux.conns[i] <- bConn:
-								case <-mux.ctx.Done():
-									bConn.Close()
-								}
-							}()
+					acc.Write(b[:n])
 
+					for i, classifier := range mux.classifiers {
+						if acc.Len() >= classifier.minBytesToMatch && classifier.match(acc.Bytes()) {
+							resultChannel <- &classifiedConnResult{
+								index: i,
+								acc:   acc,
+							}
 							return
 						}
 					}
+
+					if maxBytesToMatch != 0 && acc.Len() > maxBytesToMatch {
+						// No match. Sample does not match any classifier and is
+						// longer than required by each.
+						resultChannel <- &classifiedConnResult{
+							err:         std_errors.New("no classifier match for conn"),
+							errLogLevel: logrus.WarnLevel,
+						}
+						return
+					}
 				}
 
-				if maxBytesToMatch != 0 && acc.Len() > maxBytesToMatch {
+				resultChannel <- &classifiedConnResult{
+					err:         mux.ctx.Err(),
+					errLogLevel: logrus.DebugLevel,
+				}
+			}()
 
-					// No match. Sample does not match any detector and is
-					// longer than required by each.
-					log.WithTrace().Warning("no detector match for conn")
+			result := <-resultChannel
 
-					break // conn will be closed
+			if connClassifiedAfterFunc != nil {
+				connClassifiedAfterFunc.Stop()
+			}
+
+			if result.err != nil {
+				log.WithTraceFields(LogFields{"error": result.err}).Log(result.errLogLevel, "conn classification failed")
+
+				err := conn.Close()
+				if err != nil {
+					log.WithTraceFields(LogFields{"error": err}).Debug("close failed")
 				}
+				return
 			}
 
-			// cleanup conn
-			err := conn.Close()
-			if err != nil {
-				log.WithTraceFields(LogFields{"error": err}).Debug("close conn failed")
+			// Found a match, replay buffered bytes in new conn and send
+			// downstream.
+			// TODO: subtract the time it took to classify the conn from the
+			// subsequent SSH handshake timeout (sshHandshakeTimeout).
+			bConn := newBufferedConn(conn, result.acc)
+			select {
+			case mux.conns[result.index] <- bConn:
+			case <-mux.ctx.Done():
+				bConn.Close()
 			}
 		}()
-
-		// Wait for one of the downstream listeners to request another conn.
-		select {
-		case <-mux.accept:
-		case <-mux.ctx.Done():
-			return mux.ctx.Err()
-		}
 	}
 
 	return mux.ctx.Err()
@@ -199,11 +253,6 @@ func (mux *protocolDemux) acceptForIndex(index int) (net.Conn, error) {
 	for mux.ctx.Err() == nil {
 		select {
 		case conn := <-mux.conns[index]:
-			// trigger another accept
-			select {
-			case mux.accept <- struct{}{}:
-			default: // don't block when a signal is already buffered
-			}
 			return conn, nil
 		case <-mux.ctx.Done():
 			return nil, errors.Trace(mux.ctx.Err())

+ 1 - 1
psiphon/server/demux_test.go

@@ -59,7 +59,7 @@ func runProtocolDemuxTest(tt *protocolDemuxTest) error {
 		}
 	}()
 
-	mux, protoListeners := newProtocolDemux(context.Background(), l, tt.classifiers)
+	mux, protoListeners := newProtocolDemux(context.Background(), l, tt.classifiers, 0)
 
 	errs := make([]chan error, len(protoListeners))
 	for i := range errs {

+ 3 - 0
psiphon/server/tlsTunnel.go

@@ -43,6 +43,9 @@ type TLSTunnelServer struct {
 	obfuscatorSeedHistory  *obfuscator.SeedHistory
 }
 
+// ListenTLSTunnel returns the listener of a new TLSTunnelServer.
+// Note: the first Read or Write call on a connection returned by the listener
+// will trigger the underlying TLS handshake.
 func ListenTLSTunnel(
 	support *SupportServices,
 	listener net.Listener,

+ 1 - 1
psiphon/server/tunnelServer.go

@@ -601,7 +601,7 @@ func (sshServer *sshServer) runMeekTLSOSSHDemuxListener(sshListener *sshListener
 		return
 	}
 
-	mux, listeners := newProtocolDemux(context.Background(), listener, []protocolClassifier{meekClassifier, tlsClassifier})
+	mux, listeners := newProtocolDemux(context.Background(), listener, []protocolClassifier{meekClassifier, tlsClassifier}, sshServer.support.Config.sshHandshakeTimeout)
 
 	var wg sync.WaitGroup
 

+ 8 - 0
psiphon/tlsTunnelConn.go

@@ -171,3 +171,11 @@ func (conn *TLSTunnelConn) GetMetrics() common.LogFields {
 	}
 	return logFields
 }
+
+func (conn *TLSTunnelConn) IsClosed() bool {
+	closer, ok := conn.Conn.(common.Closer)
+	if !ok {
+		return false
+	}
+	return closer.IsClosed()
+}

+ 1 - 1
vendor/github.com/Psiphon-Labs/tls-tris/handshake_server.go

@@ -69,7 +69,7 @@ func (c *Conn) serverHandshake() error {
 	// [Psiphon]
 	// The ClientHello with the passthrough message is now available. Route the
 	// client to passthrough based on message inspection. This code assumes the
-	// client TCP conn has been wrapped with recordingConn, which has recorded
+	// client TCP conn has been wrapped with recorderConn, which has recorded
 	// all bytes sent by the client, which will be replayed, byte-for-byte, to
 	// the passthrough; as a result, passthrough clients will perform their TLS
 	// handshake with the passthrough target, receive its certificate, and in the

+ 1 - 1
vendor/modules.txt

@@ -39,7 +39,7 @@ github.com/Psiphon-Labs/quic-go/internal/utils/linkedlist
 github.com/Psiphon-Labs/quic-go/internal/wire
 github.com/Psiphon-Labs/quic-go/logging
 github.com/Psiphon-Labs/quic-go/quicvarint
-# github.com/Psiphon-Labs/tls-tris v0.0.0-20210713133851-676a693d51ad
+# github.com/Psiphon-Labs/tls-tris v0.0.0-20230821160547-c948ccd6c156
 ## explicit
 github.com/Psiphon-Labs/tls-tris
 github.com/Psiphon-Labs/tls-tris/cipherhw