worker.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. package inbound
  2. import (
  3. "context"
  4. "sync"
  5. "sync/atomic"
  6. "time"
  7. "github.com/xtls/xray-core/app/proxyman"
  8. "github.com/xtls/xray-core/common"
  9. "github.com/xtls/xray-core/common/buf"
  10. "github.com/xtls/xray-core/common/net"
  11. "github.com/xtls/xray-core/common/serial"
  12. "github.com/xtls/xray-core/common/session"
  13. "github.com/xtls/xray-core/common/signal/done"
  14. "github.com/xtls/xray-core/common/task"
  15. "github.com/xtls/xray-core/features/routing"
  16. "github.com/xtls/xray-core/features/stats"
  17. "github.com/xtls/xray-core/proxy"
  18. "github.com/xtls/xray-core/transport/internet"
  19. "github.com/xtls/xray-core/transport/internet/tcp"
  20. "github.com/xtls/xray-core/transport/internet/udp"
  21. "github.com/xtls/xray-core/transport/pipe"
  22. )
  23. type worker interface {
  24. Start() error
  25. Close() error
  26. Port() net.Port
  27. Proxy() proxy.Inbound
  28. }
  29. type tcpWorker struct {
  30. address net.Address
  31. port net.Port
  32. proxy proxy.Inbound
  33. stream *internet.MemoryStreamConfig
  34. recvOrigDest bool
  35. tag string
  36. dispatcher routing.Dispatcher
  37. sniffingConfig *proxyman.SniffingConfig
  38. sniffingMatcher *proxyman.SniffingMatcher
  39. uplinkCounter stats.Counter
  40. downlinkCounter stats.Counter
  41. hub internet.Listener
  42. ctx context.Context
  43. }
  44. func getTProxyType(s *internet.MemoryStreamConfig) internet.SocketConfig_TProxyMode {
  45. if s == nil || s.SocketSettings == nil {
  46. return internet.SocketConfig_Off
  47. }
  48. return s.SocketSettings.Tproxy
  49. }
  50. func (w *tcpWorker) callback(conn internet.Connection) {
  51. ctx, cancel := context.WithCancel(w.ctx)
  52. sid := session.NewID()
  53. ctx = session.ContextWithID(ctx, sid)
  54. if w.recvOrigDest {
  55. var dest net.Destination
  56. switch getTProxyType(w.stream) {
  57. case internet.SocketConfig_Redirect:
  58. d, err := tcp.GetOriginalDestination(conn)
  59. if err != nil {
  60. newError("failed to get original destination").Base(err).WriteToLog(session.ExportIDToError(ctx))
  61. } else {
  62. dest = d
  63. }
  64. case internet.SocketConfig_TProxy:
  65. dest = net.DestinationFromAddr(conn.LocalAddr())
  66. }
  67. if dest.IsValid() {
  68. ctx = session.ContextWithOutbound(ctx, &session.Outbound{
  69. Target: dest,
  70. })
  71. }
  72. }
  73. if w.uplinkCounter != nil || w.downlinkCounter != nil {
  74. conn = &internet.StatCouterConnection{
  75. Connection: conn,
  76. ReadCounter: w.uplinkCounter,
  77. WriteCounter: w.downlinkCounter,
  78. }
  79. }
  80. ctx = session.ContextWithInbound(ctx, &session.Inbound{
  81. Source: net.DestinationFromAddr(conn.RemoteAddr()),
  82. Gateway: net.TCPDestination(w.address, w.port),
  83. Tag: w.tag,
  84. Conn: conn,
  85. })
  86. content := new(session.Content)
  87. if w.sniffingConfig != nil {
  88. content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
  89. content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
  90. content.SniffingRequest.ExcludedDomainMatcher = w.sniffingMatcher.ExDomain
  91. content.SniffingRequest.ExcludedIPMatcher = w.sniffingMatcher.ExIP
  92. content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
  93. }
  94. ctx = session.ContextWithContent(ctx, content)
  95. if err := w.proxy.Process(ctx, net.Network_TCP, conn, w.dispatcher); err != nil {
  96. newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
  97. }
  98. cancel()
  99. if err := conn.Close(); err != nil {
  100. newError("failed to close connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
  101. }
  102. }
  103. func (w *tcpWorker) Proxy() proxy.Inbound {
  104. return w.proxy
  105. }
  106. func (w *tcpWorker) Start() error {
  107. ctx := context.Background()
  108. hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn internet.Connection) {
  109. go w.callback(conn)
  110. })
  111. if err != nil {
  112. return newError("failed to listen TCP on ", w.port).AtWarning().Base(err)
  113. }
  114. w.hub = hub
  115. return nil
  116. }
  117. func (w *tcpWorker) Close() error {
  118. var errors []interface{}
  119. if w.hub != nil {
  120. if err := common.Close(w.hub); err != nil {
  121. errors = append(errors, err)
  122. }
  123. if err := common.Close(w.proxy); err != nil {
  124. errors = append(errors, err)
  125. }
  126. }
  127. if len(errors) > 0 {
  128. return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
  129. }
  130. return nil
  131. }
  132. func (w *tcpWorker) Port() net.Port {
  133. return w.port
  134. }
  135. type udpConn struct {
  136. lastActivityTime int64 // in seconds
  137. reader buf.Reader
  138. writer buf.Writer
  139. output func([]byte) (int, error)
  140. remote net.Addr
  141. local net.Addr
  142. done *done.Instance
  143. uplink stats.Counter
  144. downlink stats.Counter
  145. }
  146. func (c *udpConn) updateActivity() {
  147. atomic.StoreInt64(&c.lastActivityTime, time.Now().Unix())
  148. }
  149. // ReadMultiBuffer implements buf.Reader
  150. func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
  151. mb, err := c.reader.ReadMultiBuffer()
  152. if err != nil {
  153. return nil, err
  154. }
  155. c.updateActivity()
  156. if c.uplink != nil {
  157. c.uplink.Add(int64(mb.Len()))
  158. }
  159. return mb, nil
  160. }
  161. func (c *udpConn) Read(buf []byte) (int, error) {
  162. panic("not implemented")
  163. }
  164. // Write implements io.Writer.
  165. func (c *udpConn) Write(buf []byte) (int, error) {
  166. n, err := c.output(buf)
  167. if c.downlink != nil {
  168. c.downlink.Add(int64(n))
  169. }
  170. if err == nil {
  171. c.updateActivity()
  172. }
  173. return n, err
  174. }
  175. func (c *udpConn) Close() error {
  176. common.Must(c.done.Close())
  177. common.Must(common.Close(c.writer))
  178. return nil
  179. }
  180. func (c *udpConn) RemoteAddr() net.Addr {
  181. return c.remote
  182. }
  183. func (c *udpConn) LocalAddr() net.Addr {
  184. return c.local
  185. }
  186. func (*udpConn) SetDeadline(time.Time) error {
  187. return nil
  188. }
  189. func (*udpConn) SetReadDeadline(time.Time) error {
  190. return nil
  191. }
  192. func (*udpConn) SetWriteDeadline(time.Time) error {
  193. return nil
  194. }
  195. type connID struct {
  196. src net.Destination
  197. dest net.Destination
  198. }
  199. type udpWorker struct {
  200. sync.RWMutex
  201. proxy proxy.Inbound
  202. hub *udp.Hub
  203. address net.Address
  204. port net.Port
  205. tag string
  206. stream *internet.MemoryStreamConfig
  207. dispatcher routing.Dispatcher
  208. sniffingConfig *proxyman.SniffingConfig
  209. uplinkCounter stats.Counter
  210. downlinkCounter stats.Counter
  211. checker *task.Periodic
  212. activeConn map[connID]*udpConn
  213. ctx context.Context
  214. cone bool
  215. }
  216. func (w *udpWorker) getConnection(id connID) (*udpConn, bool) {
  217. w.Lock()
  218. defer w.Unlock()
  219. if conn, found := w.activeConn[id]; found && !conn.done.Done() {
  220. return conn, true
  221. }
  222. pReader, pWriter := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
  223. conn := &udpConn{
  224. reader: pReader,
  225. writer: pWriter,
  226. output: func(b []byte) (int, error) {
  227. return w.hub.WriteTo(b, id.src)
  228. },
  229. remote: &net.UDPAddr{
  230. IP: id.src.Address.IP(),
  231. Port: int(id.src.Port),
  232. },
  233. local: &net.UDPAddr{
  234. IP: w.address.IP(),
  235. Port: int(w.port),
  236. },
  237. done: done.New(),
  238. uplink: w.uplinkCounter,
  239. downlink: w.downlinkCounter,
  240. }
  241. w.activeConn[id] = conn
  242. conn.updateActivity()
  243. return conn, false
  244. }
  245. func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest net.Destination) {
  246. id := connID{
  247. src: source,
  248. }
  249. if originalDest.IsValid() {
  250. if !w.cone {
  251. id.dest = originalDest
  252. }
  253. b.UDP = &originalDest
  254. }
  255. conn, existing := w.getConnection(id)
  256. // payload will be discarded in pipe is full.
  257. conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
  258. if !existing {
  259. common.Must(w.checker.Start())
  260. go func() {
  261. ctx := w.ctx
  262. sid := session.NewID()
  263. ctx = session.ContextWithID(ctx, sid)
  264. if originalDest.IsValid() {
  265. ctx = session.ContextWithOutbound(ctx, &session.Outbound{
  266. Target: originalDest,
  267. })
  268. }
  269. ctx = session.ContextWithInbound(ctx, &session.Inbound{
  270. Source: source,
  271. Gateway: net.UDPDestination(w.address, w.port),
  272. Tag: w.tag,
  273. })
  274. content := new(session.Content)
  275. if w.sniffingConfig != nil {
  276. content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
  277. content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
  278. content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
  279. }
  280. ctx = session.ContextWithContent(ctx, content)
  281. if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil {
  282. newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
  283. }
  284. conn.Close()
  285. w.removeConn(id)
  286. }()
  287. }
  288. }
  289. func (w *udpWorker) removeConn(id connID) {
  290. w.Lock()
  291. delete(w.activeConn, id)
  292. w.Unlock()
  293. }
  294. func (w *udpWorker) handlePackets() {
  295. receive := w.hub.Receive()
  296. for payload := range receive {
  297. w.callback(payload.Payload, payload.Source, payload.Target)
  298. }
  299. }
  300. func (w *udpWorker) clean() error {
  301. nowSec := time.Now().Unix()
  302. w.Lock()
  303. defer w.Unlock()
  304. if len(w.activeConn) == 0 {
  305. return newError("no more connections. stopping...")
  306. }
  307. for addr, conn := range w.activeConn {
  308. if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 300 {
  309. delete(w.activeConn, addr)
  310. conn.Close()
  311. }
  312. }
  313. if len(w.activeConn) == 0 {
  314. w.activeConn = make(map[connID]*udpConn, 16)
  315. }
  316. return nil
  317. }
  318. func (w *udpWorker) Start() error {
  319. w.activeConn = make(map[connID]*udpConn, 16)
  320. ctx := context.Background()
  321. h, err := udp.ListenUDP(ctx, w.address, w.port, w.stream, udp.HubCapacity(256))
  322. if err != nil {
  323. return err
  324. }
  325. w.cone = w.ctx.Value("cone").(bool)
  326. w.checker = &task.Periodic{
  327. Interval: time.Minute,
  328. Execute: w.clean,
  329. }
  330. w.hub = h
  331. go w.handlePackets()
  332. return nil
  333. }
  334. func (w *udpWorker) Close() error {
  335. w.Lock()
  336. defer w.Unlock()
  337. var errors []interface{}
  338. if w.hub != nil {
  339. if err := w.hub.Close(); err != nil {
  340. errors = append(errors, err)
  341. }
  342. }
  343. if w.checker != nil {
  344. if err := w.checker.Close(); err != nil {
  345. errors = append(errors, err)
  346. }
  347. }
  348. if err := common.Close(w.proxy); err != nil {
  349. errors = append(errors, err)
  350. }
  351. if len(errors) > 0 {
  352. return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
  353. }
  354. return nil
  355. }
  356. func (w *udpWorker) Port() net.Port {
  357. return w.port
  358. }
  359. func (w *udpWorker) Proxy() proxy.Inbound {
  360. return w.proxy
  361. }
  362. type dsWorker struct {
  363. address net.Address
  364. proxy proxy.Inbound
  365. stream *internet.MemoryStreamConfig
  366. tag string
  367. dispatcher routing.Dispatcher
  368. sniffingConfig *proxyman.SniffingConfig
  369. sniffingMatcher *proxyman.SniffingMatcher
  370. uplinkCounter stats.Counter
  371. downlinkCounter stats.Counter
  372. hub internet.Listener
  373. ctx context.Context
  374. }
  375. func (w *dsWorker) callback(conn internet.Connection) {
  376. ctx, cancel := context.WithCancel(w.ctx)
  377. sid := session.NewID()
  378. ctx = session.ContextWithID(ctx, sid)
  379. if w.uplinkCounter != nil || w.downlinkCounter != nil {
  380. conn = &internet.StatCouterConnection{
  381. Connection: conn,
  382. ReadCounter: w.uplinkCounter,
  383. WriteCounter: w.downlinkCounter,
  384. }
  385. }
  386. ctx = session.ContextWithInbound(ctx, &session.Inbound{
  387. Source: net.DestinationFromAddr(conn.RemoteAddr()),
  388. Gateway: net.UnixDestination(w.address),
  389. Tag: w.tag,
  390. Conn: conn,
  391. })
  392. content := new(session.Content)
  393. if w.sniffingConfig != nil {
  394. content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
  395. content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
  396. content.SniffingRequest.ExcludedDomainMatcher = w.sniffingMatcher.ExDomain
  397. content.SniffingRequest.ExcludedIPMatcher = w.sniffingMatcher.ExIP
  398. content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
  399. }
  400. ctx = session.ContextWithContent(ctx, content)
  401. if err := w.proxy.Process(ctx, net.Network_UNIX, conn, w.dispatcher); err != nil {
  402. newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
  403. }
  404. cancel()
  405. if err := conn.Close(); err != nil {
  406. newError("failed to close connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
  407. }
  408. }
  409. func (w *dsWorker) Proxy() proxy.Inbound {
  410. return w.proxy
  411. }
  412. func (w *dsWorker) Port() net.Port {
  413. return net.Port(0)
  414. }
  415. func (w *dsWorker) Start() error {
  416. ctx := context.Background()
  417. hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn internet.Connection) {
  418. go w.callback(conn)
  419. })
  420. if err != nil {
  421. return newError("failed to listen Unix Domain Socket on ", w.address).AtWarning().Base(err)
  422. }
  423. w.hub = hub
  424. return nil
  425. }
  426. func (w *dsWorker) Close() error {
  427. var errors []interface{}
  428. if w.hub != nil {
  429. if err := common.Close(w.hub); err != nil {
  430. errors = append(errors, err)
  431. }
  432. if err := common.Close(w.proxy); err != nil {
  433. errors = append(errors, err)
  434. }
  435. }
  436. if len(errors) > 0 {
  437. return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
  438. }
  439. return nil
  440. }