handler.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. package tun
  2. import (
  3. "context"
  4. "sync"
  5. "sync/atomic"
  6. "time"
  7. "github.com/xtls/xray-core/common"
  8. "github.com/xtls/xray-core/common/buf"
  9. c "github.com/xtls/xray-core/common/ctx"
  10. "github.com/xtls/xray-core/common/errors"
  11. "github.com/xtls/xray-core/common/net"
  12. "github.com/xtls/xray-core/common/protocol"
  13. "github.com/xtls/xray-core/common/session"
  14. "github.com/xtls/xray-core/common/signal/done"
  15. "github.com/xtls/xray-core/common/task"
  16. "github.com/xtls/xray-core/core"
  17. "github.com/xtls/xray-core/features/policy"
  18. "github.com/xtls/xray-core/features/routing"
  19. "github.com/xtls/xray-core/transport"
  20. "github.com/xtls/xray-core/transport/internet/stat"
  21. "github.com/xtls/xray-core/transport/pipe"
  22. "gvisor.dev/gvisor/pkg/buffer"
  23. "gvisor.dev/gvisor/pkg/tcpip"
  24. "gvisor.dev/gvisor/pkg/tcpip/checksum"
  25. "gvisor.dev/gvisor/pkg/tcpip/header"
  26. "gvisor.dev/gvisor/pkg/tcpip/stack"
  27. )
  28. // udpConnID represents a UDP connection identifier
  29. type udpConnID struct {
  30. src net.Destination
  31. dest net.Destination
  32. }
  33. // udpConn represents a UDP connection for packet handling
  34. type udpConn struct {
  35. lastActivityTime int64 // in seconds
  36. reader buf.Reader
  37. writer buf.Writer
  38. output func([]byte, net.Destination) (int, error)
  39. remote net.Addr
  40. local net.Addr
  41. done *done.Instance
  42. inactive bool
  43. cancel context.CancelFunc
  44. }
  45. func (c *udpConn) setInactive() {
  46. c.inactive = true
  47. }
  48. func (c *udpConn) updateActivity() {
  49. atomic.StoreInt64(&c.lastActivityTime, time.Now().Unix())
  50. }
  51. // ReadMultiBuffer implements buf.Reader
  52. func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
  53. mb, err := c.reader.ReadMultiBuffer()
  54. if err != nil {
  55. return nil, err
  56. }
  57. c.updateActivity()
  58. return mb, nil
  59. }
  60. func (c *udpConn) Read(buf []byte) (int, error) {
  61. return 0, errors.New("Read not supported, use ReadMultiBuffer instead")
  62. }
  63. // Write implements io.Writer
  64. func (c *udpConn) Write(data []byte) (int, error) {
  65. // Extract destination from the first buffer if available
  66. // For now, write with empty destination (will be filled by output function)
  67. n, err := c.output(data, net.Destination{})
  68. if err == nil {
  69. c.updateActivity()
  70. }
  71. return n, err
  72. }
  73. func (c *udpConn) Close() error {
  74. if c.cancel != nil {
  75. c.cancel()
  76. }
  77. common.Must(c.done.Close())
  78. common.Must(common.Close(c.writer))
  79. return nil
  80. }
  81. func (c *udpConn) RemoteAddr() net.Addr {
  82. return c.remote
  83. }
  84. func (c *udpConn) LocalAddr() net.Addr {
  85. return c.local
  86. }
  87. func (*udpConn) SetDeadline(time.Time) error {
  88. return nil
  89. }
  90. func (*udpConn) SetReadDeadline(time.Time) error {
  91. return nil
  92. }
  93. func (*udpConn) SetWriteDeadline(time.Time) error {
  94. return nil
  95. }
  96. // Handler is managing object that tie together tun interface, ip stack and dispatch connections to the routing
  97. type Handler struct {
  98. sync.RWMutex
  99. ctx context.Context
  100. config *Config
  101. stack Stack
  102. policyManager policy.Manager
  103. dispatcher routing.Dispatcher
  104. cone bool
  105. // UDP connection management
  106. udpConns map[udpConnID]*udpConn
  107. udpChecker *task.Periodic
  108. }
  109. // ConnectionHandler interface with the only method that stack is going to push new connections to
  110. type ConnectionHandler interface {
  111. HandleConnection(conn net.Conn, destination net.Destination)
  112. }
  113. // Handler implements ConnectionHandler
  114. var _ ConnectionHandler = (*Handler)(nil)
  115. func (t *Handler) policy() policy.Session {
  116. p := t.policyManager.ForLevel(t.config.UserLevel)
  117. return p
  118. }
  119. // getUDPConn gets or creates a UDP connection for the given source and destination
  120. func (t *Handler) getUDPConn(source, dest net.Destination, ipStack *stack.Stack) (*udpConn, bool) {
  121. t.Lock()
  122. defer t.Unlock()
  123. id := udpConnID{
  124. src: source,
  125. }
  126. if !t.cone {
  127. id.dest = dest
  128. }
  129. if conn, found := t.udpConns[id]; found && !conn.done.Done() {
  130. conn.updateActivity()
  131. return conn, true
  132. }
  133. pReader, pWriter := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
  134. conn := &udpConn{
  135. reader: pReader,
  136. writer: pWriter,
  137. output: func(data []byte, returnDest net.Destination) (int, error) {
  138. // Write UDP packet back to the stack with proper source address
  139. return t.writeUDPPacket(ipStack, data, returnDest, source)
  140. },
  141. remote: &net.UDPAddr{
  142. IP: source.Address.IP(),
  143. Port: int(source.Port),
  144. },
  145. local: &net.UDPAddr{
  146. IP: dest.Address.IP(),
  147. Port: int(dest.Port),
  148. },
  149. done: done.New(),
  150. }
  151. t.udpConns[id] = conn
  152. conn.updateActivity()
  153. return conn, false
  154. }
  155. // removeUDPConn removes a UDP connection
  156. func (t *Handler) removeUDPConn(id udpConnID) {
  157. t.Lock()
  158. delete(t.udpConns, id)
  159. t.Unlock()
  160. }
  161. // cleanupUDPConns removes inactive UDP connections
  162. func (t *Handler) cleanupUDPConns() error {
  163. nowSec := time.Now().Unix()
  164. t.Lock()
  165. defer t.Unlock()
  166. if len(t.udpConns) == 0 {
  167. return errors.New("UDP connection cleanup stopped: no active connections remaining")
  168. }
  169. for id, conn := range t.udpConns {
  170. if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 300 { // 5 minutes
  171. if !conn.inactive {
  172. conn.setInactive()
  173. conn.Close()
  174. delete(t.udpConns, id)
  175. }
  176. }
  177. }
  178. return nil
  179. }
  180. // writeUDPPacket writes a UDP packet back to the gVisor stack with custom source address
  181. func (t *Handler) writeUDPPacket(ipStack *stack.Stack, data []byte, dest, source net.Destination) (int, error) {
  182. // Build UDP+IP packet with proper headers using gVisor's header builders
  183. // Determine IP version
  184. var ipHdrLen, udpHdrLen int
  185. isIPv4 := dest.Address.Family().IsIPv4()
  186. if isIPv4 {
  187. ipHdrLen = header.IPv4MinimumSize
  188. } else {
  189. ipHdrLen = header.IPv6MinimumSize
  190. }
  191. udpHdrLen = header.UDPMinimumSize
  192. totalLen := ipHdrLen + udpHdrLen + len(data)
  193. packet := make([]byte, totalLen)
  194. // Build UDP header
  195. udpHeader := header.UDP(packet[ipHdrLen:])
  196. udpHeader.Encode(&header.UDPFields{
  197. SrcPort: uint16(dest.Port), // Source is the original destination
  198. DstPort: uint16(source.Port), // Destination is the original source
  199. Length: uint16(udpHdrLen + len(data)),
  200. })
  201. // Copy payload
  202. copy(packet[ipHdrLen+udpHdrLen:], data)
  203. // Build IP header and calculate checksums
  204. if isIPv4 {
  205. ipv4Header := header.IPv4(packet)
  206. ipv4Header.Encode(&header.IPv4Fields{
  207. TOS: 0,
  208. TotalLength: uint16(totalLen),
  209. ID: 0,
  210. Flags: 0,
  211. FragmentOffset: 0,
  212. TTL: 64,
  213. Protocol: uint8(header.UDPProtocolNumber),
  214. SrcAddr: tcpip.AddrFromSlice(dest.Address.IP()),
  215. DstAddr: tcpip.AddrFromSlice(source.Address.IP()),
  216. })
  217. ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum())
  218. // Calculate UDP checksum
  219. xsum := header.PseudoHeaderChecksum(
  220. header.UDPProtocolNumber,
  221. tcpip.AddrFromSlice(dest.Address.IP()),
  222. tcpip.AddrFromSlice(source.Address.IP()),
  223. uint16(udpHdrLen+len(data)),
  224. )
  225. xsum = checksum.Checksum(data, xsum)
  226. udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
  227. } else {
  228. ipv6Header := header.IPv6(packet)
  229. ipv6Header.Encode(&header.IPv6Fields{
  230. TrafficClass: 0,
  231. FlowLabel: 0,
  232. PayloadLength: uint16(udpHdrLen + len(data)),
  233. TransportProtocol: header.UDPProtocolNumber,
  234. HopLimit: 64,
  235. SrcAddr: tcpip.AddrFromSlice(dest.Address.IP()),
  236. DstAddr: tcpip.AddrFromSlice(source.Address.IP()),
  237. })
  238. // Calculate UDP checksum for IPv6
  239. xsum := header.PseudoHeaderChecksum(
  240. header.UDPProtocolNumber,
  241. tcpip.AddrFromSlice(dest.Address.IP()),
  242. tcpip.AddrFromSlice(source.Address.IP()),
  243. uint16(udpHdrLen+len(data)),
  244. )
  245. xsum = checksum.Checksum(data, xsum)
  246. udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
  247. }
  248. // Write packet to stack
  249. var proto tcpip.NetworkProtocolNumber
  250. if isIPv4 {
  251. proto = header.IPv4ProtocolNumber
  252. } else {
  253. proto = header.IPv6ProtocolNumber
  254. }
  255. buf := buffer.MakeWithData(packet)
  256. if err := ipStack.WriteRawPacket(defaultNIC, proto, buf); err != nil {
  257. return 0, errors.New("failed to write packet: " + err.String())
  258. }
  259. return len(data), nil
  260. }
  261. // HandleUDPPacket processes a raw UDP packet from gVisor
  262. func (t *Handler) HandleUDPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer, ipStack *stack.Stack) {
  263. // Extract packet information
  264. source := net.UDPDestination(
  265. net.IPAddress(id.RemoteAddress.AsSlice()),
  266. net.Port(id.RemotePort),
  267. )
  268. dest := net.UDPDestination(
  269. net.IPAddress(id.LocalAddress.AsSlice()),
  270. net.Port(id.LocalPort),
  271. )
  272. // Extract UDP payload
  273. data := pkt.Data().AsRange().ToSlice()
  274. if len(data) == 0 {
  275. return
  276. }
  277. // Get or create connection for this source
  278. conn, existing := t.getUDPConn(source, dest, ipStack)
  279. // Create buffer and set UDP destination
  280. b := buf.New()
  281. b.Write(data)
  282. b.UDP = &dest
  283. // Write to connection pipe
  284. conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
  285. if !existing {
  286. // Start checker for cleanup (only once)
  287. t.Lock()
  288. if t.udpChecker != nil && len(t.udpConns) == 1 {
  289. common.Must(t.udpChecker.Start())
  290. }
  291. t.Unlock()
  292. // Start handling this connection
  293. go func() {
  294. connID := udpConnID{
  295. src: source,
  296. }
  297. if !t.cone {
  298. connID.dest = dest
  299. }
  300. ctx, cancel := context.WithCancel(t.ctx)
  301. conn.cancel = cancel
  302. sid := session.NewID()
  303. ctx = c.ContextWithID(ctx, sid)
  304. inbound := session.Inbound{}
  305. inbound.Name = "tun"
  306. inbound.Source = source
  307. inbound.User = &protocol.MemoryUser{
  308. Level: t.config.UserLevel,
  309. }
  310. ctx = session.ContextWithInbound(ctx, &inbound)
  311. ctx = session.SubContextFromMuxInbound(ctx)
  312. link := &transport.Link{
  313. Reader: conn.reader,
  314. Writer: buf.NewWriter(conn),
  315. }
  316. if err := t.dispatcher.DispatchLink(ctx, dest, link); err != nil {
  317. errors.LogError(ctx, errors.New("UDP connection ended").Base(err))
  318. }
  319. conn.Close()
  320. if !conn.inactive {
  321. conn.setInactive()
  322. t.removeUDPConn(connID)
  323. }
  324. }()
  325. }
  326. }
  327. // Init the Handler instance with necessary parameters
  328. func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routing.Dispatcher) error {
  329. var err error
  330. t.ctx = core.ToBackgroundDetachedContext(ctx)
  331. t.policyManager = pm
  332. t.dispatcher = dispatcher
  333. t.cone = ctx.Value("cone").(bool)
  334. // Initialize UDP connection manager
  335. t.udpConns = make(map[udpConnID]*udpConn)
  336. t.udpChecker = &task.Periodic{
  337. Interval: time.Minute,
  338. Execute: t.cleanupUDPConns,
  339. }
  340. tunName := t.config.Name
  341. tunOptions := TunOptions{
  342. Name: tunName,
  343. MTU: t.config.MTU,
  344. }
  345. tunInterface, err := NewTun(tunOptions)
  346. if err != nil {
  347. return err
  348. }
  349. errors.LogInfo(t.ctx, tunName, " created")
  350. tunStackOptions := StackOptions{
  351. Tun: tunInterface,
  352. IdleTimeout: pm.ForLevel(t.config.UserLevel).Timeouts.ConnectionIdle,
  353. }
  354. tunStack, err := NewStack(t.ctx, tunStackOptions, t)
  355. if err != nil {
  356. _ = tunInterface.Close()
  357. return err
  358. }
  359. err = tunStack.Start()
  360. if err != nil {
  361. _ = tunStack.Close()
  362. _ = tunInterface.Close()
  363. return err
  364. }
  365. err = tunInterface.Start()
  366. if err != nil {
  367. _ = tunStack.Close()
  368. _ = tunInterface.Close()
  369. return err
  370. }
  371. t.stack = tunStack
  372. errors.LogInfo(t.ctx, tunName, " up")
  373. return nil
  374. }
  375. // HandleConnection pass the connection coming from the ip stack to the routing dispatcher
  376. func (t *Handler) HandleConnection(conn net.Conn, destination net.Destination) {
  377. sid := session.NewID()
  378. ctx := c.ContextWithID(t.ctx, sid)
  379. errors.LogInfo(ctx, "processing connection from: ", conn.RemoteAddr())
  380. inbound := session.Inbound{}
  381. inbound.Name = "tun"
  382. inbound.CanSpliceCopy = 1
  383. inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
  384. inbound.User = &protocol.MemoryUser{
  385. Level: t.config.UserLevel,
  386. }
  387. ctx = session.ContextWithInbound(ctx, &inbound)
  388. ctx = session.SubContextFromMuxInbound(ctx)
  389. var link *transport.Link
  390. if destination.Network == net.Network_UDP {
  391. // For UDP, use PacketReader to preserve packet boundaries
  392. link = &transport.Link{
  393. Reader: buf.NewPacketReader(conn),
  394. Writer: buf.NewWriter(conn),
  395. }
  396. } else {
  397. link = &transport.Link{
  398. Reader: &buf.TimeoutWrapperReader{Reader: buf.NewReader(conn)},
  399. Writer: buf.NewWriter(conn),
  400. }
  401. }
  402. if err := t.dispatcher.DispatchLink(ctx, destination, link); err != nil {
  403. errors.LogError(ctx, errors.New("connection closed").Base(err))
  404. return
  405. }
  406. errors.LogInfo(ctx, "connection completed")
  407. }
  408. // Network implements proxy.Inbound
  409. // and exists only to comply to proxy interface, declaring it doesn't listen on any network,
  410. // making the process not open any port for this inbound (input will be network interface)
  411. func (t *Handler) Network() []net.Network {
  412. return []net.Network{}
  413. }
  414. // Process implements proxy.Inbound
  415. // and exists only to comply to proxy interface, which should never get any inputs due to no listening ports
  416. func (t *Handler) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
  417. return nil
  418. }
  419. func init() {
  420. common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
  421. t := &Handler{config: config.(*Config)}
  422. err := core.RequireFeatures(ctx, func(pm policy.Manager, dispatcher routing.Dispatcher) error {
  423. return t.Init(ctx, pm, dispatcher)
  424. })
  425. return t, err
  426. }))
  427. }