default.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. package dispatcher
  2. import (
  3. "context"
  4. go_errors "errors"
  5. "regexp"
  6. "strings"
  7. "sync"
  8. "time"
  9. "github.com/xtls/xray-core/common"
  10. "github.com/xtls/xray-core/common/buf"
  11. "github.com/xtls/xray-core/common/errors"
  12. "github.com/xtls/xray-core/common/log"
  13. "github.com/xtls/xray-core/common/net"
  14. "github.com/xtls/xray-core/common/protocol"
  15. "github.com/xtls/xray-core/common/session"
  16. "github.com/xtls/xray-core/core"
  17. "github.com/xtls/xray-core/features/dns"
  18. "github.com/xtls/xray-core/features/outbound"
  19. "github.com/xtls/xray-core/features/policy"
  20. "github.com/xtls/xray-core/features/routing"
  21. routing_session "github.com/xtls/xray-core/features/routing/session"
  22. "github.com/xtls/xray-core/features/stats"
  23. "github.com/xtls/xray-core/transport"
  24. "github.com/xtls/xray-core/transport/pipe"
  25. )
  26. var errSniffingTimeout = errors.New("timeout on sniffing")
  27. type cachedReader struct {
  28. sync.Mutex
  29. reader buf.TimeoutReader // *pipe.Reader or *buf.TimeoutWrapperReader
  30. cache buf.MultiBuffer
  31. }
  32. func (r *cachedReader) Cache(b *buf.Buffer, deadline time.Duration) error {
  33. mb, err := r.reader.ReadMultiBufferTimeout(deadline)
  34. if err != nil {
  35. return err
  36. }
  37. r.Lock()
  38. if !mb.IsEmpty() {
  39. r.cache, _ = buf.MergeMulti(r.cache, mb)
  40. }
  41. b.Clear()
  42. rawBytes := b.Extend(min(r.cache.Len(), b.Cap()))
  43. n := r.cache.Copy(rawBytes)
  44. b.Resize(0, int32(n))
  45. r.Unlock()
  46. return nil
  47. }
  48. func (r *cachedReader) readInternal() buf.MultiBuffer {
  49. r.Lock()
  50. defer r.Unlock()
  51. if r.cache != nil && !r.cache.IsEmpty() {
  52. mb := r.cache
  53. r.cache = nil
  54. return mb
  55. }
  56. return nil
  57. }
  58. func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
  59. mb := r.readInternal()
  60. if mb != nil {
  61. return mb, nil
  62. }
  63. return r.reader.ReadMultiBuffer()
  64. }
  65. func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
  66. mb := r.readInternal()
  67. if mb != nil {
  68. return mb, nil
  69. }
  70. return r.reader.ReadMultiBufferTimeout(timeout)
  71. }
  72. func (r *cachedReader) Interrupt() {
  73. r.Lock()
  74. if r.cache != nil {
  75. r.cache = buf.ReleaseMulti(r.cache)
  76. }
  77. r.Unlock()
  78. if p, ok := r.reader.(*pipe.Reader); ok {
  79. p.Interrupt()
  80. }
  81. }
  82. // DefaultDispatcher is a default implementation of Dispatcher.
  83. type DefaultDispatcher struct {
  84. ohm outbound.Manager
  85. router routing.Router
  86. policy policy.Manager
  87. stats stats.Manager
  88. fdns dns.FakeDNSEngine
  89. }
  90. func init() {
  91. common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
  92. d := new(DefaultDispatcher)
  93. if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
  94. core.OptionalFeatures(ctx, func(fdns dns.FakeDNSEngine) {
  95. d.fdns = fdns
  96. })
  97. return d.Init(config.(*Config), om, router, pm, sm)
  98. }); err != nil {
  99. return nil, err
  100. }
  101. return d, nil
  102. }))
  103. }
  104. // Init initializes DefaultDispatcher.
  105. func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error {
  106. d.ohm = om
  107. d.router = router
  108. d.policy = pm
  109. d.stats = sm
  110. return nil
  111. }
  112. // Type implements common.HasType.
  113. func (*DefaultDispatcher) Type() interface{} {
  114. return routing.DispatcherType()
  115. }
  116. // Start implements common.Runnable.
  117. func (*DefaultDispatcher) Start() error {
  118. return nil
  119. }
  120. // Close implements common.Closable.
  121. func (*DefaultDispatcher) Close() error { return nil }
  122. func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) {
  123. opt := pipe.OptionsFromContext(ctx)
  124. uplinkReader, uplinkWriter := pipe.New(opt...)
  125. downlinkReader, downlinkWriter := pipe.New(opt...)
  126. inboundLink := &transport.Link{
  127. Reader: downlinkReader,
  128. Writer: uplinkWriter,
  129. }
  130. outboundLink := &transport.Link{
  131. Reader: uplinkReader,
  132. Writer: downlinkWriter,
  133. }
  134. sessionInbound := session.InboundFromContext(ctx)
  135. var user *protocol.MemoryUser
  136. if sessionInbound != nil {
  137. user = sessionInbound.User
  138. }
  139. if user != nil && len(user.Email) > 0 {
  140. p := d.policy.ForLevel(user.Level)
  141. if p.Stats.UserUplink {
  142. name := "user>>>" + user.Email + ">>>traffic>>>uplink"
  143. if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
  144. inboundLink.Writer = &SizeStatWriter{
  145. Counter: c,
  146. Writer: inboundLink.Writer,
  147. }
  148. }
  149. }
  150. if p.Stats.UserDownlink {
  151. name := "user>>>" + user.Email + ">>>traffic>>>downlink"
  152. if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
  153. outboundLink.Writer = &SizeStatWriter{
  154. Counter: c,
  155. Writer: outboundLink.Writer,
  156. }
  157. }
  158. }
  159. if p.Stats.UserOnline {
  160. name := "user>>>" + user.Email + ">>>online"
  161. if om, _ := stats.GetOrRegisterOnlineMap(d.stats, name); om != nil {
  162. sessionInbounds := session.InboundFromContext(ctx)
  163. userIP := sessionInbounds.Source.Address.String()
  164. om.AddIP(userIP)
  165. // log Online user with ips
  166. // errors.LogDebug(ctx, "user>>>" + user.Email + ">>>online", om.Count(), om.List())
  167. }
  168. }
  169. }
  170. return inboundLink, outboundLink
  171. }
  172. func (d *DefaultDispatcher) WrapLink(ctx context.Context, link *transport.Link) *transport.Link {
  173. sessionInbound := session.InboundFromContext(ctx)
  174. var user *protocol.MemoryUser
  175. if sessionInbound != nil {
  176. user = sessionInbound.User
  177. }
  178. link.Reader = &buf.TimeoutWrapperReader{Reader: link.Reader}
  179. if user != nil && len(user.Email) > 0 {
  180. p := d.policy.ForLevel(user.Level)
  181. if p.Stats.UserUplink {
  182. name := "user>>>" + user.Email + ">>>traffic>>>uplink"
  183. if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
  184. link.Reader.(*buf.TimeoutWrapperReader).Counter = c
  185. }
  186. }
  187. if p.Stats.UserDownlink {
  188. name := "user>>>" + user.Email + ">>>traffic>>>downlink"
  189. if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
  190. link.Writer = &SizeStatWriter{
  191. Counter: c,
  192. Writer: link.Writer,
  193. }
  194. }
  195. }
  196. if p.Stats.UserOnline {
  197. name := "user>>>" + user.Email + ">>>online"
  198. if om, _ := stats.GetOrRegisterOnlineMap(d.stats, name); om != nil {
  199. sessionInbounds := session.InboundFromContext(ctx)
  200. userIP := sessionInbounds.Source.Address.String()
  201. om.AddIP(userIP)
  202. // log Online user with ips
  203. // errors.LogDebug(ctx, "user>>>" + user.Email + ">>>online", om.Count(), om.List())
  204. }
  205. }
  206. }
  207. return link
  208. }
  209. func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool {
  210. domain := result.Domain()
  211. if domain == "" {
  212. return false
  213. }
  214. for _, d := range request.ExcludeForDomain {
  215. if strings.HasPrefix(d, "regexp:") {
  216. pattern := d[7:]
  217. re, err := regexp.Compile(pattern)
  218. if err != nil {
  219. errors.LogInfo(ctx, "Unable to compile regex")
  220. continue
  221. }
  222. if re.MatchString(domain) {
  223. return false
  224. }
  225. } else {
  226. if strings.ToLower(domain) == d {
  227. return false
  228. }
  229. }
  230. }
  231. protocolString := result.Protocol()
  232. if resComp, ok := result.(SnifferResultComposite); ok {
  233. protocolString = resComp.ProtocolForDomainResult()
  234. }
  235. for _, p := range request.OverrideDestinationForProtocol {
  236. if strings.HasPrefix(protocolString, p) || strings.HasPrefix(p, protocolString) {
  237. return true
  238. }
  239. if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
  240. fkr0.IsIPInIPPool(destination.Address) {
  241. errors.LogInfo(ctx, "Using sniffer ", protocolString, " since the fake DNS missed")
  242. return true
  243. }
  244. if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok {
  245. if resultSubset.IsProtoSubsetOf(p) {
  246. return true
  247. }
  248. }
  249. }
  250. return false
  251. }
  252. // Dispatch implements routing.Dispatcher.
  253. func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) {
  254. if !destination.IsValid() {
  255. panic("Dispatcher: Invalid destination.")
  256. }
  257. outbounds := session.OutboundsFromContext(ctx)
  258. if len(outbounds) == 0 {
  259. outbounds = []*session.Outbound{{}}
  260. ctx = session.ContextWithOutbounds(ctx, outbounds)
  261. }
  262. ob := outbounds[len(outbounds)-1]
  263. ob.OriginalTarget = destination
  264. ob.Target = destination
  265. content := session.ContentFromContext(ctx)
  266. if content == nil {
  267. content = new(session.Content)
  268. ctx = session.ContextWithContent(ctx, content)
  269. }
  270. sniffingRequest := content.SniffingRequest
  271. inbound, outbound := d.getLink(ctx)
  272. if !sniffingRequest.Enabled {
  273. go d.routedDispatch(ctx, outbound, destination)
  274. } else {
  275. go func() {
  276. cReader := &cachedReader{
  277. reader: outbound.Reader.(*pipe.Reader),
  278. }
  279. outbound.Reader = cReader
  280. result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
  281. if err == nil {
  282. content.Protocol = result.Protocol()
  283. }
  284. if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
  285. domain := result.Domain()
  286. errors.LogInfo(ctx, "sniffed domain: ", domain)
  287. destination.Address = net.ParseAddress(domain)
  288. protocol := result.Protocol()
  289. if resComp, ok := result.(SnifferResultComposite); ok {
  290. protocol = resComp.ProtocolForDomainResult()
  291. }
  292. isFakeIP := false
  293. if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(ob.Target.Address) {
  294. isFakeIP = true
  295. }
  296. if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
  297. ob.RouteTarget = destination
  298. } else {
  299. ob.Target = destination
  300. }
  301. }
  302. d.routedDispatch(ctx, outbound, destination)
  303. }()
  304. }
  305. return inbound, nil
  306. }
  307. // DispatchLink implements routing.Dispatcher.
  308. func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
  309. if !destination.IsValid() {
  310. return errors.New("Dispatcher: Invalid destination.")
  311. }
  312. outbounds := session.OutboundsFromContext(ctx)
  313. if len(outbounds) == 0 {
  314. outbounds = []*session.Outbound{{}}
  315. ctx = session.ContextWithOutbounds(ctx, outbounds)
  316. }
  317. ob := outbounds[len(outbounds)-1]
  318. ob.OriginalTarget = destination
  319. ob.Target = destination
  320. content := session.ContentFromContext(ctx)
  321. if content == nil {
  322. content = new(session.Content)
  323. ctx = session.ContextWithContent(ctx, content)
  324. }
  325. outbound = d.WrapLink(ctx, outbound)
  326. sniffingRequest := content.SniffingRequest
  327. if !sniffingRequest.Enabled {
  328. d.routedDispatch(ctx, outbound, destination)
  329. } else {
  330. cReader := &cachedReader{
  331. reader: outbound.Reader.(buf.TimeoutReader),
  332. }
  333. outbound.Reader = cReader
  334. result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
  335. if err == nil {
  336. content.Protocol = result.Protocol()
  337. }
  338. if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
  339. domain := result.Domain()
  340. errors.LogInfo(ctx, "sniffed domain: ", domain)
  341. destination.Address = net.ParseAddress(domain)
  342. protocol := result.Protocol()
  343. if resComp, ok := result.(SnifferResultComposite); ok {
  344. protocol = resComp.ProtocolForDomainResult()
  345. }
  346. isFakeIP := false
  347. if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(ob.Target.Address) {
  348. isFakeIP = true
  349. }
  350. if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
  351. ob.RouteTarget = destination
  352. } else {
  353. ob.Target = destination
  354. }
  355. }
  356. d.routedDispatch(ctx, outbound, destination)
  357. }
  358. return nil
  359. }
  360. func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) {
  361. payload := buf.NewWithSize(32767)
  362. defer payload.Release()
  363. sniffer := NewSniffer(ctx)
  364. metaresult, metadataErr := sniffer.SniffMetadata(ctx)
  365. if metadataOnly {
  366. return metaresult, metadataErr
  367. }
  368. contentResult, contentErr := func() (SniffResult, error) {
  369. cacheDeadline := 200 * time.Millisecond
  370. totalAttempt := 0
  371. for {
  372. select {
  373. case <-ctx.Done():
  374. return nil, ctx.Err()
  375. default:
  376. cachingStartingTimeStamp := time.Now()
  377. err := cReader.Cache(payload, cacheDeadline)
  378. if err != nil {
  379. return nil, err
  380. }
  381. cachingTimeElapsed := time.Since(cachingStartingTimeStamp)
  382. cacheDeadline -= cachingTimeElapsed
  383. if !payload.IsEmpty() {
  384. result, err := sniffer.Sniff(ctx, payload.Bytes(), network)
  385. switch err {
  386. case common.ErrNoClue: // No Clue: protocol not matches, and sniffer cannot determine whether there will be a match or not
  387. totalAttempt++
  388. case protocol.ErrProtoNeedMoreData: // Protocol Need More Data: protocol matches, but need more data to complete sniffing
  389. // in this case, do not add totalAttempt(allow to read until timeout)
  390. default:
  391. return result, err
  392. }
  393. } else {
  394. totalAttempt++
  395. }
  396. if totalAttempt >= 2 || cacheDeadline <= 0 {
  397. return nil, errSniffingTimeout
  398. }
  399. }
  400. }
  401. }()
  402. if contentErr != nil && metadataErr == nil {
  403. return metaresult, nil
  404. }
  405. if contentErr == nil && metadataErr == nil {
  406. return CompositeResult(metaresult, contentResult), nil
  407. }
  408. return contentResult, contentErr
  409. }
  410. func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
  411. outbounds := session.OutboundsFromContext(ctx)
  412. ob := outbounds[len(outbounds)-1]
  413. var handler outbound.Handler
  414. routingLink := routing_session.AsRoutingContext(ctx)
  415. inTag := routingLink.GetInboundTag()
  416. isPickRoute := 0
  417. if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" {
  418. ctx = session.SetForcedOutboundTagToContext(ctx, "")
  419. if h := d.ohm.GetHandler(forcedOutboundTag); h != nil {
  420. isPickRoute = 1
  421. errors.LogInfo(ctx, "taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]")
  422. handler = h
  423. } else {
  424. errors.LogError(ctx, "non existing tag for platform initialized detour: ", forcedOutboundTag)
  425. common.Close(link.Writer)
  426. common.Interrupt(link.Reader)
  427. return
  428. }
  429. } else if d.router != nil {
  430. if route, err := d.router.PickRoute(routingLink); err == nil {
  431. outTag := route.GetOutboundTag()
  432. if h := d.ohm.GetHandler(outTag); h != nil {
  433. isPickRoute = 2
  434. if route.GetRuleTag() == "" {
  435. errors.LogInfo(ctx, "taking detour [", outTag, "] for [", destination, "]")
  436. } else {
  437. errors.LogInfo(ctx, "Hit route rule: [", route.GetRuleTag(), "] so taking detour [", outTag, "] for [", destination, "]")
  438. }
  439. handler = h
  440. } else {
  441. errors.LogWarning(ctx, "non existing outTag: ", outTag)
  442. common.Close(link.Writer)
  443. common.Interrupt(link.Reader)
  444. return // DO NOT CHANGE: the traffic shouldn't be processed by default outbound if the specified outbound tag doesn't exist (yet), e.g., VLESS Reverse Proxy
  445. }
  446. } else {
  447. if !go_errors.Is(err, common.ErrNoClue) {
  448. errors.LogWarningInner(ctx, err, "get error during route pick ")
  449. common.Close(link.Writer)
  450. common.Interrupt(link.Reader)
  451. return
  452. }
  453. errors.LogInfo(ctx, "default route for ", destination)
  454. }
  455. }
  456. if handler == nil {
  457. handler = d.ohm.GetDefaultHandler()
  458. }
  459. if handler == nil {
  460. errors.LogInfo(ctx, "default outbound handler not exist")
  461. common.Close(link.Writer)
  462. common.Interrupt(link.Reader)
  463. return
  464. }
  465. ob.Tag = handler.Tag()
  466. if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
  467. if tag := handler.Tag(); tag != "" {
  468. if inTag == "" {
  469. accessMessage.Detour = tag
  470. } else if isPickRoute == 1 {
  471. accessMessage.Detour = inTag + " ==> " + tag
  472. } else if isPickRoute == 2 {
  473. accessMessage.Detour = inTag + " -> " + tag
  474. } else {
  475. accessMessage.Detour = inTag + " >> " + tag
  476. }
  477. }
  478. log.Record(accessMessage)
  479. }
  480. handler.Dispatch(ctx, link)
  481. }