server.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. package wireguard
  2. import (
  3. "context"
  4. goerrors "errors"
  5. "io"
  6. "github.com/xtls/xray-core/common/buf"
  7. c "github.com/xtls/xray-core/common/ctx"
  8. "github.com/xtls/xray-core/common/errors"
  9. "github.com/xtls/xray-core/common/log"
  10. "github.com/xtls/xray-core/common/net"
  11. "github.com/xtls/xray-core/common/session"
  12. "github.com/xtls/xray-core/core"
  13. "github.com/xtls/xray-core/features/dns"
  14. "github.com/xtls/xray-core/features/policy"
  15. "github.com/xtls/xray-core/features/routing"
  16. "github.com/xtls/xray-core/transport"
  17. "github.com/xtls/xray-core/transport/internet/stat"
  18. )
  19. var nullDestination = net.TCPDestination(net.AnyIP, 0)
  20. type Server struct {
  21. bindServer *netBindServer
  22. info routingInfo
  23. policyManager policy.Manager
  24. }
  25. type routingInfo struct {
  26. ctx context.Context
  27. dispatcher routing.Dispatcher
  28. inboundTag *session.Inbound
  29. contentTag *session.Content
  30. }
  31. func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
  32. v := core.MustFromContext(ctx)
  33. endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
  34. if err != nil {
  35. return nil, err
  36. }
  37. server := &Server{
  38. bindServer: &netBindServer{
  39. netBind: netBind{
  40. dns: v.GetFeature(dns.ClientType()).(dns.Client),
  41. dnsOption: dns.IPOption{
  42. IPv4Enable: hasIPv4,
  43. IPv6Enable: hasIPv6,
  44. },
  45. },
  46. },
  47. policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
  48. }
  49. tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection)
  50. if err != nil {
  51. return nil, err
  52. }
  53. if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil {
  54. _ = tun.Close()
  55. return nil, err
  56. }
  57. return server, nil
  58. }
  59. // Network implements proxy.Inbound.
  60. func (*Server) Network() []net.Network {
  61. return []net.Network{net.Network_UDP}
  62. }
  63. // Process implements proxy.Inbound.
  64. func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
  65. s.info = routingInfo{
  66. ctx: ctx,
  67. dispatcher: dispatcher,
  68. inboundTag: session.InboundFromContext(ctx),
  69. contentTag: session.ContentFromContext(ctx),
  70. }
  71. ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
  72. if err != nil {
  73. return err
  74. }
  75. nep := ep.(*netEndpoint)
  76. nep.conn = conn
  77. reader := buf.NewPacketReader(conn)
  78. for {
  79. mpayload, err := reader.ReadMultiBuffer()
  80. if err != nil {
  81. return err
  82. }
  83. for _, payload := range mpayload {
  84. v, ok := <-s.bindServer.readQueue
  85. if !ok {
  86. return nil
  87. }
  88. i, err := payload.Read(v.buff)
  89. v.bytes = i
  90. v.endpoint = nep
  91. v.err = err
  92. v.waiter.Done()
  93. if err != nil && goerrors.Is(err, io.EOF) {
  94. nep.conn = nil
  95. return nil
  96. }
  97. }
  98. }
  99. }
  100. func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
  101. if s.info.dispatcher == nil {
  102. errors.LogError(s.info.ctx, "unexpected: dispatcher == nil")
  103. return
  104. }
  105. ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
  106. sid := session.NewID()
  107. ctx = c.ContextWithID(ctx, sid)
  108. inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs)
  109. if s.info.inboundTag != nil {
  110. inbound = *s.info.inboundTag
  111. }
  112. inbound.Name = "wireguard"
  113. inbound.CanSpliceCopy = 3
  114. // overwrite the source to use the tun address for each sub context.
  115. // Since gvisor.ForwarderRequest doesn't provide any info to associate the sub-context with the Parent context
  116. // Currently we have no way to link to the original source address
  117. inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
  118. ctx = session.ContextWithInbound(ctx, &inbound)
  119. if s.info.contentTag != nil {
  120. ctx = session.ContextWithContent(ctx, s.info.contentTag)
  121. }
  122. ctx = session.SubContextFromMuxInbound(ctx)
  123. ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
  124. From: nullDestination,
  125. To: dest,
  126. Status: log.AccessAccepted,
  127. Reason: "",
  128. })
  129. err := s.info.dispatcher.DispatchLink(ctx, dest, &transport.Link{
  130. Reader: buf.NewReader(conn),
  131. Writer: buf.NewWriter(conn),
  132. })
  133. if err != nil {
  134. errors.LogInfoInner(ctx, err, "connection ends")
  135. }
  136. cancel()
  137. conn.Close()
  138. }