udp.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. // Copyright 2018 Jigsaw Operations LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // https://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package service
  15. import (
  16. "context"
  17. "errors"
  18. "fmt"
  19. "log/slog"
  20. "net"
  21. "net/netip"
  22. "runtime/debug"
  23. "sync"
  24. "time"
  25. "github.com/Jigsaw-Code/outline-sdk/transport"
  26. "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
  27. "github.com/shadowsocks/go-shadowsocks2/socks"
  28. onet "github.com/Jigsaw-Code/outline-ss-server/net"
  29. )
  30. // UDPConnMetrics is used to report metrics on UDP connections.
  31. type UDPConnMetrics interface {
  32. AddPacketFromClient(status string, clientProxyBytes, proxyTargetBytes int64)
  33. AddPacketFromTarget(status string, targetProxyBytes, proxyClientBytes int64)
  34. RemoveNatEntry()
  35. }
  36. type UDPMetrics interface {
  37. AddUDPNatEntry(clientAddr net.Addr, accessKey string) UDPConnMetrics
  38. }
  39. // Max UDP buffer size for the server code.
  40. const serverUDPBufferSize = 64 * 1024
  41. // Wrapper for slog.Debug during UDP proxying.
  42. func debugUDP(l *slog.Logger, template string, cipherID string, attr slog.Attr) {
  43. // This is an optimization to reduce unnecessary allocations due to an interaction
  44. // between Go's inlining/escape analysis and varargs functions like slog.Debug.
  45. if l.Enabled(nil, slog.LevelDebug) {
  46. l.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("UDP: %s", template), slog.String("ID", cipherID), attr)
  47. }
  48. }
  49. func debugUDPAddr(l *slog.Logger, template string, addr net.Addr, attr slog.Attr) {
  50. if l.Enabled(nil, slog.LevelDebug) {
  51. l.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("UDP: %s", template), slog.String("address", addr.String()), attr)
  52. }
  53. }
  54. // Decrypts src into dst. It tries each cipher until it finds one that authenticates
  55. // correctly. dst and src must not overlap.
  56. func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList, l *slog.Logger) ([]byte, string, *shadowsocks.EncryptionKey, error) {
  57. // Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD.
  58. // We snapshot the list because it may be modified while we use it.
  59. snapshot := cipherList.SnapshotForClientIP(clientIP)
  60. for ci, entry := range snapshot {
  61. id, cryptoKey := entry.Value.(*CipherEntry).ID, entry.Value.(*CipherEntry).CryptoKey
  62. buf, err := shadowsocks.Unpack(dst, src, cryptoKey)
  63. if err != nil {
  64. debugUDP(l, "Failed to unpack.", id, slog.Any("err", err))
  65. continue
  66. }
  67. debugUDP(l, "Found cipher.", id, slog.Int("index", ci))
  68. // Move the active cipher to the front, so that the search is quicker next time.
  69. cipherList.MarkUsedByClientIP(entry, clientIP)
  70. return buf, id, cryptoKey, nil
  71. }
  72. return nil, "", nil, errors.New("could not find valid UDP cipher")
  73. }
  74. type packetHandler struct {
  75. logger *slog.Logger
  76. natTimeout time.Duration
  77. ciphers CipherList
  78. m UDPMetrics
  79. ssm ShadowsocksConnMetrics
  80. targetIPValidator onet.TargetIPValidator
  81. targetListener transport.PacketListener
  82. }
  83. // NewPacketHandler creates a PacketHandler
  84. func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetrics, ssMetrics ShadowsocksConnMetrics) PacketHandler {
  85. if m == nil {
  86. m = &NoOpUDPMetrics{}
  87. }
  88. if ssMetrics == nil {
  89. ssMetrics = &NoOpShadowsocksConnMetrics{}
  90. }
  91. return &packetHandler{
  92. logger: noopLogger(),
  93. natTimeout: natTimeout,
  94. ciphers: cipherList,
  95. m: m,
  96. ssm: ssMetrics,
  97. targetIPValidator: onet.RequirePublicIP,
  98. targetListener: MakeTargetUDPListener(0),
  99. }
  100. }
  101. // PacketHandler is a running UDP shadowsocks proxy that can be stopped.
  102. type PacketHandler interface {
  103. // SetLogger sets the logger used to log messages. Uses a no-op logger if nil.
  104. SetLogger(l *slog.Logger)
  105. // SetTargetIPValidator sets the function to be used to validate the target IP addresses.
  106. SetTargetIPValidator(targetIPValidator onet.TargetIPValidator)
  107. // SetTargetPacketListener sets the packet listener to use for target connections.
  108. SetTargetPacketListener(targetListener transport.PacketListener)
  109. // Handle returns after clientConn closes and all the sub goroutines return.
  110. Handle(clientConn net.PacketConn)
  111. }
  112. func (h *packetHandler) SetLogger(l *slog.Logger) {
  113. if l == nil {
  114. l = noopLogger()
  115. }
  116. h.logger = l
  117. }
  118. func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) {
  119. h.targetIPValidator = targetIPValidator
  120. }
  121. func (h *packetHandler) SetTargetPacketListener(targetListener transport.PacketListener) {
  122. h.targetListener = targetListener
  123. }
  124. // Listen on addr for encrypted packets and basically do UDP NAT.
  125. // We take the ciphers as a pointer because it gets replaced on config updates.
  126. func (h *packetHandler) Handle(clientConn net.PacketConn) {
  127. nm := newNATmap(h.natTimeout, h.m, h.logger)
  128. defer nm.Close()
  129. cipherBuf := make([]byte, serverUDPBufferSize)
  130. textBuf := make([]byte, serverUDPBufferSize)
  131. for {
  132. clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf)
  133. if errors.Is(err, net.ErrClosed) {
  134. break
  135. }
  136. var proxyTargetBytes int
  137. var targetConn *natconn
  138. connError := func() (connError *onet.ConnectionError) {
  139. defer func() {
  140. if r := recover(); r != nil {
  141. slog.Error("Panic in UDP loop: %v. Continuing to listen.", r)
  142. debug.PrintStack()
  143. }
  144. }()
  145. // Error from ReadFrom
  146. if err != nil {
  147. return onet.NewConnectionError("ERR_READ", "Failed to read from client", err)
  148. }
  149. defer slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAddr.String()))
  150. debugUDPAddr(h.logger, "Outbound packet.", clientAddr, slog.Int("bytes", clientProxyBytes))
  151. cipherData := cipherBuf[:clientProxyBytes]
  152. var payload []byte
  153. var tgtUDPAddr *net.UDPAddr
  154. targetConn = nm.Get(clientAddr.String())
  155. if targetConn == nil {
  156. ip := clientAddr.(*net.UDPAddr).AddrPort().Addr()
  157. var textData []byte
  158. var cryptoKey *shadowsocks.EncryptionKey
  159. unpackStart := time.Now()
  160. textData, keyID, cryptoKey, err := findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers, h.logger)
  161. timeToCipher := time.Since(unpackStart)
  162. h.ssm.AddCipherSearch(err == nil, timeToCipher)
  163. if err != nil {
  164. return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err)
  165. }
  166. var onetErr *onet.ConnectionError
  167. if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil {
  168. return onetErr
  169. }
  170. udpConn, err := h.targetListener.ListenPacket(context.Background())
  171. if err != nil {
  172. return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create a `PacketConn`", err)
  173. }
  174. targetConn = nm.Add(clientAddr, clientConn, cryptoKey, udpConn, keyID)
  175. } else {
  176. unpackStart := time.Now()
  177. textData, err := shadowsocks.Unpack(nil, cipherData, targetConn.cryptoKey)
  178. timeToCipher := time.Since(unpackStart)
  179. h.ssm.AddCipherSearch(err == nil, timeToCipher)
  180. if err != nil {
  181. return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err)
  182. }
  183. var onetErr *onet.ConnectionError
  184. if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil {
  185. return onetErr
  186. }
  187. }
  188. debugUDPAddr(h.logger, "Proxy exit.", clientAddr, slog.Any("target", targetConn.LocalAddr()))
  189. proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature
  190. if err != nil {
  191. return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err)
  192. }
  193. return nil
  194. }()
  195. status := "OK"
  196. if connError != nil {
  197. slog.LogAttrs(nil, slog.LevelDebug, "UDP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause))
  198. status = connError.Status
  199. }
  200. if targetConn != nil {
  201. targetConn.metrics.AddPacketFromClient(status, int64(clientProxyBytes), int64(proxyTargetBytes))
  202. }
  203. }
  204. }
  205. // Given the decrypted contents of a UDP packet, return
  206. // the payload and the destination address, or an error if
  207. // this packet cannot or should not be forwarded.
  208. func (h *packetHandler) validatePacket(textData []byte) ([]byte, *net.UDPAddr, *onet.ConnectionError) {
  209. tgtAddr := socks.SplitAddr(textData)
  210. if tgtAddr == nil {
  211. return nil, nil, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", nil)
  212. }
  213. tgtUDPAddr, err := net.ResolveUDPAddr("udp", tgtAddr.String())
  214. if err != nil {
  215. return nil, nil, onet.NewConnectionError("ERR_RESOLVE_ADDRESS", fmt.Sprintf("Failed to resolve target address %v", tgtAddr), err)
  216. }
  217. if err := h.targetIPValidator(tgtUDPAddr.IP); err != nil {
  218. return nil, nil, ensureConnectionError(err, "ERR_ADDRESS_INVALID", "invalid address")
  219. }
  220. payload := textData[len(tgtAddr):]
  221. return payload, tgtUDPAddr, nil
  222. }
  223. func isDNS(addr net.Addr) bool {
  224. _, port, _ := net.SplitHostPort(addr.String())
  225. return port == "53"
  226. }
  227. type natconn struct {
  228. net.PacketConn
  229. cryptoKey *shadowsocks.EncryptionKey
  230. metrics UDPConnMetrics
  231. // NAT timeout to apply for non-DNS packets.
  232. defaultTimeout time.Duration
  233. // Current read deadline of PacketConn. Used to avoid decreasing the
  234. // deadline. Initially zero.
  235. readDeadline time.Time
  236. // If the connection has only sent one DNS query, it will close
  237. // if it receives a DNS response.
  238. fastClose sync.Once
  239. }
  240. func (c *natconn) onWrite(addr net.Addr) {
  241. // Fast close is only allowed if there has been exactly one write,
  242. // and it was a DNS query.
  243. isDNS := isDNS(addr)
  244. isFirstWrite := c.readDeadline.IsZero()
  245. if !isDNS || !isFirstWrite {
  246. // Disable fast close. (Idempotent.)
  247. c.fastClose.Do(func() {})
  248. }
  249. timeout := c.defaultTimeout
  250. if isDNS {
  251. // Shorten timeout as required by RFC 5452 Section 10.
  252. timeout = 17 * time.Second
  253. }
  254. newDeadline := time.Now().Add(timeout)
  255. if newDeadline.After(c.readDeadline) {
  256. c.readDeadline = newDeadline
  257. c.SetReadDeadline(newDeadline)
  258. }
  259. }
  260. func (c *natconn) onRead(addr net.Addr) {
  261. c.fastClose.Do(func() {
  262. if isDNS(addr) {
  263. // The next ReadFrom() should time out immediately.
  264. c.SetReadDeadline(time.Now())
  265. }
  266. })
  267. }
  268. func (c *natconn) WriteTo(buf []byte, dst net.Addr) (int, error) {
  269. c.onWrite(dst)
  270. return c.PacketConn.WriteTo(buf, dst)
  271. }
  272. func (c *natconn) ReadFrom(buf []byte) (int, net.Addr, error) {
  273. n, addr, err := c.PacketConn.ReadFrom(buf)
  274. if err == nil {
  275. c.onRead(addr)
  276. }
  277. return n, addr, err
  278. }
  279. // Packet NAT table
  280. type natmap struct {
  281. sync.RWMutex
  282. keyConn map[string]*natconn
  283. logger *slog.Logger
  284. timeout time.Duration
  285. metrics UDPMetrics
  286. }
  287. func newNATmap(timeout time.Duration, sm UDPMetrics, l *slog.Logger) *natmap {
  288. m := &natmap{logger: l, metrics: sm}
  289. m.keyConn = make(map[string]*natconn)
  290. m.timeout = timeout
  291. return m
  292. }
  293. func (m *natmap) Get(key string) *natconn {
  294. m.RLock()
  295. defer m.RUnlock()
  296. return m.keyConn[key]
  297. }
  298. func (m *natmap) set(key string, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, connMetrics UDPConnMetrics) *natconn {
  299. entry := &natconn{
  300. PacketConn: pc,
  301. cryptoKey: cryptoKey,
  302. metrics: connMetrics,
  303. defaultTimeout: m.timeout,
  304. }
  305. m.Lock()
  306. defer m.Unlock()
  307. m.keyConn[key] = entry
  308. return entry
  309. }
  310. func (m *natmap) del(key string) net.PacketConn {
  311. m.Lock()
  312. defer m.Unlock()
  313. entry, ok := m.keyConn[key]
  314. if ok {
  315. delete(m.keyConn, key)
  316. return entry
  317. }
  318. return nil
  319. }
  320. func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, targetConn net.PacketConn, keyID string) *natconn {
  321. connMetrics := m.metrics.AddUDPNatEntry(clientAddr, keyID)
  322. entry := m.set(clientAddr.String(), targetConn, cryptoKey, connMetrics)
  323. go func() {
  324. timedCopy(clientAddr, clientConn, entry, m.logger)
  325. connMetrics.RemoveNatEntry()
  326. if pc := m.del(clientAddr.String()); pc != nil {
  327. pc.Close()
  328. }
  329. }()
  330. return entry
  331. }
  332. func (m *natmap) Close() error {
  333. m.Lock()
  334. defer m.Unlock()
  335. var err error
  336. now := time.Now()
  337. for _, pc := range m.keyConn {
  338. if e := pc.SetReadDeadline(now); e != nil {
  339. err = e
  340. }
  341. }
  342. return err
  343. }
  344. // Get the maximum length of the shadowsocks address header by parsing
  345. // and serializing an IPv6 address from the example range.
  346. var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345"))
  347. // copy from target to client until read timeout
  348. func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, l *slog.Logger) {
  349. // pkt is used for in-place encryption of downstream UDP packets, with the layout
  350. // [padding?][salt][address][body][tag][extra]
  351. // Padding is only used if the address is IPv4.
  352. pkt := make([]byte, serverUDPBufferSize)
  353. saltSize := targetConn.cryptoKey.SaltSize()
  354. // Leave enough room at the beginning of the packet for a max-length header (i.e. IPv6).
  355. bodyStart := saltSize + maxAddrLen
  356. expired := false
  357. for {
  358. var bodyLen, proxyClientBytes int
  359. connError := func() (connError *onet.ConnectionError) {
  360. var (
  361. raddr net.Addr
  362. err error
  363. )
  364. // `readBuf` receives the plaintext body in `pkt`:
  365. // [padding?][salt][address][body][tag][unused]
  366. // |-- bodyStart --|[ readBuf ]
  367. readBuf := pkt[bodyStart:]
  368. bodyLen, raddr, err = targetConn.ReadFrom(readBuf)
  369. if err != nil {
  370. if netErr, ok := err.(net.Error); ok {
  371. if netErr.Timeout() {
  372. expired = true
  373. return nil
  374. }
  375. }
  376. return onet.NewConnectionError("ERR_READ", "Failed to read from target", err)
  377. }
  378. debugUDPAddr(l, "Got response.", clientAddr, slog.Any("target", raddr))
  379. srcAddr := socks.ParseAddr(raddr.String())
  380. addrStart := bodyStart - len(srcAddr)
  381. // `plainTextBuf` concatenates the SOCKS address and body:
  382. // [padding?][salt][address][body][tag][unused]
  383. // |-- addrStart -|[plaintextBuf ]
  384. plaintextBuf := pkt[addrStart : bodyStart+bodyLen]
  385. copy(plaintextBuf, srcAddr)
  386. // saltStart is 0 if raddr is IPv6.
  387. saltStart := addrStart - saltSize
  388. // `packBuf` adds space for the salt and tag.
  389. // `buf` shows the space that was used.
  390. // [padding?][salt][address][body][tag][unused]
  391. // [ packBuf ]
  392. // [ buf ]
  393. packBuf := pkt[saltStart:]
  394. buf, err := shadowsocks.Pack(packBuf, plaintextBuf, targetConn.cryptoKey) // Encrypt in-place
  395. if err != nil {
  396. return onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err)
  397. }
  398. proxyClientBytes, err = clientConn.WriteTo(buf, clientAddr)
  399. if err != nil {
  400. return onet.NewConnectionError("ERR_WRITE", "Failed to write to client", err)
  401. }
  402. return nil
  403. }()
  404. status := "OK"
  405. if connError != nil {
  406. slog.LogAttrs(nil, slog.LevelDebug, "UDP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause))
  407. status = connError.Status
  408. }
  409. if expired {
  410. break
  411. }
  412. targetConn.metrics.AddPacketFromTarget(status, int64(bodyLen), int64(proxyClientBytes))
  413. }
  414. }
  415. // NoOpUDPConnMetrics is a [UDPConnMetrics] that doesn't do anything. Useful in tests
  416. // or if you don't want to track metrics.
  417. type NoOpUDPConnMetrics struct{}
  418. var _ UDPConnMetrics = (*NoOpUDPConnMetrics)(nil)
  419. func (m *NoOpUDPConnMetrics) AddPacketFromClient(status string, clientProxyBytes, proxyTargetBytes int64) {
  420. }
  421. func (m *NoOpUDPConnMetrics) AddPacketFromTarget(status string, targetProxyBytes, proxyClientBytes int64) {
  422. }
  423. func (m *NoOpUDPConnMetrics) RemoveNatEntry() {}
  424. // NoOpUDPMetrics is a [UDPMetrics] that doesn't do anything. Useful in tests
  425. // or if you don't want to track metrics.
  426. type NoOpUDPMetrics struct{}
  427. var _ UDPMetrics = (*NoOpUDPMetrics)(nil)
  428. func (m *NoOpUDPMetrics) AddUDPNatEntry(clientAddr net.Addr, accessKey string) UDPConnMetrics {
  429. return &NoOpUDPConnMetrics{}
  430. }