|
|
@@ -22,20 +22,23 @@ package server
|
|
|
import (
|
|
|
"bytes"
|
|
|
"context"
|
|
|
+ std_errors "errors"
|
|
|
"net"
|
|
|
+ "time"
|
|
|
|
|
|
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
|
|
|
+ "github.com/sirupsen/logrus"
|
|
|
)
|
|
|
|
|
|
// protocolDemux enables a single listener to support multiple protocols
|
|
|
// by demultiplexing each conn it accepts into the corresponding protocol
|
|
|
// handler.
|
|
|
type protocolDemux struct {
|
|
|
- ctx context.Context
|
|
|
- cancelFunc context.CancelFunc
|
|
|
- innerListener net.Listener
|
|
|
- classifiers []protocolClassifier
|
|
|
- accept chan struct{}
|
|
|
+ ctx context.Context
|
|
|
+ cancelFunc context.CancelFunc
|
|
|
+ innerListener net.Listener
|
|
|
+ classifiers []protocolClassifier
|
|
|
+ connClassificationTimeout time.Duration
|
|
|
|
|
|
conns []chan net.Conn
|
|
|
}
|
|
|
@@ -58,8 +61,13 @@ type protocolClassifier struct {
|
|
|
// newProtocolDemux returns a newly initialized ProtocolDemux and an
|
|
|
// array of protocol listeners. For each protocol classifier in classifiers
|
|
|
// there will be a corresponding protocol listener at the same index in the
|
|
|
-// array of returned protocol listeners.
|
|
|
-func newProtocolDemux(ctx context.Context, listener net.Listener, classifiers []protocolClassifier) (*protocolDemux, []protoListener) {
|
|
|
+// array of returned protocol listeners. If connClassificationTimeout is >0,
|
|
|
+// then any conn not classified in this amount of time will be closed.
|
|
|
+//
|
|
|
+// Limitation: the conn is also closed after reading maxBytesToMatch and
|
|
|
+// failing to find a match, which can be a fingerprint for a raw conn with no
|
|
|
+// preceding anti-probing measure, such as TLS passthrough.
|
|
|
+func newProtocolDemux(ctx context.Context, listener net.Listener, classifiers []protocolClassifier, connClassificationTimeout time.Duration) (*protocolDemux, []protoListener) {
|
|
|
|
|
|
ctx, cancelFunc := context.WithCancel(ctx)
|
|
|
|
|
|
@@ -69,12 +77,12 @@ func newProtocolDemux(ctx context.Context, listener net.Listener, classifiers []
|
|
|
}
|
|
|
|
|
|
p := protocolDemux{
|
|
|
- ctx: ctx,
|
|
|
- cancelFunc: cancelFunc,
|
|
|
- innerListener: listener,
|
|
|
- conns: conns,
|
|
|
- classifiers: classifiers,
|
|
|
- accept: make(chan struct{}, 1),
|
|
|
+ ctx: ctx,
|
|
|
+ cancelFunc: cancelFunc,
|
|
|
+ innerListener: listener,
|
|
|
+ conns: conns,
|
|
|
+ classifiers: classifiers,
|
|
|
+ connClassificationTimeout: connClassificationTimeout,
|
|
|
}
|
|
|
|
|
|
protoListeners := make([]protoListener, len(classifiers))
|
|
|
@@ -115,78 +123,124 @@ func (mux *protocolDemux) run() error {
|
|
|
|
|
|
for mux.ctx.Err() == nil {
|
|
|
|
|
|
- // Accept first conn immediately and then wait for downstream listeners
|
|
|
- // to request new conns.
|
|
|
+ // Accept new conn and spawn a goroutine where it is read until
|
|
|
+ // either:
|
|
|
+ // - It matches one of the configured protocols and is sent downstream
|
|
|
+ // to the corresponding protocol listener
|
|
|
+ // - It does not match any of the configured protocols, an error
|
|
|
+ // occurs, or mux.connClassificationTimeout elapses before the conn
|
|
|
+ // is classified and the conn is closed
|
|
|
+ // New conns are accepted, and classified, continuously even if the
|
|
|
+ // downstream consumers are not ready to process them, which could
|
|
|
+ // result in spawning many goroutines that become blocked until the
|
|
|
+ // downstream consumers manage to catch up. Although, this scenario
|
|
|
+ // should be unlikely because the producer - accepting new conns - is
|
|
|
+ // bounded by network I/O and the consumer is not. Generally, the
|
|
|
+ // consumer continuously loops accepting new conns, from its
|
|
|
+ // corresponding protocol listener, and immediately spawns a goroutine
|
|
|
+ // to handle each new conn after it is accepted.
|
|
|
|
|
|
conn, err := mux.innerListener.Accept()
|
|
|
if err != nil {
|
|
|
if mux.ctx.Err() == nil {
|
|
|
log.WithTraceFields(LogFields{"error": err}).Debug("accept failed")
|
|
|
- // TODO: add backoff before continue?
|
|
|
}
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
go func() {
|
|
|
|
|
|
- var acc bytes.Buffer
|
|
|
- b := make([]byte, readBufferSize)
|
|
|
+ type classifiedConnResult struct {
|
|
|
+ index int
|
|
|
+ acc bytes.Buffer
|
|
|
+ err error
|
|
|
+ errLogLevel logrus.Level
|
|
|
+ }
|
|
|
|
|
|
- for mux.ctx.Err() == nil {
|
|
|
+ resultChannel := make(chan *classifiedConnResult, 2)
|
|
|
|
|
|
- n, err := conn.Read(b)
|
|
|
- if err != nil {
|
|
|
- log.WithTraceFields(LogFields{"error": err}).Debug("read conn failed")
|
|
|
- break // conn will be closed
|
|
|
- }
|
|
|
+ var connClassifiedAfterFunc *time.Timer
|
|
|
|
|
|
- acc.Write(b[:n])
|
|
|
+ if mux.connClassificationTimeout > 0 {
|
|
|
+ connClassifiedAfterFunc = time.AfterFunc(mux.connClassificationTimeout, func() {
|
|
|
+ resultChannel <- &classifiedConnResult{
|
|
|
+ err: std_errors.New("conn classification timeout"),
|
|
|
+ errLogLevel: logrus.DebugLevel,
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
|
|
|
- for i, detector := range mux.classifiers {
|
|
|
+ go func() {
|
|
|
+ var acc bytes.Buffer
|
|
|
+ b := make([]byte, readBufferSize)
|
|
|
|
|
|
- if acc.Len() >= detector.minBytesToMatch {
|
|
|
+ for mux.ctx.Err() == nil {
|
|
|
|
|
|
- if detector.match(acc.Bytes()) {
|
|
|
+ n, err := conn.Read(b)
|
|
|
+ if err != nil {
|
|
|
+ resultChannel <- &classifiedConnResult{
|
|
|
+ err: errors.TraceMsg(err, "read conn failed"),
|
|
|
+ errLogLevel: logrus.DebugLevel,
|
|
|
+ }
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- // Found a match, replay buffered bytes in new conn
|
|
|
- // and downstream.
|
|
|
- go func() {
|
|
|
- bConn := newBufferedConn(conn, acc)
|
|
|
- select {
|
|
|
- case mux.conns[i] <- bConn:
|
|
|
- case <-mux.ctx.Done():
|
|
|
- bConn.Close()
|
|
|
- }
|
|
|
- }()
|
|
|
+ acc.Write(b[:n])
|
|
|
|
|
|
+ for i, classifier := range mux.classifiers {
|
|
|
+ if acc.Len() >= classifier.minBytesToMatch && classifier.match(acc.Bytes()) {
|
|
|
+ resultChannel <- &classifiedConnResult{
|
|
|
+ index: i,
|
|
|
+ acc: acc,
|
|
|
+ }
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ if maxBytesToMatch != 0 && acc.Len() > maxBytesToMatch {
|
|
|
+ // No match. Sample does not match any classifier and is
|
|
|
+ // longer than required by each.
|
|
|
+ resultChannel <- &classifiedConnResult{
|
|
|
+ err: std_errors.New("no classifier match for conn"),
|
|
|
+ errLogLevel: logrus.WarnLevel,
|
|
|
+ }
|
|
|
+ return
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- if maxBytesToMatch != 0 && acc.Len() > maxBytesToMatch {
|
|
|
+ resultChannel <- &classifiedConnResult{
|
|
|
+ err: mux.ctx.Err(),
|
|
|
+ errLogLevel: logrus.DebugLevel,
|
|
|
+ }
|
|
|
+ }()
|
|
|
|
|
|
- // No match. Sample does not match any detector and is
|
|
|
- // longer than required by each.
|
|
|
- log.WithTrace().Warning("no detector match for conn")
|
|
|
+ result := <-resultChannel
|
|
|
|
|
|
- break // conn will be closed
|
|
|
+ if connClassifiedAfterFunc != nil {
|
|
|
+ connClassifiedAfterFunc.Stop()
|
|
|
+ }
|
|
|
+
|
|
|
+ if result.err != nil {
|
|
|
+ log.WithTraceFields(LogFields{"error": result.err}).Log(result.errLogLevel, "conn classification failed")
|
|
|
+
|
|
|
+ err := conn.Close()
|
|
|
+ if err != nil {
|
|
|
+ log.WithTraceFields(LogFields{"error": err}).Debug("close failed")
|
|
|
}
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
- // cleanup conn
|
|
|
- err := conn.Close()
|
|
|
- if err != nil {
|
|
|
- log.WithTraceFields(LogFields{"error": err}).Debug("close conn failed")
|
|
|
+ // Found a match, replay buffered bytes in new conn and send
|
|
|
+ // downstream.
|
|
|
+ // TODO: subtract the time it took to classify the conn from the
|
|
|
+ // subsequent SSH handshake timeout (sshHandshakeTimeout).
|
|
|
+ bConn := newBufferedConn(conn, result.acc)
|
|
|
+ select {
|
|
|
+ case mux.conns[result.index] <- bConn:
|
|
|
+ case <-mux.ctx.Done():
|
|
|
+ bConn.Close()
|
|
|
}
|
|
|
}()
|
|
|
-
|
|
|
- // Wait for one of the downstream listeners to request another conn.
|
|
|
- select {
|
|
|
- case <-mux.accept:
|
|
|
- case <-mux.ctx.Done():
|
|
|
- return mux.ctx.Err()
|
|
|
- }
|
|
|
}
|
|
|
|
|
|
return mux.ctx.Err()
|
|
|
@@ -199,11 +253,6 @@ func (mux *protocolDemux) acceptForIndex(index int) (net.Conn, error) {
|
|
|
for mux.ctx.Err() == nil {
|
|
|
select {
|
|
|
case conn := <-mux.conns[index]:
|
|
|
- // trigger another accept
|
|
|
- select {
|
|
|
- case mux.accept <- struct{}{}:
|
|
|
- default: // don't block when a signal is already buffered
|
|
|
- }
|
|
|
return conn, nil
|
|
|
case <-mux.ctx.Done():
|
|
|
return nil, errors.Trace(mux.ctx.Err())
|