handler.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. package tun
  2. import (
  3. "context"
  4. "sync"
  5. "sync/atomic"
  6. "time"
  7. "github.com/xtls/xray-core/common"
  8. "github.com/xtls/xray-core/common/buf"
  9. c "github.com/xtls/xray-core/common/ctx"
  10. "github.com/xtls/xray-core/common/errors"
  11. "github.com/xtls/xray-core/common/net"
  12. "github.com/xtls/xray-core/common/protocol"
  13. "github.com/xtls/xray-core/common/session"
  14. "github.com/xtls/xray-core/common/signal/done"
  15. "github.com/xtls/xray-core/common/task"
  16. "github.com/xtls/xray-core/core"
  17. "github.com/xtls/xray-core/features/policy"
  18. "github.com/xtls/xray-core/features/routing"
  19. "github.com/xtls/xray-core/transport"
  20. "github.com/xtls/xray-core/transport/internet/stat"
  21. "github.com/xtls/xray-core/transport/pipe"
  22. "gvisor.dev/gvisor/pkg/buffer"
  23. "gvisor.dev/gvisor/pkg/tcpip"
  24. "gvisor.dev/gvisor/pkg/tcpip/checksum"
  25. "gvisor.dev/gvisor/pkg/tcpip/header"
  26. "gvisor.dev/gvisor/pkg/tcpip/stack"
  27. )
  28. type udpConn struct {
  29. lastActive atomic.Int64
  30. reader buf.Reader
  31. writer buf.Writer
  32. done *done.Instance
  33. cancel context.CancelFunc
  34. }
  35. // Handler is managing object that tie together tun interface, ip stack and dispatch connections to the routing
  36. type Handler struct {
  37. sync.Mutex
  38. ctx context.Context
  39. config *Config
  40. stack Stack
  41. policyManager policy.Manager
  42. dispatcher routing.Dispatcher
  43. udpConns map[net.Destination]*udpConn
  44. udpChecker *task.Periodic
  45. }
  46. // ConnectionHandler interface with the only method that stack is going to push new connections to
  47. type ConnectionHandler interface {
  48. HandleConnection(conn net.Conn, destination net.Destination)
  49. }
  50. // Handler implements ConnectionHandler
  51. var _ ConnectionHandler = (*Handler)(nil)
  52. func (t *Handler) policy() policy.Session {
  53. return t.policyManager.ForLevel(t.config.UserLevel)
  54. }
  55. func (t *Handler) cleanupUDP() error {
  56. t.Lock()
  57. defer t.Unlock()
  58. if len(t.udpConns) == 0 {
  59. return errors.New("no connections")
  60. }
  61. now := time.Now().Unix()
  62. for src, conn := range t.udpConns {
  63. if now-conn.lastActive.Load() > 300 {
  64. conn.cancel()
  65. common.Must(conn.done.Close())
  66. common.Must(common.Close(conn.writer))
  67. delete(t.udpConns, src)
  68. }
  69. }
  70. return nil
  71. }
  72. func (t *Handler) HandleUDPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer, ipStack *stack.Stack) {
  73. src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
  74. dest := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
  75. data := pkt.Data().AsRange().ToSlice()
  76. if len(data) == 0 {
  77. return
  78. }
  79. t.Lock()
  80. conn, found := t.udpConns[src]
  81. if !found {
  82. reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
  83. conn = &udpConn{reader: reader, writer: writer, done: done.New()}
  84. t.udpConns[src] = conn
  85. if t.udpChecker != nil && len(t.udpConns) == 1 {
  86. common.Must(t.udpChecker.Start())
  87. }
  88. t.Unlock()
  89. go func() {
  90. ctx, cancel := context.WithCancel(t.ctx)
  91. conn.cancel = cancel
  92. defer func() {
  93. cancel()
  94. t.Lock()
  95. delete(t.udpConns, src)
  96. t.Unlock()
  97. common.Must(conn.done.Close())
  98. common.Must(common.Close(conn.writer))
  99. }()
  100. inbound := &session.Inbound{
  101. Name: "tun",
  102. Source: src,
  103. CanSpliceCopy: 1,
  104. User: &protocol.MemoryUser{Level: t.config.UserLevel},
  105. }
  106. ctx = session.ContextWithInbound(c.ContextWithID(ctx, session.NewID()), inbound)
  107. ctx = session.SubContextFromMuxInbound(ctx)
  108. link := &transport.Link{
  109. Reader: &buf.TimeoutWrapperReader{Reader: conn.reader},
  110. Writer: &udpWriter{stack: ipStack, src: dest, dest: src},
  111. }
  112. t.dispatcher.DispatchLink(ctx, dest, link)
  113. }()
  114. } else {
  115. conn.lastActive.Store(time.Now().Unix())
  116. t.Unlock()
  117. }
  118. b := buf.New()
  119. b.Write(data)
  120. b.UDP = &dest
  121. conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
  122. }
  123. type udpWriter struct {
  124. stack *stack.Stack
  125. src net.Destination
  126. dest net.Destination
  127. }
  128. func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
  129. for _, b := range mb {
  130. // Use b.UDP as source if available, otherwise use w.dest
  131. srcAddr := w.dest
  132. if b.UDP != nil {
  133. srcAddr = *b.UDP
  134. }
  135. // Validate address family matches
  136. if srcAddr.Address.Family() != w.src.Address.Family() {
  137. errors.LogWarning(context.Background(), "UDP return packet address family mismatch: expected ", w.src.Address.Family(), ", got ", srcAddr.Address.Family())
  138. b.Release()
  139. continue
  140. }
  141. netProto := header.IPv4ProtocolNumber
  142. if !w.src.Address.Family().IsIPv4() {
  143. netProto = header.IPv6ProtocolNumber
  144. }
  145. // Build route from actual response source to original client
  146. route, err := w.stack.FindRoute(
  147. defaultNIC,
  148. tcpip.AddrFromSlice(srcAddr.Address.IP()),
  149. tcpip.AddrFromSlice(w.src.Address.IP()),
  150. netProto,
  151. false,
  152. )
  153. if err != nil {
  154. b.Release()
  155. continue
  156. }
  157. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
  158. ReserveHeaderBytes: header.UDPMinimumSize,
  159. Payload: buffer.MakeWithData(b.Bytes()),
  160. })
  161. udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
  162. udp.Encode(&header.UDPFields{
  163. SrcPort: uint16(srcAddr.Port),
  164. DstPort: uint16(w.src.Port),
  165. Length: uint16(pkt.Size()),
  166. })
  167. xsum := route.PseudoHeaderChecksum(header.UDPProtocolNumber, uint16(pkt.Size()))
  168. udp.SetChecksum(^udp.CalculateChecksum(checksum.Checksum(b.Bytes(), xsum)))
  169. route.WritePacket(stack.NetworkHeaderParams{
  170. Protocol: header.UDPProtocolNumber,
  171. TTL: 64,
  172. }, pkt)
  173. pkt.DecRef()
  174. route.Release()
  175. b.Release()
  176. }
  177. return nil
  178. }
  179. // Init the Handler instance with necessary parameters
  180. func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routing.Dispatcher) error {
  181. var err error
  182. t.ctx = core.ToBackgroundDetachedContext(ctx)
  183. t.policyManager = pm
  184. t.dispatcher = dispatcher
  185. t.udpConns = make(map[net.Destination]*udpConn)
  186. t.udpChecker = &task.Periodic{Interval: time.Minute, Execute: t.cleanupUDP}
  187. tunName := t.config.Name
  188. tunOptions := TunOptions{
  189. Name: tunName,
  190. MTU: t.config.MTU,
  191. }
  192. tunInterface, err := NewTun(tunOptions)
  193. if err != nil {
  194. return err
  195. }
  196. errors.LogInfo(t.ctx, tunName, " created")
  197. tunStackOptions := StackOptions{
  198. Tun: tunInterface,
  199. IdleTimeout: pm.ForLevel(t.config.UserLevel).Timeouts.ConnectionIdle,
  200. }
  201. tunStack, err := NewStack(t.ctx, tunStackOptions, t)
  202. if err != nil {
  203. _ = tunInterface.Close()
  204. return err
  205. }
  206. err = tunStack.Start()
  207. if err != nil {
  208. _ = tunStack.Close()
  209. _ = tunInterface.Close()
  210. return err
  211. }
  212. err = tunInterface.Start()
  213. if err != nil {
  214. _ = tunStack.Close()
  215. _ = tunInterface.Close()
  216. return err
  217. }
  218. t.stack = tunStack
  219. errors.LogInfo(t.ctx, tunName, " up")
  220. return nil
  221. }
  222. // HandleConnection pass the connection coming from the ip stack to the routing dispatcher
  223. func (t *Handler) HandleConnection(conn net.Conn, destination net.Destination) {
  224. sid := session.NewID()
  225. ctx := c.ContextWithID(t.ctx, sid)
  226. errors.LogInfo(ctx, "processing connection from: ", conn.RemoteAddr())
  227. inbound := session.Inbound{}
  228. inbound.Name = "tun"
  229. inbound.CanSpliceCopy = 1
  230. inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
  231. inbound.User = &protocol.MemoryUser{
  232. Level: t.config.UserLevel,
  233. }
  234. ctx = session.ContextWithInbound(ctx, &inbound)
  235. ctx = session.SubContextFromMuxInbound(ctx)
  236. link := &transport.Link{
  237. Reader: &buf.TimeoutWrapperReader{Reader: buf.NewReader(conn)},
  238. Writer: buf.NewWriter(conn),
  239. }
  240. if err := t.dispatcher.DispatchLink(ctx, destination, link); err != nil {
  241. errors.LogError(ctx, errors.New("connection closed").Base(err))
  242. return
  243. }
  244. errors.LogInfo(ctx, "connection completed")
  245. }
  246. // Network implements proxy.Inbound
  247. // and exists only to comply to proxy interface, declaring it doesn't listen on any network,
  248. // making the process not open any port for this inbound (input will be network interface)
  249. func (t *Handler) Network() []net.Network {
  250. return []net.Network{}
  251. }
  252. // Process implements proxy.Inbound
  253. // and exists only to comply to proxy interface, which should never get any inputs due to no listening ports
  254. func (t *Handler) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
  255. return nil
  256. }
  257. func init() {
  258. common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
  259. t := &Handler{config: config.(*Config)}
  260. err := core.RequireFeatures(ctx, func(pm policy.Manager, dispatcher routing.Dispatcher) error {
  261. return t.Init(ctx, pm, dispatcher)
  262. })
  263. return t, err
  264. }))
  265. }