tcp.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. // Copyright 2018 Jigsaw Operations LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // https://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package service
  15. import (
  16. "bytes"
  17. "container/list"
  18. "context"
  19. "errors"
  20. "fmt"
  21. "io"
  22. "log/slog"
  23. "net"
  24. "net/netip"
  25. "sync"
  26. "time"
  27. "github.com/Jigsaw-Code/outline-sdk/transport"
  28. "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
  29. "github.com/shadowsocks/go-shadowsocks2/socks"
  30. onet "github.com/Jigsaw-Code/outline-ss-server/net"
  31. "github.com/Jigsaw-Code/outline-ss-server/service/metrics"
  32. )
  33. // TCPConnMetrics is used to report metrics on TCP connections.
  34. type TCPConnMetrics interface {
  35. AddAuthenticated(accessKey string)
  36. AddClosed(status string, data metrics.ProxyMetrics, duration time.Duration)
  37. AddProbe(status, drainResult string, clientProxyBytes int64)
  38. }
  39. func remoteIP(conn net.Conn) netip.Addr {
  40. addr := conn.RemoteAddr()
  41. if addr == nil {
  42. return netip.Addr{}
  43. }
  44. if tcpaddr, ok := addr.(*net.TCPAddr); ok {
  45. return tcpaddr.AddrPort().Addr()
  46. }
  47. addrPort, err := netip.ParseAddrPort(addr.String())
  48. if err == nil {
  49. return addrPort.Addr()
  50. }
  51. return netip.Addr{}
  52. }
  53. // Wrapper for slog.Debug during TCP access key searches.
  54. func debugTCP(l *slog.Logger, template string, cipherID string, attr slog.Attr) {
  55. // This is an optimization to reduce unnecessary allocations due to an interaction
  56. // between Go's inlining/escape analysis and varargs functions like slog.Debug.
  57. if l.Enabled(nil, slog.LevelDebug) {
  58. l.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("TCP: %s", template), slog.String("ID", cipherID), attr)
  59. }
  60. }
  61. // bytesForKeyFinding is the number of bytes to read for finding the AccessKey.
  62. // Is must satisfy provided >= bytesForKeyFinding >= required for every cipher in the list.
  63. // provided = saltSize + 2 + 2 * cipher.TagSize, the minimum number of bytes we will see in a valid connection
  64. // required = saltSize + 2 + cipher.TagSize, the number of bytes needed to authenticate the connection.
  65. const bytesForKeyFinding = 50
  66. func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList CipherList, l *slog.Logger) (*CipherEntry, io.Reader, []byte, time.Duration, error) {
  67. // We snapshot the list because it may be modified while we use it.
  68. ciphers := cipherList.SnapshotForClientIP(clientIP)
  69. firstBytes := make([]byte, bytesForKeyFinding)
  70. if n, err := io.ReadFull(clientReader, firstBytes); err != nil {
  71. return nil, clientReader, nil, 0, fmt.Errorf("reading header failed after %d bytes: %w", n, err)
  72. }
  73. findStartTime := time.Now()
  74. entry, elt := findEntry(firstBytes, ciphers, l)
  75. timeToCipher := time.Since(findStartTime)
  76. if entry == nil {
  77. // TODO: Ban and log client IPs with too many failures too quick to protect against DoS.
  78. return nil, clientReader, nil, timeToCipher, fmt.Errorf("could not find valid TCP cipher")
  79. }
  80. // Move the active cipher to the front, so that the search is quicker next time.
  81. cipherList.MarkUsedByClientIP(elt, clientIP)
  82. salt := firstBytes[:entry.CryptoKey.SaltSize()]
  83. return entry, io.MultiReader(bytes.NewReader(firstBytes), clientReader), salt, timeToCipher, nil
  84. }
  85. // Implements a trial decryption search. This assumes that all ciphers are AEAD.
  86. func findEntry(firstBytes []byte, ciphers []*list.Element, l *slog.Logger) (*CipherEntry, *list.Element) {
  87. // To hold the decrypted chunk length.
  88. chunkLenBuf := [2]byte{}
  89. for ci, elt := range ciphers {
  90. entry := elt.Value.(*CipherEntry)
  91. cryptoKey := entry.CryptoKey
  92. _, err := shadowsocks.Unpack(chunkLenBuf[:0], firstBytes[:cryptoKey.SaltSize()+2+cryptoKey.TagSize()], cryptoKey)
  93. if err != nil {
  94. debugTCP(l, "Failed to decrypt length.", entry.ID, slog.Any("err", err))
  95. continue
  96. }
  97. debugTCP(l, "Found cipher.", entry.ID, slog.Int("index", ci))
  98. return entry, elt
  99. }
  100. return nil, nil
  101. }
  102. type StreamAuthenticateFunc func(clientConn transport.StreamConn) (string, transport.StreamConn, *onet.ConnectionError)
  103. // NewShadowsocksStreamAuthenticator creates a stream authenticator that uses Shadowsocks.
  104. // TODO(fortuna): Offer alternative transports.
  105. func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCache, metrics ShadowsocksConnMetrics, l *slog.Logger) StreamAuthenticateFunc {
  106. if metrics == nil {
  107. metrics = &NoOpShadowsocksConnMetrics{}
  108. }
  109. if l == nil {
  110. l = noopLogger()
  111. }
  112. return func(clientConn transport.StreamConn) (string, transport.StreamConn, *onet.ConnectionError) {
  113. // Find the cipher and acess key id.
  114. cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), ciphers, l)
  115. metrics.AddCipherSearch(keyErr == nil, timeToCipher)
  116. if keyErr != nil {
  117. const status = "ERR_CIPHER"
  118. return "", nil, onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr)
  119. }
  120. var id string
  121. if cipherEntry != nil {
  122. id = cipherEntry.ID
  123. }
  124. // Check if the connection is a replay.
  125. isServerSalt := cipherEntry.SaltGenerator.IsServerSalt(clientSalt)
  126. // Only check the cache if findAccessKey succeeded and the salt is unrecognized.
  127. if isServerSalt || !replayCache.Add(cipherEntry.ID, clientSalt) {
  128. var status string
  129. if isServerSalt {
  130. status = "ERR_REPLAY_SERVER"
  131. } else {
  132. status = "ERR_REPLAY_CLIENT"
  133. }
  134. return id, nil, onet.NewConnectionError(status, "Replay detected", nil)
  135. }
  136. ssr := shadowsocks.NewReader(clientReader, cipherEntry.CryptoKey)
  137. ssw := shadowsocks.NewWriter(clientConn, cipherEntry.CryptoKey)
  138. ssw.SetSaltGenerator(cipherEntry.SaltGenerator)
  139. return id, transport.WrapConn(clientConn, ssr, ssw), nil
  140. }
  141. }
  142. type streamHandler struct {
  143. logger *slog.Logger
  144. listenerId string
  145. readTimeout time.Duration
  146. authenticate StreamAuthenticateFunc
  147. dialer transport.StreamDialer
  148. }
  149. // NewStreamHandler creates a StreamHandler
  150. func NewStreamHandler(authenticate StreamAuthenticateFunc, timeout time.Duration) StreamHandler {
  151. return &streamHandler{
  152. logger: noopLogger(),
  153. readTimeout: timeout,
  154. authenticate: authenticate,
  155. dialer: MakeValidatingTCPStreamDialer(onet.RequirePublicIP, 0),
  156. }
  157. }
  158. // StreamHandler is a handler that handles stream connections.
  159. type StreamHandler interface {
  160. Handle(ctx context.Context, conn transport.StreamConn, connMetrics TCPConnMetrics)
  161. // SetLogger sets the logger used to log messages. Uses a no-op logger if nil.
  162. SetLogger(l *slog.Logger)
  163. // SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses.
  164. SetTargetDialer(dialer transport.StreamDialer)
  165. }
  166. func (s *streamHandler) SetLogger(l *slog.Logger) {
  167. if l == nil {
  168. l = noopLogger()
  169. }
  170. s.logger = l
  171. }
  172. func (s *streamHandler) SetTargetDialer(dialer transport.StreamDialer) {
  173. s.dialer = dialer
  174. }
  175. func ensureConnectionError(err error, fallbackStatus string, fallbackMsg string) *onet.ConnectionError {
  176. if err == nil {
  177. return nil
  178. }
  179. var connErr *onet.ConnectionError
  180. if errors.As(err, &connErr) {
  181. return connErr
  182. } else {
  183. return onet.NewConnectionError(fallbackStatus, fallbackMsg, err)
  184. }
  185. }
  186. type StreamAcceptFunc func() (transport.StreamConn, error)
  187. func WrapStreamAcceptFunc[T transport.StreamConn](f func() (T, error)) StreamAcceptFunc {
  188. return func() (transport.StreamConn, error) {
  189. return f()
  190. }
  191. }
  192. type StreamHandleFunc func(ctx context.Context, conn transport.StreamConn)
  193. // StreamServe repeatedly calls `accept` to obtain connections and `handle` to handle them until
  194. // accept() returns [ErrClosed]. When that happens, all connection handlers will be notified
  195. // via their [context.Context]. StreamServe will return after all pending handlers return.
  196. func StreamServe(accept StreamAcceptFunc, handle StreamHandleFunc) {
  197. var running sync.WaitGroup
  198. defer running.Wait()
  199. ctx, contextCancel := context.WithCancel(context.Background())
  200. defer contextCancel()
  201. for {
  202. clientConn, err := accept()
  203. if err != nil {
  204. if errors.Is(err, net.ErrClosed) {
  205. break
  206. }
  207. slog.Warn("Accept failed. Continuing to listen.", "err", err)
  208. continue
  209. }
  210. running.Add(1)
  211. go func() {
  212. defer running.Done()
  213. defer clientConn.Close()
  214. defer func() {
  215. if r := recover(); r != nil {
  216. slog.Warn("Panic in TCP handler. Continuing to listen.", "err", r)
  217. }
  218. }()
  219. handle(ctx, clientConn)
  220. }()
  221. }
  222. }
  223. func (h *streamHandler) Handle(ctx context.Context, clientConn transport.StreamConn, connMetrics TCPConnMetrics) {
  224. if connMetrics == nil {
  225. connMetrics = &NoOpTCPConnMetrics{}
  226. }
  227. var proxyMetrics metrics.ProxyMetrics
  228. measuredClientConn := metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy)
  229. connStart := time.Now()
  230. connError := h.handleConnection(ctx, measuredClientConn, connMetrics, &proxyMetrics)
  231. connDuration := time.Since(connStart)
  232. status := "OK"
  233. if connError != nil {
  234. status = connError.Status
  235. h.logger.LogAttrs(nil, slog.LevelDebug, "TCP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause))
  236. }
  237. connMetrics.AddClosed(status, proxyMetrics, connDuration)
  238. measuredClientConn.Close() // Closing after the metrics are added aids integration testing.
  239. h.logger.LogAttrs(nil, slog.LevelDebug, "TCP: Done.", slog.String("status", status), slog.Duration("duration", connDuration))
  240. }
  241. func getProxyRequest(clientConn transport.StreamConn) (string, error) {
  242. // TODO(fortuna): Use Shadowsocks proxy, HTTP CONNECT or SOCKS5 based on first byte:
  243. // case 1, 3 or 4: Shadowsocks (address type)
  244. // case 5: SOCKS5 (protocol version)
  245. // case "C": HTTP CONNECT (first char of method)
  246. tgtAddr, err := socks.ReadAddr(clientConn)
  247. if err != nil {
  248. return "", err
  249. }
  250. return tgtAddr.String(), nil
  251. }
  252. func proxyConnection(l *slog.Logger, ctx context.Context, dialer transport.StreamDialer, tgtAddr string, clientConn transport.StreamConn) *onet.ConnectionError {
  253. tgtConn, dialErr := dialer.DialStream(ctx, tgtAddr)
  254. if dialErr != nil {
  255. // We don't drain so dial errors and invalid addresses are communicated quickly.
  256. return ensureConnectionError(dialErr, "ERR_CONNECT", "Failed to connect to target")
  257. }
  258. defer tgtConn.Close()
  259. l.LogAttrs(nil, slog.LevelDebug, "Proxy connection.", slog.String("client", clientConn.RemoteAddr().String()), slog.String("target", tgtConn.RemoteAddr().String()))
  260. fromClientErrCh := make(chan error)
  261. go func() {
  262. _, fromClientErr := io.Copy(tgtConn, clientConn)
  263. if fromClientErr != nil {
  264. // Drain to prevent a close in the case of a cipher error.
  265. io.Copy(io.Discard, clientConn)
  266. }
  267. clientConn.CloseRead()
  268. // Send FIN to target.
  269. // We must do this after the drain is completed, otherwise the target will close its
  270. // connection with the proxy, which will, in turn, close the connection with the client.
  271. tgtConn.CloseWrite()
  272. fromClientErrCh <- fromClientErr
  273. }()
  274. _, fromTargetErr := io.Copy(clientConn, tgtConn)
  275. // Send FIN to client.
  276. clientConn.CloseWrite()
  277. tgtConn.CloseRead()
  278. fromClientErr := <-fromClientErrCh
  279. if fromClientErr != nil {
  280. return onet.NewConnectionError("ERR_RELAY_CLIENT", "Failed to relay traffic from client", fromClientErr)
  281. }
  282. if fromTargetErr != nil {
  283. return onet.NewConnectionError("ERR_RELAY_TARGET", "Failed to relay traffic from target", fromTargetErr)
  284. }
  285. return nil
  286. }
  287. func (h *streamHandler) handleConnection(ctx context.Context, outerConn transport.StreamConn, connMetrics TCPConnMetrics, proxyMetrics *metrics.ProxyMetrics) *onet.ConnectionError {
  288. // Set a deadline to receive the address to the target.
  289. readDeadline := time.Now().Add(h.readTimeout)
  290. if deadline, ok := ctx.Deadline(); ok {
  291. outerConn.SetDeadline(deadline)
  292. if deadline.Before(readDeadline) {
  293. readDeadline = deadline
  294. }
  295. }
  296. outerConn.SetReadDeadline(readDeadline)
  297. id, innerConn, authErr := h.authenticate(outerConn)
  298. if authErr != nil {
  299. // Drain to protect against probing attacks.
  300. h.absorbProbe(outerConn, connMetrics, authErr.Status, proxyMetrics)
  301. return authErr
  302. }
  303. connMetrics.AddAuthenticated(id)
  304. // Read target address and dial it.
  305. tgtAddr, err := getProxyRequest(innerConn)
  306. // Clear the deadline for the target address
  307. outerConn.SetReadDeadline(time.Time{})
  308. if err != nil {
  309. // Drain to prevent a close on cipher error.
  310. io.Copy(io.Discard, outerConn)
  311. return onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", err)
  312. }
  313. dialer := transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) {
  314. tgtConn, err := h.dialer.DialStream(ctx, tgtAddr)
  315. if err != nil {
  316. return nil, err
  317. }
  318. tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy)
  319. return tgtConn, nil
  320. })
  321. return proxyConnection(h.logger, ctx, dialer, tgtAddr, innerConn)
  322. }
  323. // Keep the connection open until we hit the authentication deadline to protect against probing attacks
  324. // `proxyMetrics` is a pointer because its value is being mutated by `clientConn`.
  325. func (h *streamHandler) absorbProbe(clientConn io.ReadCloser, connMetrics TCPConnMetrics, status string, proxyMetrics *metrics.ProxyMetrics) {
  326. // This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe.
  327. _, drainErr := io.Copy(io.Discard, clientConn) // drain socket
  328. drainResult := drainErrToString(drainErr)
  329. h.logger.LogAttrs(nil, slog.LevelDebug, "Drain error.", slog.Any("err", drainErr), slog.String("result", drainResult))
  330. connMetrics.AddProbe(status, drainResult, proxyMetrics.ClientProxy)
  331. }
  332. func drainErrToString(drainErr error) string {
  333. netErr, ok := drainErr.(net.Error)
  334. switch {
  335. case drainErr == nil:
  336. return "eof"
  337. case ok && netErr.Timeout():
  338. return "timeout"
  339. default:
  340. return "other"
  341. }
  342. }
  343. // NoOpTCPConnMetrics is a [TCPConnMetrics] that doesn't do anything. Useful in tests
  344. // or if you don't want to track metrics.
  345. type NoOpTCPConnMetrics struct{}
  346. var _ TCPConnMetrics = (*NoOpTCPConnMetrics)(nil)
  347. func (m *NoOpTCPConnMetrics) AddAuthenticated(accessKey string) {}
  348. func (m *NoOpTCPConnMetrics) AddClosed(status string, data metrics.ProxyMetrics, duration time.Duration) {
  349. }
  350. func (m *NoOpTCPConnMetrics) AddProbe(status, drainResult string, clientProxyBytes int64) {}