|
|
@@ -110,39 +110,7 @@ func (l *ShadowsocksListener) Accept() (net.Conn, error) {
|
|
|
return nil, errors.Trace(err)
|
|
|
}
|
|
|
|
|
|
- salt, reader, err := l.readSalt(conn)
|
|
|
- if err != nil {
|
|
|
- return nil, errors.TraceMsg(err, "failed to read salt")
|
|
|
- }
|
|
|
-
|
|
|
- // TODO: code mostly copied from [1]; use NewShadowsocksStreamAuthenticator instead?
|
|
|
- //
|
|
|
- // [1] https://github.com/Jigsaw-Code/outline-ss-server/blob/fa651d3e87cc0a94104babb3ae85253471a22ebc/service/tcp.go#L138
|
|
|
-
|
|
|
- // Hardcode key ID because all clients use the same cipher per server,
|
|
|
- // which is fine because the underlying SSH connection protects the
|
|
|
- // confidentiality and integrity of client traffic between the client and
|
|
|
- // server.
|
|
|
- keyID := "1"
|
|
|
-
|
|
|
- isServerSalt := l.server.saltGenerator.IsServerSalt(salt)
|
|
|
-
|
|
|
- if isServerSalt || !l.server.replayCache.Add(keyID, salt) {
|
|
|
-
|
|
|
- go drainConn(conn)
|
|
|
-
|
|
|
- var err error
|
|
|
- if isServerSalt {
|
|
|
- err = errors.TraceNew("server replay detected")
|
|
|
- } else {
|
|
|
- err = errors.TraceNew("client replay detected")
|
|
|
- }
|
|
|
-
|
|
|
- l.server.irregularTunnelLogger(conn.RemoteAddr().String(), err, nil)
|
|
|
-
|
|
|
- return nil, err
|
|
|
- }
|
|
|
-
|
|
|
+ reader := NewSaltReader(conn, l.server)
|
|
|
ssr := shadowsocks.NewReader(reader, l.server.key)
|
|
|
ssw := shadowsocks.NewWriter(conn, l.server.key)
|
|
|
ssw.SetSaltGenerator(l.server.saltGenerator)
|
|
|
@@ -151,46 +119,6 @@ func (l *ShadowsocksListener) Accept() (net.Conn, error) {
|
|
|
return NewShadowsocksConn(ssClientConn), nil
|
|
|
}
|
|
|
|
|
|
-func drainConn(conn net.Conn) {
|
|
|
- _, _ = io.Copy(io.Discard, conn)
|
|
|
- conn.Close()
|
|
|
-}
|
|
|
-
|
|
|
-func (l *ShadowsocksListener) readSalt(conn net.Conn) ([]byte, io.Reader, error) {
|
|
|
-
|
|
|
- type result struct {
|
|
|
- salt []byte
|
|
|
- err error
|
|
|
- }
|
|
|
-
|
|
|
- resultChannel := make(chan result)
|
|
|
-
|
|
|
- go func() {
|
|
|
- saltSize := l.server.key.SaltSize()
|
|
|
- salt := make([]byte, saltSize)
|
|
|
- if n, err := io.ReadFull(conn, salt); err != nil {
|
|
|
- resultChannel <- result{
|
|
|
- err: fmt.Errorf("reading conn failed after %d bytes: %w", n, err),
|
|
|
- }
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- resultChannel <- result{
|
|
|
- salt: salt,
|
|
|
- }
|
|
|
- }()
|
|
|
-
|
|
|
- select {
|
|
|
- case result := <-resultChannel:
|
|
|
- if result.err != nil {
|
|
|
- return nil, nil, result.err
|
|
|
- }
|
|
|
- return result.salt, io.MultiReader(bytes.NewReader(result.salt), conn), nil
|
|
|
- case <-l.server.support.TunnelServer.shutdownBroadcast:
|
|
|
- return nil, nil, errors.TraceNew("shutdown broadcast")
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
// ShadowsocksConn implements the net.Conn and common.MetricsSource interfaces.
|
|
|
type ShadowsocksConn struct {
|
|
|
net.Conn
|
|
|
@@ -221,3 +149,84 @@ func (conn *ShadowsocksConn) GetMetrics() common.LogFields {
|
|
|
|
|
|
return logFields
|
|
|
}
|
|
|
+
|
|
|
+type saltReader struct {
|
|
|
+ net.Conn
|
|
|
+ server *ShadowsocksServer
|
|
|
+ reader io.Reader
|
|
|
+}
|
|
|
+
|
|
|
+func NewSaltReader(conn net.Conn, server *ShadowsocksServer) *saltReader {
|
|
|
+ return &saltReader{
|
|
|
+ Conn: conn,
|
|
|
+ server: server,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Note: it is assumed that the underlying transport, net.Conn, is a reliable
|
|
|
+// stream transport, i.e. TCP, therefore it is required that the caller stop
|
|
|
+// calling Read() on an instance of saltReader after an error is returned.
|
|
|
+func (conn *saltReader) Read(b []byte) (int, error) {
|
|
|
+
|
|
|
+ if conn.reader == nil {
|
|
|
+ err := conn.init()
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.Trace(err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return conn.reader.Read(b)
|
|
|
+}
|
|
|
+
|
|
|
+func (conn *saltReader) init() error {
|
|
|
+
|
|
|
+ // Note: code adapted from https://github.com/Jigsaw-Code/outline-ss-server/blob/fa651d3e87cc0a94104babb3ae85253471a22ebc/service/tcp.go#L119.
|
|
|
+
|
|
|
+ salt, reader, err := readSalt(conn.Conn, conn.server.key.SaltSize())
|
|
|
+ if err != nil {
|
|
|
+ return errors.TraceMsg(err, "failed to read salt")
|
|
|
+ }
|
|
|
+
|
|
|
+ conn.reader = reader
|
|
|
+
|
|
|
+ // Hardcode key ID because all clients use the same cipher per server,
|
|
|
+ // which is fine because the underlying SSH connection protects the
|
|
|
+ // confidentiality and integrity of client traffic between the client and
|
|
|
+ // server.
|
|
|
+ keyID := "1"
|
|
|
+
|
|
|
+ isServerSalt := conn.server.saltGenerator.IsServerSalt(salt)
|
|
|
+
|
|
|
+ if isServerSalt || !conn.server.replayCache.Add(keyID, salt) {
|
|
|
+
|
|
|
+ go drainConn(conn)
|
|
|
+
|
|
|
+ var err error
|
|
|
+ if isServerSalt {
|
|
|
+ err = errors.TraceNew("server replay detected")
|
|
|
+ } else {
|
|
|
+ err = errors.TraceNew("client replay detected")
|
|
|
+ }
|
|
|
+
|
|
|
+ conn.server.irregularTunnelLogger(conn.RemoteAddr().String(), err, nil)
|
|
|
+
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func readSalt(conn net.Conn, saltSize int) ([]byte, io.Reader, error) {
|
|
|
+
|
|
|
+ salt := make([]byte, saltSize)
|
|
|
+ if n, err := io.ReadFull(conn, salt); err != nil {
|
|
|
+ return nil, nil, fmt.Errorf("reading conn failed after %d bytes: %w", n, err)
|
|
|
+ }
|
|
|
+
|
|
|
+ return salt, io.MultiReader(bytes.NewReader(salt), conn), nil
|
|
|
+}
|
|
|
+
|
|
|
+func drainConn(conn net.Conn) {
|
|
|
+ _, _ = io.Copy(io.Discard, conn)
|
|
|
+ conn.Close()
|
|
|
+}
|