obj.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. // Copyright 2018 Google LLC. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package nftables
  15. import (
  16. "encoding/binary"
  17. "fmt"
  18. "github.com/mdlayher/netlink"
  19. "golang.org/x/sys/unix"
  20. )
  21. var objHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ)
  22. // Obj represents a netfilter stateful object. See also
  23. // https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects
  24. type Obj interface {
  25. table() *Table
  26. family() TableFamily
  27. unmarshal(*netlink.AttributeDecoder) error
  28. marshal(data bool) ([]byte, error)
  29. }
  30. // AddObject adds the specified Obj. Alias of AddObj.
  31. func (cc *Conn) AddObject(o Obj) Obj {
  32. return cc.AddObj(o)
  33. }
  34. // AddObj adds the specified Obj. See also
  35. // https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects
  36. func (cc *Conn) AddObj(o Obj) Obj {
  37. cc.mu.Lock()
  38. defer cc.mu.Unlock()
  39. data, err := o.marshal(true)
  40. if err != nil {
  41. cc.setErr(err)
  42. return nil
  43. }
  44. cc.messages = append(cc.messages, netlink.Message{
  45. Header: netlink.Header{
  46. Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ),
  47. Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
  48. },
  49. Data: append(extraHeader(uint8(o.family()), 0), data...),
  50. })
  51. return o
  52. }
  53. // DeleteObject deletes the specified Obj
  54. func (cc *Conn) DeleteObject(o Obj) {
  55. cc.mu.Lock()
  56. defer cc.mu.Unlock()
  57. data, err := o.marshal(false)
  58. if err != nil {
  59. cc.setErr(err)
  60. return
  61. }
  62. data = append(data, cc.marshalAttr([]netlink.Attribute{{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA}})...)
  63. cc.messages = append(cc.messages, netlink.Message{
  64. Header: netlink.Header{
  65. Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ),
  66. Flags: netlink.Request | netlink.Acknowledge,
  67. },
  68. Data: append(extraHeader(uint8(o.family()), 0), data...),
  69. })
  70. }
  71. // GetObj is a legacy method that return all Obj that belongs
  72. // to the same table as the given one
  73. func (cc *Conn) GetObj(o Obj) ([]Obj, error) {
  74. return cc.getObj(nil, o.table(), unix.NFT_MSG_GETOBJ)
  75. }
  76. // GetObjReset is a legacy method that reset all Obj that belongs
  77. // the same table as the given one
  78. func (cc *Conn) GetObjReset(o Obj) ([]Obj, error) {
  79. return cc.getObj(nil, o.table(), unix.NFT_MSG_GETOBJ_RESET)
  80. }
  81. // GetObject gets the specified Object
  82. func (cc *Conn) GetObject(o Obj) (Obj, error) {
  83. objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ)
  84. if len(objs) == 0 {
  85. return nil, err
  86. }
  87. return objs[0], err
  88. }
  89. // GetObjects get all the Obj that belongs to the given table
  90. func (cc *Conn) GetObjects(t *Table) ([]Obj, error) {
  91. return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ)
  92. }
  93. // ResetObject reset the given Obj
  94. func (cc *Conn) ResetObject(o Obj) (Obj, error) {
  95. objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ_RESET)
  96. if len(objs) == 0 {
  97. return nil, err
  98. }
  99. return objs[0], err
  100. }
  101. // ResetObjects reset all the Obj that belongs to the given table
  102. func (cc *Conn) ResetObjects(t *Table) ([]Obj, error) {
  103. return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ_RESET)
  104. }
  105. func objFromMsg(msg netlink.Message) (Obj, error) {
  106. if got, want := msg.Header.Type, objHeaderType; got != want {
  107. return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
  108. }
  109. ad, err := netlink.NewAttributeDecoder(msg.Data[4:])
  110. if err != nil {
  111. return nil, err
  112. }
  113. ad.ByteOrder = binary.BigEndian
  114. var (
  115. table *Table
  116. name string
  117. objectType uint32
  118. )
  119. const NFT_OBJECT_COUNTER = 1 // TODO: get into x/sys/unix
  120. for ad.Next() {
  121. switch ad.Type() {
  122. case unix.NFTA_OBJ_TABLE:
  123. table = &Table{Name: ad.String(), Family: TableFamily(msg.Data[0])}
  124. case unix.NFTA_OBJ_NAME:
  125. name = ad.String()
  126. case unix.NFTA_OBJ_TYPE:
  127. objectType = ad.Uint32()
  128. case unix.NFTA_OBJ_DATA:
  129. switch objectType {
  130. case NFT_OBJECT_COUNTER:
  131. o := CounterObj{
  132. Table: table,
  133. Name: name,
  134. }
  135. ad.Do(func(b []byte) error {
  136. ad, err := netlink.NewAttributeDecoder(b)
  137. if err != nil {
  138. return err
  139. }
  140. ad.ByteOrder = binary.BigEndian
  141. return o.unmarshal(ad)
  142. })
  143. return &o, ad.Err()
  144. }
  145. }
  146. }
  147. if err := ad.Err(); err != nil {
  148. return nil, err
  149. }
  150. return nil, fmt.Errorf("malformed stateful object")
  151. }
  152. func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) {
  153. conn, closer, err := cc.netlinkConn()
  154. if err != nil {
  155. return nil, err
  156. }
  157. defer func() { _ = closer() }()
  158. var data []byte
  159. var flags netlink.HeaderFlags
  160. if o != nil {
  161. data, err = o.marshal(false)
  162. } else {
  163. flags = netlink.Dump
  164. data, err = netlink.MarshalAttributes([]netlink.Attribute{
  165. {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")},
  166. })
  167. }
  168. if err != nil {
  169. return nil, err
  170. }
  171. message := netlink.Message{
  172. Header: netlink.Header{
  173. Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | msgType),
  174. Flags: netlink.Request | netlink.Acknowledge | flags,
  175. },
  176. Data: append(extraHeader(uint8(t.Family), 0), data...),
  177. }
  178. if _, err := conn.SendMessages([]netlink.Message{message}); err != nil {
  179. return nil, fmt.Errorf("SendMessages: %v", err)
  180. }
  181. reply, err := receiveAckAware(conn, message.Header.Flags)
  182. if err != nil {
  183. return nil, fmt.Errorf("Receive: %v", err)
  184. }
  185. var objs []Obj
  186. for _, msg := range reply {
  187. o, err := objFromMsg(msg)
  188. if err != nil {
  189. return nil, err
  190. }
  191. objs = append(objs, o)
  192. }
  193. return objs, nil
  194. }