server.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. package wireguard
  2. import (
  3. "context"
  4. goerrors "errors"
  5. "io"
  6. "github.com/xtls/xray-core/common"
  7. "github.com/xtls/xray-core/common/buf"
  8. c "github.com/xtls/xray-core/common/ctx"
  9. "github.com/xtls/xray-core/common/errors"
  10. "github.com/xtls/xray-core/common/log"
  11. "github.com/xtls/xray-core/common/net"
  12. "github.com/xtls/xray-core/common/session"
  13. "github.com/xtls/xray-core/common/signal"
  14. "github.com/xtls/xray-core/common/task"
  15. "github.com/xtls/xray-core/core"
  16. "github.com/xtls/xray-core/features/dns"
  17. "github.com/xtls/xray-core/features/policy"
  18. "github.com/xtls/xray-core/features/routing"
  19. "github.com/xtls/xray-core/transport/internet/stat"
  20. )
  21. var nullDestination = net.TCPDestination(net.AnyIP, 0)
  22. type Server struct {
  23. bindServer *netBindServer
  24. info routingInfo
  25. policyManager policy.Manager
  26. tag string
  27. sniffingRequest session.SniffingRequest
  28. }
  29. type routingInfo struct {
  30. ctx context.Context
  31. dispatcher routing.Dispatcher
  32. }
  33. func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
  34. v := core.MustFromContext(ctx)
  35. endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
  36. if err != nil {
  37. return nil, err
  38. }
  39. server := &Server{
  40. bindServer: &netBindServer{
  41. netBind: netBind{
  42. dns: v.GetFeature(dns.ClientType()).(dns.Client),
  43. dnsOption: dns.IPOption{
  44. IPv4Enable: hasIPv4,
  45. IPv6Enable: hasIPv6,
  46. },
  47. },
  48. },
  49. policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
  50. }
  51. // Retrieve tag and sniffing config from context (set by AlwaysOnInboundHandler)
  52. if inbound := session.InboundFromContext(ctx); inbound != nil {
  53. server.tag = inbound.Tag
  54. }
  55. if content := session.ContentFromContext(ctx); content != nil {
  56. server.sniffingRequest = content.SniffingRequest
  57. }
  58. tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection)
  59. if err != nil {
  60. return nil, err
  61. }
  62. if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil {
  63. _ = tun.Close()
  64. return nil, err
  65. }
  66. return server, nil
  67. }
  68. // Network implements proxy.Inbound.
  69. func (*Server) Network() []net.Network {
  70. return []net.Network{net.Network_UDP}
  71. }
  72. // Process implements proxy.Inbound.
  73. func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
  74. s.info = routingInfo{
  75. ctx: ctx,
  76. dispatcher: dispatcher,
  77. }
  78. ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
  79. if err != nil {
  80. return err
  81. }
  82. nep := ep.(*netEndpoint)
  83. nep.conn = conn
  84. reader := buf.NewPacketReader(conn)
  85. for {
  86. mpayload, err := reader.ReadMultiBuffer()
  87. if err != nil {
  88. return err
  89. }
  90. for _, payload := range mpayload {
  91. v, ok := <-s.bindServer.readQueue
  92. if !ok {
  93. return nil
  94. }
  95. i, err := payload.Read(v.buff)
  96. v.bytes = i
  97. v.endpoint = nep
  98. v.err = err
  99. v.waiter.Done()
  100. if err != nil && goerrors.Is(err, io.EOF) {
  101. nep.conn = nil
  102. return nil
  103. }
  104. }
  105. }
  106. }
  107. func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
  108. if s.info.dispatcher == nil {
  109. errors.LogError(s.info.ctx, "unexpected: dispatcher == nil")
  110. return
  111. }
  112. defer conn.Close()
  113. ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
  114. sid := session.NewID()
  115. ctx = c.ContextWithID(ctx, sid)
  116. inbound := session.Inbound{
  117. Name: "wireguard",
  118. Tag: s.tag,
  119. CanSpliceCopy: 3,
  120. // overwrite the source to use the tun address for each sub context.
  121. // Since gvisor.ForwarderRequest doesn't provide any info to associate the sub-context with the Parent context
  122. // Currently we have no way to link to the original source address
  123. Source: net.DestinationFromAddr(conn.RemoteAddr()),
  124. }
  125. ctx = session.ContextWithInbound(ctx, &inbound)
  126. ctx = session.ContextWithContent(ctx, &session.Content{
  127. SniffingRequest: s.sniffingRequest,
  128. })
  129. ctx = session.SubContextFromMuxInbound(ctx)
  130. plcy := s.policyManager.ForLevel(0)
  131. timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
  132. ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
  133. From: nullDestination,
  134. To: dest,
  135. Status: log.AccessAccepted,
  136. Reason: "",
  137. })
  138. link, err := s.info.dispatcher.Dispatch(ctx, dest)
  139. if err != nil {
  140. errors.LogErrorInner(ctx, err, "dispatch connection")
  141. }
  142. defer cancel()
  143. requestDone := func() error {
  144. defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
  145. if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
  146. return errors.New("failed to transport all TCP request").Base(err)
  147. }
  148. return nil
  149. }
  150. responseDone := func() error {
  151. defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
  152. if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
  153. return errors.New("failed to transport all TCP response").Base(err)
  154. }
  155. return nil
  156. }
  157. requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
  158. if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
  159. common.Interrupt(link.Reader)
  160. common.Interrupt(link.Writer)
  161. errors.LogDebugInner(ctx, err, "connection ends")
  162. return
  163. }
  164. }