router.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. package router
  2. import (
  3. "context"
  4. "sync"
  5. "github.com/xtls/xray-core/common"
  6. "github.com/xtls/xray-core/common/errors"
  7. "github.com/xtls/xray-core/common/serial"
  8. "github.com/xtls/xray-core/core"
  9. "github.com/xtls/xray-core/features/dns"
  10. "github.com/xtls/xray-core/features/outbound"
  11. "github.com/xtls/xray-core/features/routing"
  12. routing_dns "github.com/xtls/xray-core/features/routing/dns"
  13. )
  14. // Router is an implementation of routing.Router.
  15. type Router struct {
  16. domainStrategy Config_DomainStrategy
  17. rules []*Rule
  18. balancers map[string]*Balancer
  19. dns dns.Client
  20. ctx context.Context
  21. ohm outbound.Manager
  22. dispatcher routing.Dispatcher
  23. mu sync.Mutex
  24. }
  25. // Route is an implementation of routing.Route.
  26. type Route struct {
  27. routing.Context
  28. outboundGroupTags []string
  29. outboundTag string
  30. ruleTag string
  31. }
  32. // Init initializes the Router.
  33. func (r *Router) Init(ctx context.Context, config *Config, d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error {
  34. r.domainStrategy = config.DomainStrategy
  35. r.dns = d
  36. r.ctx = ctx
  37. r.ohm = ohm
  38. r.dispatcher = dispatcher
  39. r.balancers = make(map[string]*Balancer, len(config.BalancingRule))
  40. for _, rule := range config.BalancingRule {
  41. balancer, err := rule.Build(ohm, dispatcher)
  42. if err != nil {
  43. return err
  44. }
  45. balancer.InjectContext(ctx)
  46. r.balancers[rule.Tag] = balancer
  47. }
  48. r.rules = make([]*Rule, 0, len(config.Rule))
  49. for _, rule := range config.Rule {
  50. cond, err := rule.BuildCondition()
  51. if err != nil {
  52. return err
  53. }
  54. rr := &Rule{
  55. Condition: cond,
  56. Tag: rule.GetTag(),
  57. RuleTag: rule.GetRuleTag(),
  58. }
  59. btag := rule.GetBalancingTag()
  60. if len(btag) > 0 {
  61. brule, found := r.balancers[btag]
  62. if !found {
  63. return errors.New("balancer ", btag, " not found")
  64. }
  65. rr.Balancer = brule
  66. }
  67. r.rules = append(r.rules, rr)
  68. }
  69. return nil
  70. }
  71. // PickRoute implements routing.Router.
  72. func (r *Router) PickRoute(ctx routing.Context) (routing.Route, error) {
  73. rule, ctx, err := r.pickRouteInternal(ctx)
  74. if err != nil {
  75. return nil, err
  76. }
  77. tag, err := rule.GetTag()
  78. if err != nil {
  79. return nil, err
  80. }
  81. return &Route{Context: ctx, outboundTag: tag, ruleTag: rule.RuleTag}, nil
  82. }
  83. // AddRule implements routing.Router.
  84. func (r *Router) AddRule(config *serial.TypedMessage, shouldAppend bool) error {
  85. inst, err := config.GetInstance()
  86. if err != nil {
  87. return err
  88. }
  89. if c, ok := inst.(*Config); ok {
  90. return r.ReloadRules(c, shouldAppend)
  91. }
  92. return errors.New("AddRule: config type error")
  93. }
  94. func (r *Router) ReloadRules(config *Config, shouldAppend bool) error {
  95. r.mu.Lock()
  96. defer r.mu.Unlock()
  97. if !shouldAppend {
  98. r.balancers = make(map[string]*Balancer, len(config.BalancingRule))
  99. r.rules = make([]*Rule, 0, len(config.Rule))
  100. }
  101. for _, rule := range config.BalancingRule {
  102. _, found := r.balancers[rule.Tag]
  103. if found {
  104. return errors.New("duplicate balancer tag")
  105. }
  106. balancer, err := rule.Build(r.ohm, r.dispatcher)
  107. if err != nil {
  108. return err
  109. }
  110. balancer.InjectContext(r.ctx)
  111. r.balancers[rule.Tag] = balancer
  112. }
  113. for _, rule := range config.Rule {
  114. if r.RuleExists(rule.GetRuleTag()) {
  115. return errors.New("duplicate ruleTag ", rule.GetRuleTag())
  116. }
  117. cond, err := rule.BuildCondition()
  118. if err != nil {
  119. return err
  120. }
  121. rr := &Rule{
  122. Condition: cond,
  123. Tag: rule.GetTag(),
  124. RuleTag: rule.GetRuleTag(),
  125. }
  126. btag := rule.GetBalancingTag()
  127. if len(btag) > 0 {
  128. brule, found := r.balancers[btag]
  129. if !found {
  130. return errors.New("balancer ", btag, " not found")
  131. }
  132. rr.Balancer = brule
  133. }
  134. r.rules = append(r.rules, rr)
  135. }
  136. return nil
  137. }
  138. func (r *Router) RuleExists(tag string) bool {
  139. if tag != "" {
  140. for _, rule := range r.rules {
  141. if rule.RuleTag == tag {
  142. return true
  143. }
  144. }
  145. }
  146. return false
  147. }
  148. // RemoveRule implements routing.Router.
  149. func (r *Router) RemoveRule(tag string) error {
  150. r.mu.Lock()
  151. defer r.mu.Unlock()
  152. newRules := []*Rule{}
  153. if tag != "" {
  154. for _, rule := range r.rules {
  155. if rule.RuleTag != tag {
  156. newRules = append(newRules, rule)
  157. }
  158. }
  159. r.rules = newRules
  160. return nil
  161. }
  162. return errors.New("empty tag name!")
  163. }
  164. // ListRule implements routing.Router
  165. func (r *Router) ListRule() []routing.Route {
  166. r.mu.Lock()
  167. defer r.mu.Unlock()
  168. ruleList := make([]routing.Route, 0)
  169. for _, rule := range r.rules {
  170. ruleList = append(ruleList, &Route{
  171. outboundTag: rule.Tag,
  172. ruleTag: rule.RuleTag,
  173. })
  174. }
  175. return ruleList
  176. }
  177. func (r *Router) pickRouteInternal(ctx routing.Context) (*Rule, routing.Context, error) {
  178. // SkipDNSResolve is set from DNS module.
  179. // the DOH remote server maybe a domain name,
  180. // this prevents cycle resolving dead loop
  181. skipDNSResolve := ctx.GetSkipDNSResolve()
  182. if r.domainStrategy == Config_IpOnDemand && !skipDNSResolve {
  183. ctx = routing_dns.ContextWithDNSClient(ctx, r.dns)
  184. }
  185. for _, rule := range r.rules {
  186. if rule.Apply(ctx) {
  187. return rule, ctx, nil
  188. }
  189. }
  190. if r.domainStrategy != Config_IpIfNonMatch || len(ctx.GetTargetDomain()) == 0 || skipDNSResolve {
  191. return nil, ctx, common.ErrNoClue
  192. }
  193. ctx = routing_dns.ContextWithDNSClient(ctx, r.dns)
  194. // Try applying rules again if we have IPs.
  195. for _, rule := range r.rules {
  196. if rule.Apply(ctx) {
  197. return rule, ctx, nil
  198. }
  199. }
  200. return nil, ctx, common.ErrNoClue
  201. }
  202. // Start implements common.Runnable.
  203. func (r *Router) Start() error {
  204. return nil
  205. }
  206. // Close implements common.Closable.
  207. func (r *Router) Close() error {
  208. return nil
  209. }
  210. // Type implements common.HasType.
  211. func (*Router) Type() interface{} {
  212. return routing.RouterType()
  213. }
  214. // GetOutboundGroupTags implements routing.Route.
  215. func (r *Route) GetOutboundGroupTags() []string {
  216. return r.outboundGroupTags
  217. }
  218. // GetOutboundTag implements routing.Route.
  219. func (r *Route) GetOutboundTag() string {
  220. return r.outboundTag
  221. }
  222. func (r *Route) GetRuleTag() string {
  223. return r.ruleTag
  224. }
  225. func init() {
  226. common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
  227. r := new(Router)
  228. if err := core.RequireFeatures(ctx, func(d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error {
  229. return r.Init(ctx, config.(*Config), d, ohm, dispatcher)
  230. }); err != nil {
  231. return nil, err
  232. }
  233. return r, nil
  234. }))
  235. }