stack_gvisor.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. package tun
  2. import (
  3. "context"
  4. "time"
  5. "github.com/xtls/xray-core/common"
  6. "github.com/xtls/xray-core/common/buf"
  7. c "github.com/xtls/xray-core/common/ctx"
  8. "github.com/xtls/xray-core/common/errors"
  9. "github.com/xtls/xray-core/common/net"
  10. "github.com/xtls/xray-core/common/protocol"
  11. "github.com/xtls/xray-core/common/session"
  12. "github.com/xtls/xray-core/common/signal/done"
  13. "github.com/xtls/xray-core/transport"
  14. "github.com/xtls/xray-core/transport/pipe"
  15. "gvisor.dev/gvisor/pkg/buffer"
  16. "gvisor.dev/gvisor/pkg/tcpip"
  17. "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  18. "gvisor.dev/gvisor/pkg/tcpip/checksum"
  19. "gvisor.dev/gvisor/pkg/tcpip/header"
  20. "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
  21. "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
  22. "gvisor.dev/gvisor/pkg/tcpip/stack"
  23. "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  24. "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  25. "gvisor.dev/gvisor/pkg/waiter"
  26. )
  27. const (
  28. defaultNIC tcpip.NICID = 1
  29. tcpRXBufMinSize = tcp.MinBufferSize
  30. tcpRXBufDefSize = tcp.DefaultSendBufferSize
  31. tcpRXBufMaxSize = 8 << 20 // 8MiB
  32. tcpTXBufMinSize = tcp.MinBufferSize
  33. tcpTXBufDefSize = tcp.DefaultReceiveBufferSize
  34. tcpTXBufMaxSize = 6 << 20 // 6MiB
  35. )
  36. // stackGVisor is ip stack implemented by gVisor package
  37. type stackGVisor struct {
  38. ctx context.Context
  39. tun GVisorTun
  40. idleTimeout time.Duration
  41. handler *Handler
  42. stack *stack.Stack
  43. endpoint stack.LinkEndpoint
  44. }
  45. // GVisorTun implements a bridge to connect gVisor ip stack to tun interface
  46. type GVisorTun interface {
  47. newEndpoint() (stack.LinkEndpoint, error)
  48. }
  49. // NewStack builds new ip stack (using gVisor)
  50. func NewStack(ctx context.Context, options StackOptions, handler *Handler) (Stack, error) {
  51. gStack := &stackGVisor{
  52. ctx: ctx,
  53. tun: options.Tun.(GVisorTun),
  54. idleTimeout: options.IdleTimeout,
  55. handler: handler,
  56. }
  57. return gStack, nil
  58. }
  59. // Start is called by Handler to bring stack to life
  60. func (t *stackGVisor) Start() error {
  61. linkEndpoint, err := t.tun.newEndpoint()
  62. if err != nil {
  63. return err
  64. }
  65. ipStack, err := createStack(linkEndpoint)
  66. if err != nil {
  67. return err
  68. }
  69. tcpForwarder := tcp.NewForwarder(ipStack, 0, 65535, func(r *tcp.ForwarderRequest) {
  70. go func(r *tcp.ForwarderRequest) {
  71. var wq waiter.Queue
  72. var id = r.ID()
  73. // Perform a TCP three-way handshake.
  74. ep, err := r.CreateEndpoint(&wq)
  75. if err != nil {
  76. errors.LogError(t.ctx, err.String())
  77. r.Complete(true)
  78. return
  79. }
  80. options := ep.SocketOptions()
  81. options.SetKeepAlive(false)
  82. options.SetReuseAddress(true)
  83. options.SetReusePort(true)
  84. t.handler.HandleConnection(
  85. gonet.NewTCPConn(&wq, ep),
  86. // local address on the gVisor side is connection destination
  87. net.TCPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)),
  88. )
  89. // close the socket
  90. ep.Close()
  91. // send connection complete upstream
  92. r.Complete(false)
  93. }(r)
  94. })
  95. ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
  96. // Use custom UDP packet handler instead of forwarder for FullCone NAT
  97. ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
  98. t.handler.HandleUDPPacket(id, pkt, ipStack)
  99. return true
  100. })
  101. t.stack = ipStack
  102. t.endpoint = linkEndpoint
  103. return nil
  104. }
  105. // Close is called by Handler to shut down the stack
  106. func (t *stackGVisor) Close() error {
  107. if t.stack == nil {
  108. return nil
  109. }
  110. t.endpoint.Attach(nil)
  111. t.stack.Close()
  112. for _, endpoint := range t.stack.CleanupEndpoints() {
  113. endpoint.Abort()
  114. }
  115. return nil
  116. }
  117. // createStack configure gVisor ip stack
  118. func createStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
  119. opts := stack.Options{
  120. NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  121. TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
  122. HandleLocal: false,
  123. }
  124. gStack := stack.New(opts)
  125. err := gStack.CreateNIC(defaultNIC, ep)
  126. if err != nil {
  127. return nil, errors.New(err.String())
  128. }
  129. gStack.SetRouteTable([]tcpip.Route{
  130. {Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
  131. {Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
  132. })
  133. err = gStack.SetSpoofing(defaultNIC, true)
  134. if err != nil {
  135. return nil, errors.New(err.String())
  136. }
  137. err = gStack.SetPromiscuousMode(defaultNIC, true)
  138. if err != nil {
  139. return nil, errors.New(err.String())
  140. }
  141. cOpt := tcpip.CongestionControlOption("cubic")
  142. gStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
  143. sOpt := tcpip.TCPSACKEnabled(true)
  144. gStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
  145. mOpt := tcpip.TCPModerateReceiveBufferOption(true)
  146. gStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
  147. tcpRXBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{
  148. Min: tcpRXBufMinSize,
  149. Default: tcpRXBufDefSize,
  150. Max: tcpRXBufMaxSize,
  151. }
  152. err = gStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRXBufOpt)
  153. if err != nil {
  154. return nil, errors.New(err.String())
  155. }
  156. tcpTXBufOpt := tcpip.TCPSendBufferSizeRangeOption{
  157. Min: tcpTXBufMinSize,
  158. Default: tcpTXBufDefSize,
  159. Max: tcpTXBufMaxSize,
  160. }
  161. err = gStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTXBufOpt)
  162. if err != nil {
  163. return nil, errors.New(err.String())
  164. }
  165. return gStack, nil
  166. }
  167. // HandleUDPPacket handles incoming UDP packets for FullCone NAT
  168. func (t *Handler) HandleUDPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer, ipStack *stack.Stack) {
  169. src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
  170. dest := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
  171. data := pkt.Data().AsRange().ToSlice()
  172. if len(data) == 0 {
  173. return
  174. }
  175. t.Lock()
  176. conn, found := t.udpConns[src]
  177. if !found {
  178. reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
  179. conn = &udpConn{reader: reader, writer: writer, done: done.New()}
  180. t.udpConns[src] = conn
  181. if t.udpChecker != nil && len(t.udpConns) == 1 {
  182. common.Must(t.udpChecker.Start())
  183. }
  184. t.Unlock()
  185. go func() {
  186. ctx, cancel := context.WithCancel(t.ctx)
  187. conn.cancel = cancel
  188. defer func() {
  189. cancel()
  190. t.Lock()
  191. delete(t.udpConns, src)
  192. t.Unlock()
  193. common.Must(conn.done.Close())
  194. common.Must(common.Close(conn.writer))
  195. }()
  196. inbound := &session.Inbound{
  197. Name: "tun",
  198. Source: src,
  199. CanSpliceCopy: 1,
  200. User: &protocol.MemoryUser{Level: t.config.UserLevel},
  201. }
  202. ctx = session.ContextWithInbound(c.ContextWithID(ctx, session.NewID()), inbound)
  203. ctx = session.SubContextFromMuxInbound(ctx)
  204. link := &transport.Link{
  205. Reader: &buf.TimeoutWrapperReader{Reader: conn.reader},
  206. Writer: &udpWriter{stack: ipStack, src: dest, dest: src},
  207. }
  208. t.dispatcher.DispatchLink(ctx, dest, link)
  209. }()
  210. } else {
  211. conn.lastActive.Store(time.Now().Unix())
  212. t.Unlock()
  213. }
  214. b := buf.New()
  215. b.Write(data)
  216. b.UDP = &dest
  217. conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
  218. }
  219. type udpWriter struct {
  220. stack *stack.Stack
  221. src net.Destination
  222. dest net.Destination
  223. }
  224. func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
  225. for _, b := range mb {
  226. // Use b.UDP as source if available, otherwise use w.dest
  227. srcAddr := w.dest
  228. if b.UDP != nil {
  229. srcAddr = *b.UDP
  230. }
  231. // Validate address family matches
  232. if srcAddr.Address.Family() != w.src.Address.Family() {
  233. errors.LogWarning(context.Background(), "UDP return packet address family mismatch: expected ", w.src.Address.Family(), ", got ", srcAddr.Address.Family())
  234. b.Release()
  235. continue
  236. }
  237. payload := b.Bytes()
  238. udpLen := header.UDPMinimumSize + len(payload)
  239. srcIP := tcpip.AddrFromSlice(srcAddr.Address.IP())
  240. dstIP := tcpip.AddrFromSlice(w.src.Address.IP())
  241. // Build packet with appropriate IP header size
  242. isIPv4 := w.src.Address.Family().IsIPv4()
  243. ipHdrSize := header.IPv6MinimumSize
  244. netProto := header.IPv6ProtocolNumber
  245. if isIPv4 {
  246. ipHdrSize = header.IPv4MinimumSize
  247. netProto = header.IPv4ProtocolNumber
  248. }
  249. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
  250. ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
  251. Payload: buffer.MakeWithData(payload),
  252. })
  253. // Build UDP header
  254. udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
  255. udpHdr.Encode(&header.UDPFields{
  256. SrcPort: uint16(srcAddr.Port),
  257. DstPort: uint16(w.src.Port),
  258. Length: uint16(udpLen),
  259. })
  260. // Calculate and set UDP checksum
  261. xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
  262. udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
  263. // Build IP header
  264. if isIPv4 {
  265. ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
  266. ipHdr.Encode(&header.IPv4Fields{
  267. TotalLength: uint16(header.IPv4MinimumSize + udpLen),
  268. TTL: 64,
  269. Protocol: uint8(header.UDPProtocolNumber),
  270. SrcAddr: srcIP,
  271. DstAddr: dstIP,
  272. })
  273. ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
  274. } else {
  275. ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
  276. ipHdr.Encode(&header.IPv6Fields{
  277. PayloadLength: uint16(udpLen),
  278. TransportProtocol: header.UDPProtocolNumber,
  279. HopLimit: 64,
  280. SrcAddr: srcIP,
  281. DstAddr: dstIP,
  282. })
  283. }
  284. // Write raw packet to network stack
  285. views := pkt.AsSlices()
  286. var data []byte
  287. for _, view := range views {
  288. data = append(data, view...)
  289. }
  290. w.stack.WriteRawPacket(defaultNIC, netProto, buffer.MakeWithData(data))
  291. pkt.DecRef()
  292. b.Release()
  293. }
  294. return nil
  295. }