| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395 |
- // Copyright 2018 Jigsaw Operations LLC
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // https://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package service
- import (
- "bytes"
- "container/list"
- "context"
- "errors"
- "fmt"
- "io"
- "log/slog"
- "net"
- "net/netip"
- "sync"
- "time"
- "github.com/Jigsaw-Code/outline-sdk/transport"
- "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
- "github.com/shadowsocks/go-shadowsocks2/socks"
- onet "github.com/Jigsaw-Code/outline-ss-server/net"
- "github.com/Jigsaw-Code/outline-ss-server/service/metrics"
- )
- // TCPConnMetrics is used to report metrics on TCP connections.
- type TCPConnMetrics interface {
- AddAuthenticated(accessKey string)
- AddClosed(status string, data metrics.ProxyMetrics, duration time.Duration)
- AddProbe(status, drainResult string, clientProxyBytes int64)
- }
- func remoteIP(conn net.Conn) netip.Addr {
- addr := conn.RemoteAddr()
- if addr == nil {
- return netip.Addr{}
- }
- if tcpaddr, ok := addr.(*net.TCPAddr); ok {
- return tcpaddr.AddrPort().Addr()
- }
- addrPort, err := netip.ParseAddrPort(addr.String())
- if err == nil {
- return addrPort.Addr()
- }
- return netip.Addr{}
- }
- // Wrapper for slog.Debug during TCP access key searches.
- func debugTCP(l *slog.Logger, template string, cipherID string, attr slog.Attr) {
- // This is an optimization to reduce unnecessary allocations due to an interaction
- // between Go's inlining/escape analysis and varargs functions like slog.Debug.
- if l.Enabled(nil, slog.LevelDebug) {
- l.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("TCP: %s", template), slog.String("ID", cipherID), attr)
- }
- }
- // bytesForKeyFinding is the number of bytes to read for finding the AccessKey.
- // Is must satisfy provided >= bytesForKeyFinding >= required for every cipher in the list.
- // provided = saltSize + 2 + 2 * cipher.TagSize, the minimum number of bytes we will see in a valid connection
- // required = saltSize + 2 + cipher.TagSize, the number of bytes needed to authenticate the connection.
- const bytesForKeyFinding = 50
- func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList CipherList, l *slog.Logger) (*CipherEntry, io.Reader, []byte, time.Duration, error) {
- // We snapshot the list because it may be modified while we use it.
- ciphers := cipherList.SnapshotForClientIP(clientIP)
- firstBytes := make([]byte, bytesForKeyFinding)
- if n, err := io.ReadFull(clientReader, firstBytes); err != nil {
- return nil, clientReader, nil, 0, fmt.Errorf("reading header failed after %d bytes: %w", n, err)
- }
- findStartTime := time.Now()
- entry, elt := findEntry(firstBytes, ciphers, l)
- timeToCipher := time.Since(findStartTime)
- if entry == nil {
- // TODO: Ban and log client IPs with too many failures too quick to protect against DoS.
- return nil, clientReader, nil, timeToCipher, fmt.Errorf("could not find valid TCP cipher")
- }
- // Move the active cipher to the front, so that the search is quicker next time.
- cipherList.MarkUsedByClientIP(elt, clientIP)
- salt := firstBytes[:entry.CryptoKey.SaltSize()]
- return entry, io.MultiReader(bytes.NewReader(firstBytes), clientReader), salt, timeToCipher, nil
- }
- // Implements a trial decryption search. This assumes that all ciphers are AEAD.
- func findEntry(firstBytes []byte, ciphers []*list.Element, l *slog.Logger) (*CipherEntry, *list.Element) {
- // To hold the decrypted chunk length.
- chunkLenBuf := [2]byte{}
- for ci, elt := range ciphers {
- entry := elt.Value.(*CipherEntry)
- cryptoKey := entry.CryptoKey
- _, err := shadowsocks.Unpack(chunkLenBuf[:0], firstBytes[:cryptoKey.SaltSize()+2+cryptoKey.TagSize()], cryptoKey)
- if err != nil {
- debugTCP(l, "Failed to decrypt length.", entry.ID, slog.Any("err", err))
- continue
- }
- debugTCP(l, "Found cipher.", entry.ID, slog.Int("index", ci))
- return entry, elt
- }
- return nil, nil
- }
- type StreamAuthenticateFunc func(clientConn transport.StreamConn) (string, transport.StreamConn, *onet.ConnectionError)
- // NewShadowsocksStreamAuthenticator creates a stream authenticator that uses Shadowsocks.
- // TODO(fortuna): Offer alternative transports.
- func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCache, metrics ShadowsocksConnMetrics, l *slog.Logger) StreamAuthenticateFunc {
- if metrics == nil {
- metrics = &NoOpShadowsocksConnMetrics{}
- }
- if l == nil {
- l = noopLogger()
- }
- return func(clientConn transport.StreamConn) (string, transport.StreamConn, *onet.ConnectionError) {
- // Find the cipher and acess key id.
- cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), ciphers, l)
- metrics.AddCipherSearch(keyErr == nil, timeToCipher)
- if keyErr != nil {
- const status = "ERR_CIPHER"
- return "", nil, onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr)
- }
- var id string
- if cipherEntry != nil {
- id = cipherEntry.ID
- }
- // Check if the connection is a replay.
- isServerSalt := cipherEntry.SaltGenerator.IsServerSalt(clientSalt)
- // Only check the cache if findAccessKey succeeded and the salt is unrecognized.
- if isServerSalt || !replayCache.Add(cipherEntry.ID, clientSalt) {
- var status string
- if isServerSalt {
- status = "ERR_REPLAY_SERVER"
- } else {
- status = "ERR_REPLAY_CLIENT"
- }
- return id, nil, onet.NewConnectionError(status, "Replay detected", nil)
- }
- ssr := shadowsocks.NewReader(clientReader, cipherEntry.CryptoKey)
- ssw := shadowsocks.NewWriter(clientConn, cipherEntry.CryptoKey)
- ssw.SetSaltGenerator(cipherEntry.SaltGenerator)
- return id, transport.WrapConn(clientConn, ssr, ssw), nil
- }
- }
- type streamHandler struct {
- logger *slog.Logger
- listenerId string
- readTimeout time.Duration
- authenticate StreamAuthenticateFunc
- dialer transport.StreamDialer
- }
- // NewStreamHandler creates a StreamHandler
- func NewStreamHandler(authenticate StreamAuthenticateFunc, timeout time.Duration) StreamHandler {
- return &streamHandler{
- logger: noopLogger(),
- readTimeout: timeout,
- authenticate: authenticate,
- dialer: MakeValidatingTCPStreamDialer(onet.RequirePublicIP, 0),
- }
- }
- // StreamHandler is a handler that handles stream connections.
- type StreamHandler interface {
- Handle(ctx context.Context, conn transport.StreamConn, connMetrics TCPConnMetrics)
- // SetLogger sets the logger used to log messages. Uses a no-op logger if nil.
- SetLogger(l *slog.Logger)
- // SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses.
- SetTargetDialer(dialer transport.StreamDialer)
- }
- func (s *streamHandler) SetLogger(l *slog.Logger) {
- if l == nil {
- l = noopLogger()
- }
- s.logger = l
- }
- func (s *streamHandler) SetTargetDialer(dialer transport.StreamDialer) {
- s.dialer = dialer
- }
- func ensureConnectionError(err error, fallbackStatus string, fallbackMsg string) *onet.ConnectionError {
- if err == nil {
- return nil
- }
- var connErr *onet.ConnectionError
- if errors.As(err, &connErr) {
- return connErr
- } else {
- return onet.NewConnectionError(fallbackStatus, fallbackMsg, err)
- }
- }
- type StreamAcceptFunc func() (transport.StreamConn, error)
- func WrapStreamAcceptFunc[T transport.StreamConn](f func() (T, error)) StreamAcceptFunc {
- return func() (transport.StreamConn, error) {
- return f()
- }
- }
- type StreamHandleFunc func(ctx context.Context, conn transport.StreamConn)
- // StreamServe repeatedly calls `accept` to obtain connections and `handle` to handle them until
- // accept() returns [ErrClosed]. When that happens, all connection handlers will be notified
- // via their [context.Context]. StreamServe will return after all pending handlers return.
- func StreamServe(accept StreamAcceptFunc, handle StreamHandleFunc) {
- var running sync.WaitGroup
- defer running.Wait()
- ctx, contextCancel := context.WithCancel(context.Background())
- defer contextCancel()
- for {
- clientConn, err := accept()
- if err != nil {
- if errors.Is(err, net.ErrClosed) {
- break
- }
- slog.Warn("Accept failed. Continuing to listen.", "err", err)
- continue
- }
- running.Add(1)
- go func() {
- defer running.Done()
- defer clientConn.Close()
- defer func() {
- if r := recover(); r != nil {
- slog.Warn("Panic in TCP handler. Continuing to listen.", "err", r)
- }
- }()
- handle(ctx, clientConn)
- }()
- }
- }
- func (h *streamHandler) Handle(ctx context.Context, clientConn transport.StreamConn, connMetrics TCPConnMetrics) {
- if connMetrics == nil {
- connMetrics = &NoOpTCPConnMetrics{}
- }
- var proxyMetrics metrics.ProxyMetrics
- measuredClientConn := metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy)
- connStart := time.Now()
- connError := h.handleConnection(ctx, measuredClientConn, connMetrics, &proxyMetrics)
- connDuration := time.Since(connStart)
- status := "OK"
- if connError != nil {
- status = connError.Status
- h.logger.LogAttrs(nil, slog.LevelDebug, "TCP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause))
- }
- connMetrics.AddClosed(status, proxyMetrics, connDuration)
- measuredClientConn.Close() // Closing after the metrics are added aids integration testing.
- h.logger.LogAttrs(nil, slog.LevelDebug, "TCP: Done.", slog.String("status", status), slog.Duration("duration", connDuration))
- }
- func getProxyRequest(clientConn transport.StreamConn) (string, error) {
- // TODO(fortuna): Use Shadowsocks proxy, HTTP CONNECT or SOCKS5 based on first byte:
- // case 1, 3 or 4: Shadowsocks (address type)
- // case 5: SOCKS5 (protocol version)
- // case "C": HTTP CONNECT (first char of method)
- tgtAddr, err := socks.ReadAddr(clientConn)
- if err != nil {
- return "", err
- }
- return tgtAddr.String(), nil
- }
- func proxyConnection(l *slog.Logger, ctx context.Context, dialer transport.StreamDialer, tgtAddr string, clientConn transport.StreamConn) *onet.ConnectionError {
- tgtConn, dialErr := dialer.DialStream(ctx, tgtAddr)
- if dialErr != nil {
- // We don't drain so dial errors and invalid addresses are communicated quickly.
- return ensureConnectionError(dialErr, "ERR_CONNECT", "Failed to connect to target")
- }
- defer tgtConn.Close()
- l.LogAttrs(nil, slog.LevelDebug, "Proxy connection.", slog.String("client", clientConn.RemoteAddr().String()), slog.String("target", tgtConn.RemoteAddr().String()))
- fromClientErrCh := make(chan error)
- go func() {
- _, fromClientErr := io.Copy(tgtConn, clientConn)
- if fromClientErr != nil {
- // Drain to prevent a close in the case of a cipher error.
- io.Copy(io.Discard, clientConn)
- }
- clientConn.CloseRead()
- // Send FIN to target.
- // We must do this after the drain is completed, otherwise the target will close its
- // connection with the proxy, which will, in turn, close the connection with the client.
- tgtConn.CloseWrite()
- fromClientErrCh <- fromClientErr
- }()
- _, fromTargetErr := io.Copy(clientConn, tgtConn)
- // Send FIN to client.
- clientConn.CloseWrite()
- tgtConn.CloseRead()
- fromClientErr := <-fromClientErrCh
- if fromClientErr != nil {
- return onet.NewConnectionError("ERR_RELAY_CLIENT", "Failed to relay traffic from client", fromClientErr)
- }
- if fromTargetErr != nil {
- return onet.NewConnectionError("ERR_RELAY_TARGET", "Failed to relay traffic from target", fromTargetErr)
- }
- return nil
- }
- func (h *streamHandler) handleConnection(ctx context.Context, outerConn transport.StreamConn, connMetrics TCPConnMetrics, proxyMetrics *metrics.ProxyMetrics) *onet.ConnectionError {
- // Set a deadline to receive the address to the target.
- readDeadline := time.Now().Add(h.readTimeout)
- if deadline, ok := ctx.Deadline(); ok {
- outerConn.SetDeadline(deadline)
- if deadline.Before(readDeadline) {
- readDeadline = deadline
- }
- }
- outerConn.SetReadDeadline(readDeadline)
- id, innerConn, authErr := h.authenticate(outerConn)
- if authErr != nil {
- // Drain to protect against probing attacks.
- h.absorbProbe(outerConn, connMetrics, authErr.Status, proxyMetrics)
- return authErr
- }
- connMetrics.AddAuthenticated(id)
- // Read target address and dial it.
- tgtAddr, err := getProxyRequest(innerConn)
- // Clear the deadline for the target address
- outerConn.SetReadDeadline(time.Time{})
- if err != nil {
- // Drain to prevent a close on cipher error.
- io.Copy(io.Discard, outerConn)
- return onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", err)
- }
- dialer := transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) {
- tgtConn, err := h.dialer.DialStream(ctx, tgtAddr)
- if err != nil {
- return nil, err
- }
- tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy)
- return tgtConn, nil
- })
- return proxyConnection(h.logger, ctx, dialer, tgtAddr, innerConn)
- }
- // Keep the connection open until we hit the authentication deadline to protect against probing attacks
- // `proxyMetrics` is a pointer because its value is being mutated by `clientConn`.
- func (h *streamHandler) absorbProbe(clientConn io.ReadCloser, connMetrics TCPConnMetrics, status string, proxyMetrics *metrics.ProxyMetrics) {
- // This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe.
- _, drainErr := io.Copy(io.Discard, clientConn) // drain socket
- drainResult := drainErrToString(drainErr)
- h.logger.LogAttrs(nil, slog.LevelDebug, "Drain error.", slog.Any("err", drainErr), slog.String("result", drainResult))
- connMetrics.AddProbe(status, drainResult, proxyMetrics.ClientProxy)
- }
- func drainErrToString(drainErr error) string {
- netErr, ok := drainErr.(net.Error)
- switch {
- case drainErr == nil:
- return "eof"
- case ok && netErr.Timeout():
- return "timeout"
- default:
- return "other"
- }
- }
- // NoOpTCPConnMetrics is a [TCPConnMetrics] that doesn't do anything. Useful in tests
- // or if you don't want to track metrics.
- type NoOpTCPConnMetrics struct{}
- var _ TCPConnMetrics = (*NoOpTCPConnMetrics)(nil)
- func (m *NoOpTCPConnMetrics) AddAuthenticated(accessKey string) {}
- func (m *NoOpTCPConnMetrics) AddClosed(status string, data metrics.ProxyMetrics, duration time.Duration) {
- }
- func (m *NoOpTCPConnMetrics) AddProbe(status, drainResult string, clientProxyBytes int64) {}
|