Browse Source

Non-blocking Accept

Miro 1 year ago
parent
commit
c26975ac53
1 changed files with 82 additions and 73 deletions
  1. 82 73
      psiphon/server/shadowsocks.go

+ 82 - 73
psiphon/server/shadowsocks.go

@@ -110,39 +110,7 @@ func (l *ShadowsocksListener) Accept() (net.Conn, error) {
 		return nil, errors.Trace(err)
 	}
 
-	salt, reader, err := l.readSalt(conn)
-	if err != nil {
-		return nil, errors.TraceMsg(err, "failed to read salt")
-	}
-
-	// TODO: code mostly copied from [1]; use NewShadowsocksStreamAuthenticator instead?
-	//
-	// [1] https://github.com/Jigsaw-Code/outline-ss-server/blob/fa651d3e87cc0a94104babb3ae85253471a22ebc/service/tcp.go#L138
-
-	// Hardcode key ID because all clients use the same cipher per server,
-	// which is fine because the underlying SSH connection protects the
-	// confidentiality and integrity of client traffic between the client and
-	// server.
-	keyID := "1"
-
-	isServerSalt := l.server.saltGenerator.IsServerSalt(salt)
-
-	if isServerSalt || !l.server.replayCache.Add(keyID, salt) {
-
-		go drainConn(conn)
-
-		var err error
-		if isServerSalt {
-			err = errors.TraceNew("server replay detected")
-		} else {
-			err = errors.TraceNew("client replay detected")
-		}
-
-		l.server.irregularTunnelLogger(conn.RemoteAddr().String(), err, nil)
-
-		return nil, err
-	}
-
+	reader := NewSaltReader(conn, l.server)
 	ssr := shadowsocks.NewReader(reader, l.server.key)
 	ssw := shadowsocks.NewWriter(conn, l.server.key)
 	ssw.SetSaltGenerator(l.server.saltGenerator)
@@ -151,46 +119,6 @@ func (l *ShadowsocksListener) Accept() (net.Conn, error) {
 	return NewShadowsocksConn(ssClientConn), nil
 }
 
-func drainConn(conn net.Conn) {
-	_, _ = io.Copy(io.Discard, conn)
-	conn.Close()
-}
-
-func (l *ShadowsocksListener) readSalt(conn net.Conn) ([]byte, io.Reader, error) {
-
-	type result struct {
-		salt []byte
-		err  error
-	}
-
-	resultChannel := make(chan result)
-
-	go func() {
-		saltSize := l.server.key.SaltSize()
-		salt := make([]byte, saltSize)
-		if n, err := io.ReadFull(conn, salt); err != nil {
-			resultChannel <- result{
-				err: fmt.Errorf("reading conn failed after %d bytes: %w", n, err),
-			}
-			return
-		}
-
-		resultChannel <- result{
-			salt: salt,
-		}
-	}()
-
-	select {
-	case result := <-resultChannel:
-		if result.err != nil {
-			return nil, nil, result.err
-		}
-		return result.salt, io.MultiReader(bytes.NewReader(result.salt), conn), nil
-	case <-l.server.support.TunnelServer.shutdownBroadcast:
-		return nil, nil, errors.TraceNew("shutdown broadcast")
-	}
-}
-
 // ShadowsocksConn implements the net.Conn and common.MetricsSource interfaces.
 type ShadowsocksConn struct {
 	net.Conn
@@ -221,3 +149,84 @@ func (conn *ShadowsocksConn) GetMetrics() common.LogFields {
 
 	return logFields
 }
+
+type saltReader struct {
+	net.Conn
+	server *ShadowsocksServer
+	reader io.Reader
+}
+
+func NewSaltReader(conn net.Conn, server *ShadowsocksServer) *saltReader {
+	return &saltReader{
+		Conn:   conn,
+		server: server,
+	}
+}
+
+// Note: it is assumed that the underlying transport, net.Conn, is a reliable
+// stream transport, i.e. TCP, therefore it is required that the caller stop
+// calling Read() on an instance of saltReader after an error is returned.
+func (conn *saltReader) Read(b []byte) (int, error) {
+
+	if conn.reader == nil {
+		err := conn.init()
+		if err != nil {
+			return 0, errors.Trace(err)
+		}
+	}
+
+	return conn.reader.Read(b)
+}
+
+func (conn *saltReader) init() error {
+
+	// Note: code adapted from https://github.com/Jigsaw-Code/outline-ss-server/blob/fa651d3e87cc0a94104babb3ae85253471a22ebc/service/tcp.go#L119.
+
+	salt, reader, err := readSalt(conn.Conn, conn.server.key.SaltSize())
+	if err != nil {
+		return errors.TraceMsg(err, "failed to read salt")
+	}
+
+	conn.reader = reader
+
+	// Hardcode key ID because all clients use the same cipher per server,
+	// which is fine because the underlying SSH connection protects the
+	// confidentiality and integrity of client traffic between the client and
+	// server.
+	keyID := "1"
+
+	isServerSalt := conn.server.saltGenerator.IsServerSalt(salt)
+
+	if isServerSalt || !conn.server.replayCache.Add(keyID, salt) {
+
+		go drainConn(conn)
+
+		var err error
+		if isServerSalt {
+			err = errors.TraceNew("server replay detected")
+		} else {
+			err = errors.TraceNew("client replay detected")
+		}
+
+		conn.server.irregularTunnelLogger(conn.RemoteAddr().String(), err, nil)
+
+		return err
+	}
+
+	return nil
+}
+
+func readSalt(conn net.Conn, saltSize int) ([]byte, io.Reader, error) {
+
+	salt := make([]byte, saltSize)
+	if n, err := io.ReadFull(conn, salt); err != nil {
+		return nil, nil, fmt.Errorf("reading conn failed after %d bytes: %w", n, err)
+	}
+
+	return salt, io.MultiReader(bytes.NewReader(salt), conn), nil
+}
+
+func drainConn(conn net.Conn) {
+	_, _ = io.Copy(io.Discard, conn)
+	conn.Close()
+}