rule.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. package rtnetlink
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "net"
  7. "github.com/jsimonetti/rtnetlink/internal/unix"
  8. "github.com/mdlayher/netlink"
  9. )
  10. var (
  11. // errInvalidRuleMessage is returned when a RuleMessage is malformed.
  12. errInvalidRuleMessage = errors.New("rtnetlink RuleMessage is invalid or too short")
  13. // errInvalidRuleAttribute is returned when a RuleMessage contains an unknown attribute.
  14. errInvalidRuleAttribute = errors.New("rtnetlink RuleMessage contains an unknown Attribute")
  15. )
  16. var _ Message = &RuleMessage{}
  17. // A RuleMessage is a route netlink link message.
  18. type RuleMessage struct {
  19. // Address family
  20. Family uint8
  21. // Length of destination prefix
  22. DstLength uint8
  23. // Length of source prefix
  24. SrcLength uint8
  25. // Rule TOS
  26. TOS uint8
  27. // Routing table identifier
  28. Table uint8
  29. // Rule action
  30. Action uint8
  31. // Rule flags
  32. Flags uint32
  33. // Attributes List
  34. Attributes *RuleAttributes
  35. }
  36. // MarshalBinary marshals a LinkMessage into a byte slice.
  37. func (m *RuleMessage) MarshalBinary() ([]byte, error) {
  38. b := make([]byte, 12)
  39. // fib_rule_hdr
  40. b[0] = m.Family
  41. b[1] = m.DstLength
  42. b[2] = m.SrcLength
  43. b[3] = m.TOS
  44. b[4] = m.Table
  45. b[7] = m.Action
  46. nativeEndian.PutUint32(b[8:12], m.Flags)
  47. if m.Attributes != nil {
  48. ae := netlink.NewAttributeEncoder()
  49. ae.ByteOrder = nativeEndian
  50. err := m.Attributes.encode(ae)
  51. if err != nil {
  52. return nil, err
  53. }
  54. a, err := ae.Encode()
  55. if err != nil {
  56. return nil, err
  57. }
  58. return append(b, a...), nil
  59. }
  60. return b, nil
  61. }
  62. // UnmarshalBinary unmarshals the contents of a byte slice into a LinkMessage.
  63. func (m *RuleMessage) UnmarshalBinary(b []byte) error {
  64. l := len(b)
  65. if l < 12 {
  66. return errInvalidRuleMessage
  67. }
  68. m.Family = b[0]
  69. m.DstLength = b[1]
  70. m.SrcLength = b[2]
  71. m.TOS = b[3]
  72. m.Table = b[4]
  73. // b[5] and b[6] are reserved fields
  74. m.Action = b[7]
  75. m.Flags = nativeEndian.Uint32(b[8:12])
  76. if l > 12 {
  77. m.Attributes = &RuleAttributes{}
  78. ad, err := netlink.NewAttributeDecoder(b[12:])
  79. if err != nil {
  80. return err
  81. }
  82. ad.ByteOrder = nativeEndian
  83. return m.Attributes.decode(ad)
  84. }
  85. return nil
  86. }
  87. // rtMessage is an empty method to sattisfy the Message interface.
  88. func (*RuleMessage) rtMessage() {}
  89. // RuleService is used to retrieve rtnetlink family information.
  90. type RuleService struct {
  91. c *Conn
  92. }
  93. func (r *RuleService) execute(m Message, family uint16, flags netlink.HeaderFlags) ([]RuleMessage, error) {
  94. msgs, err := r.c.Execute(m, family, flags)
  95. rules := make([]RuleMessage, len(msgs))
  96. for i := range msgs {
  97. rules[i] = *msgs[i].(*RuleMessage)
  98. }
  99. return rules, err
  100. }
  101. // Add new rule
  102. func (r *RuleService) Add(req *RuleMessage) error {
  103. flags := netlink.Request | netlink.Create | netlink.Acknowledge | netlink.Excl
  104. _, err := r.c.Execute(req, unix.RTM_NEWRULE, flags)
  105. return err
  106. }
  107. // Replace or add new rule
  108. func (r *RuleService) Replace(req *RuleMessage) error {
  109. flags := netlink.Request | netlink.Create | netlink.Replace | netlink.Acknowledge
  110. _, err := r.c.Execute(req, unix.RTM_NEWRULE, flags)
  111. return err
  112. }
  113. // Delete existing rule
  114. func (r *RuleService) Delete(req *RuleMessage) error {
  115. flags := netlink.Request | netlink.Acknowledge
  116. _, err := r.c.Execute(req, unix.RTM_DELRULE, flags)
  117. return err
  118. }
  119. // Get Rule(s)
  120. func (r *RuleService) Get(req *RuleMessage) ([]RuleMessage, error) {
  121. flags := netlink.Request | netlink.DumpFiltered
  122. return r.execute(req, unix.RTM_GETRULE, flags)
  123. }
  124. // List all rules
  125. func (r *RuleService) List() ([]RuleMessage, error) {
  126. flags := netlink.Request | netlink.Dump
  127. return r.execute(&RuleMessage{}, unix.RTM_GETRULE, flags)
  128. }
  129. // RuleAttributes contains all attributes for a rule.
  130. type RuleAttributes struct {
  131. Src, Dst *net.IP
  132. IIFName, OIFName *string
  133. Goto *uint32
  134. Priority *uint32
  135. FwMark, FwMask *uint32
  136. SrcRealm *uint16
  137. DstRealm *uint16
  138. TunID *uint64
  139. Table *uint32
  140. L3MDev *uint8
  141. Protocol *uint8
  142. IPProto *uint8
  143. SuppressPrefixLen *uint32
  144. SuppressIFGroup *uint32
  145. UIDRange *RuleUIDRange
  146. SPortRange *RulePortRange
  147. DPortRange *RulePortRange
  148. }
  149. // unmarshalBinary unmarshals the contents of a byte slice into a RuleMessage.
  150. func (r *RuleAttributes) decode(ad *netlink.AttributeDecoder) error {
  151. for ad.Next() {
  152. switch ad.Type() {
  153. case unix.FRA_UNSPEC:
  154. // unused
  155. continue
  156. case unix.FRA_DST:
  157. r.Dst = &net.IP{}
  158. ad.Do(decodeIP(r.Dst))
  159. case unix.FRA_SRC:
  160. r.Src = &net.IP{}
  161. ad.Do(decodeIP(r.Src))
  162. case unix.FRA_IIFNAME:
  163. v := ad.String()
  164. r.IIFName = &v
  165. case unix.FRA_GOTO:
  166. v := ad.Uint32()
  167. r.Goto = &v
  168. case unix.FRA_UNUSED2:
  169. // unused
  170. continue
  171. case unix.FRA_PRIORITY:
  172. v := ad.Uint32()
  173. r.Priority = &v
  174. case unix.FRA_UNUSED3:
  175. // unused
  176. continue
  177. case unix.FRA_UNUSED4:
  178. // unused
  179. continue
  180. case unix.FRA_UNUSED5:
  181. // unused
  182. continue
  183. case unix.FRA_FWMARK:
  184. v := ad.Uint32()
  185. r.FwMark = &v
  186. case unix.FRA_FLOW:
  187. dst32 := ad.Uint32()
  188. src32 := uint32(dst32 >> 16)
  189. src32 &= 0xFFFF
  190. dst32 &= 0xFFFF
  191. src16 := uint16(src32)
  192. dst16 := uint16(dst32)
  193. r.SrcRealm = &src16
  194. r.DstRealm = &dst16
  195. case unix.FRA_TUN_ID:
  196. v := ad.Uint64()
  197. r.TunID = &v
  198. case unix.FRA_SUPPRESS_IFGROUP:
  199. v := ad.Uint32()
  200. r.SuppressIFGroup = &v
  201. case unix.FRA_SUPPRESS_PREFIXLEN:
  202. v := ad.Uint32()
  203. r.SuppressPrefixLen = &v
  204. case unix.FRA_TABLE:
  205. v := ad.Uint32()
  206. r.Table = &v
  207. case unix.FRA_FWMASK:
  208. v := ad.Uint32()
  209. r.FwMask = &v
  210. case unix.FRA_OIFNAME:
  211. v := ad.String()
  212. r.OIFName = &v
  213. case unix.FRA_PAD:
  214. // unused
  215. continue
  216. case unix.FRA_L3MDEV:
  217. v := ad.Uint8()
  218. r.L3MDev = &v
  219. case unix.FRA_UID_RANGE:
  220. r.UIDRange = &RuleUIDRange{}
  221. err := r.UIDRange.unmarshalBinary(ad.Bytes())
  222. if err != nil {
  223. return err
  224. }
  225. case unix.FRA_PROTOCOL:
  226. v := ad.Uint8()
  227. r.Protocol = &v
  228. case unix.FRA_IP_PROTO:
  229. v := ad.Uint8()
  230. r.IPProto = &v
  231. case unix.FRA_SPORT_RANGE:
  232. r.SPortRange = &RulePortRange{}
  233. err := r.SPortRange.unmarshalBinary(ad.Bytes())
  234. if err != nil {
  235. return err
  236. }
  237. case unix.FRA_DPORT_RANGE:
  238. r.DPortRange = &RulePortRange{}
  239. err := r.DPortRange.unmarshalBinary(ad.Bytes())
  240. if err != nil {
  241. return err
  242. }
  243. default:
  244. return errInvalidRuleAttribute
  245. }
  246. }
  247. return ad.Err()
  248. }
  249. // MarshalBinary marshals a RuleAttributes into a byte slice.
  250. func (r *RuleAttributes) encode(ae *netlink.AttributeEncoder) error {
  251. if r.Table != nil {
  252. ae.Uint32(unix.FRA_TABLE, *r.Table)
  253. }
  254. if r.Protocol != nil {
  255. ae.Uint8(unix.FRA_PROTOCOL, *r.Protocol)
  256. }
  257. if r.Src != nil {
  258. ae.Do(unix.FRA_SRC, encodeIP(*r.Src))
  259. }
  260. if r.Dst != nil {
  261. ae.Do(unix.FRA_DST, encodeIP(*r.Dst))
  262. }
  263. if r.IIFName != nil {
  264. ae.String(unix.FRA_IIFNAME, *r.IIFName)
  265. }
  266. if r.OIFName != nil {
  267. ae.String(unix.FRA_OIFNAME, *r.OIFName)
  268. }
  269. if r.Goto != nil {
  270. ae.Uint32(unix.FRA_GOTO, *r.Goto)
  271. }
  272. if r.Priority != nil {
  273. ae.Uint32(unix.FRA_PRIORITY, *r.Priority)
  274. }
  275. if r.FwMark != nil {
  276. ae.Uint32(unix.FRA_FWMARK, *r.FwMark)
  277. }
  278. if r.FwMask != nil {
  279. ae.Uint32(unix.FRA_FWMASK, *r.FwMask)
  280. }
  281. if r.DstRealm != nil {
  282. value := uint32(*r.DstRealm)
  283. if r.SrcRealm != nil {
  284. value |= (uint32(*r.SrcRealm&0xFFFF) << 16)
  285. }
  286. ae.Uint32(unix.FRA_FLOW, value)
  287. }
  288. if r.TunID != nil {
  289. ae.Uint64(unix.FRA_TUN_ID, *r.TunID)
  290. }
  291. if r.L3MDev != nil {
  292. ae.Uint8(unix.FRA_L3MDEV, *r.L3MDev)
  293. }
  294. if r.IPProto != nil {
  295. ae.Uint8(unix.FRA_IP_PROTO, *r.IPProto)
  296. }
  297. if r.SuppressIFGroup != nil {
  298. ae.Uint32(unix.FRA_SUPPRESS_IFGROUP, *r.SuppressIFGroup)
  299. }
  300. if r.SuppressPrefixLen != nil {
  301. ae.Uint32(unix.FRA_SUPPRESS_PREFIXLEN, *r.SuppressPrefixLen)
  302. }
  303. if r.UIDRange != nil {
  304. data, err := marshalRuleUIDRange(*r.UIDRange)
  305. if err != nil {
  306. return err
  307. }
  308. ae.Bytes(unix.FRA_UID_RANGE, data)
  309. }
  310. if r.SPortRange != nil {
  311. data, err := marshalRulePortRange(*r.SPortRange)
  312. if err != nil {
  313. return err
  314. }
  315. ae.Bytes(unix.FRA_SPORT_RANGE, data)
  316. }
  317. if r.DPortRange != nil {
  318. data, err := marshalRulePortRange(*r.DPortRange)
  319. if err != nil {
  320. return err
  321. }
  322. ae.Bytes(unix.FRA_DPORT_RANGE, data)
  323. }
  324. return nil
  325. }
  326. // RulePortRange defines start and end ports for a rule
  327. type RulePortRange struct {
  328. Start, End uint16
  329. }
  330. func (r *RulePortRange) unmarshalBinary(data []byte) error {
  331. b := bytes.NewReader(data)
  332. return binary.Read(b, nativeEndian, r)
  333. }
  334. func marshalRulePortRange(s RulePortRange) ([]byte, error) {
  335. var buf bytes.Buffer
  336. err := binary.Write(&buf, nativeEndian, s)
  337. return buf.Bytes(), err
  338. }
  339. // RuleUIDRange defines the start and end for UID matches
  340. type RuleUIDRange struct {
  341. Start, End uint16
  342. }
  343. func (r *RuleUIDRange) unmarshalBinary(data []byte) error {
  344. b := bytes.NewReader(data)
  345. return binary.Read(b, nativeEndian, r)
  346. }
  347. func marshalRuleUIDRange(s RuleUIDRange) ([]byte, error) {
  348. var buf bytes.Buffer
  349. err := binary.Write(&buf, nativeEndian, s)
  350. return buf.Bytes(), err
  351. }