conn.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. // Copyright 2018 Google LLC. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package nftables
  15. import (
  16. "errors"
  17. "fmt"
  18. "sync"
  19. "github.com/google/nftables/binaryutil"
  20. "github.com/google/nftables/expr"
  21. "github.com/mdlayher/netlink"
  22. "github.com/mdlayher/netlink/nltest"
  23. "golang.org/x/sys/unix"
  24. )
  25. // A Conn represents a netlink connection of the nftables family.
  26. //
  27. // All methods return their input, so that variables can be defined from string
  28. // literals when desired.
  29. //
  30. // Commands are buffered. Flush sends all buffered commands in a single batch.
  31. type Conn struct {
  32. TestDial nltest.Func // for testing only; passed to nltest.Dial
  33. NetNS int // fd referencing the network namespace netlink will interact with.
  34. lasting bool // establish a lasting connection to be used across multiple netlink operations.
  35. mu sync.Mutex // protects the following state
  36. messages []netlink.Message
  37. err error
  38. nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol.
  39. }
  40. // ConnOption is an option to change the behavior of the nftables Conn returned by Open.
  41. type ConnOption func(*Conn)
  42. // New returns a netlink connection for querying and modifying nftables. Some
  43. // aspects of the new netlink connection can be configured using the options
  44. // WithNetNSFd, WithTestDial, and AsLasting.
  45. //
  46. // A lasting netlink connection should be closed by calling CloseLasting() to
  47. // close the underlying lasting netlink connection, cancelling all pending
  48. // operations using this connection.
  49. func New(opts ...ConnOption) (*Conn, error) {
  50. cc := &Conn{}
  51. for _, opt := range opts {
  52. opt(cc)
  53. }
  54. if !cc.lasting {
  55. return cc, nil
  56. }
  57. nlconn, err := cc.dialNetlink()
  58. if err != nil {
  59. return nil, err
  60. }
  61. cc.nlconn = nlconn
  62. return cc, nil
  63. }
  64. // AsLasting creates the new netlink connection as a lasting connection that is
  65. // reused across multiple netlink operations, instead of opening and closing the
  66. // underlying netlink connection only for the duration of a single netlink
  67. // operation.
  68. func AsLasting() ConnOption {
  69. return func(cc *Conn) {
  70. // We cannot create the underlying connection yet, as we are called
  71. // anywhere in the option processing chain and there might be later
  72. // options still modifying connection behavior.
  73. cc.lasting = true
  74. }
  75. }
  76. // WithNetNSFd sets the network namespace to create a new netlink connection to:
  77. // the fd must reference a network namespace.
  78. func WithNetNSFd(fd int) ConnOption {
  79. return func(cc *Conn) {
  80. cc.NetNS = fd
  81. }
  82. }
  83. // WithTestDial sets the specified nltest.Func when creating a new netlink
  84. // connection.
  85. func WithTestDial(f nltest.Func) ConnOption {
  86. return func(cc *Conn) {
  87. cc.TestDial = f
  88. }
  89. }
  90. // netlinkCloser is returned by netlinkConn(UnderLock) and must be called after
  91. // being done with the returned netlink connection in order to properly close
  92. // this connection, if necessary.
  93. type netlinkCloser func() error
  94. // netlinkConn returns a netlink connection together with a netlinkCloser that
  95. // later must be called by the caller when it doesn't need the returned netlink
  96. // connection anymore. The netlinkCloser will close the netlink connection when
  97. // necessary. If New has been told to create a lasting connection, then this
  98. // lasting netlink connection will be returned, otherwise a new "transient"
  99. // netlink connection will be opened and returned instead. netlinkConn must not
  100. // be called while the Conn.mu lock is currently helt (this will cause a
  101. // deadlock). Use netlinkConnUnderLock instead in such situations.
  102. func (cc *Conn) netlinkConn() (*netlink.Conn, netlinkCloser, error) {
  103. cc.mu.Lock()
  104. defer cc.mu.Unlock()
  105. return cc.netlinkConnUnderLock()
  106. }
  107. // netlinkConnUnderLock works like netlinkConn but must be called while holding
  108. // the Conn.mu lock.
  109. func (cc *Conn) netlinkConnUnderLock() (*netlink.Conn, netlinkCloser, error) {
  110. if cc.nlconn != nil {
  111. return cc.nlconn, func() error { return nil }, nil
  112. }
  113. nlconn, err := cc.dialNetlink()
  114. if err != nil {
  115. return nil, nil, err
  116. }
  117. return nlconn, func() error { return nlconn.Close() }, nil
  118. }
  119. func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([]netlink.Message, error) {
  120. if nlconn == nil {
  121. return nil, errors.New("netlink conn is not initialized")
  122. }
  123. // first receive will be the message that we expect
  124. reply, err := nlconn.Receive()
  125. if err != nil {
  126. return nil, err
  127. }
  128. if (sentMsgFlags & netlink.Acknowledge) == 0 {
  129. // we did not request an ack
  130. return reply, nil
  131. }
  132. if (sentMsgFlags & netlink.Dump) == netlink.Dump {
  133. // sent message has Dump flag set, there will be no acks
  134. // https://github.com/torvalds/linux/blob/7e062cda7d90543ac8c7700fc7c5527d0c0f22ad/net/netlink/af_netlink.c#L2387-L2390
  135. return reply, nil
  136. }
  137. // Dump flag is not set, we expect an ack
  138. ack, err := nlconn.Receive()
  139. if err != nil {
  140. return nil, err
  141. }
  142. if len(ack) == 0 {
  143. return nil, errors.New("received an empty ack")
  144. }
  145. msg := ack[0]
  146. if msg.Header.Type != netlink.Error {
  147. // acks should be delivered as NLMSG_ERROR
  148. return nil, fmt.Errorf("expected header %v, but got %v", netlink.Error, msg.Header.Type)
  149. }
  150. if binaryutil.BigEndian.Uint32(msg.Data[:4]) != 0 {
  151. // if errno field is not set to 0 (success), this is an error
  152. return nil, fmt.Errorf("error delivered in message: %v", msg.Data)
  153. }
  154. return reply, nil
  155. }
  156. // CloseLasting closes the lasting netlink connection that has been opened using
  157. // AsLasting option when creating this connection. If either no lasting netlink
  158. // connection has been opened or the lasting connection is already in the
  159. // process of closing or has been closed, CloseLasting will immediately return
  160. // without any error.
  161. //
  162. // CloseLasting will terminate all pending netlink operations using the lasting
  163. // connection.
  164. //
  165. // After closing a lasting connection, the connection will revert to using
  166. // on-demand transient netlink connections when calling further netlink
  167. // operations (such as GetTables).
  168. func (cc *Conn) CloseLasting() error {
  169. // Don't acquire the lock for the whole duration of the CloseLasting
  170. // operation, but instead only so long as to make sure to only run the
  171. // netlink socket close on the first time with a lasting netlink socket. As
  172. // there is only the New() constructor, but no Open() method, it's
  173. // impossible to reopen a lasting connection.
  174. cc.mu.Lock()
  175. nlconn := cc.nlconn
  176. cc.nlconn = nil
  177. cc.mu.Unlock()
  178. if nlconn != nil {
  179. return nlconn.Close()
  180. }
  181. return nil
  182. }
  183. // Flush sends all buffered commands in a single batch to nftables.
  184. func (cc *Conn) Flush() error {
  185. cc.mu.Lock()
  186. defer func() {
  187. cc.messages = nil
  188. cc.mu.Unlock()
  189. }()
  190. if len(cc.messages) == 0 {
  191. // Messages were already programmed, returning nil
  192. return nil
  193. }
  194. if cc.err != nil {
  195. return cc.err // serialization error
  196. }
  197. conn, closer, err := cc.netlinkConnUnderLock()
  198. if err != nil {
  199. return err
  200. }
  201. defer func() { _ = closer() }()
  202. if _, err := conn.SendMessages(batch(cc.messages)); err != nil {
  203. return fmt.Errorf("SendMessages: %w", err)
  204. }
  205. // Fetch the requested acknowledgement for each message we sent.
  206. for _, msg := range cc.messages {
  207. if msg.Header.Flags&netlink.Acknowledge == 0 {
  208. continue // message did not request an acknowledgement
  209. }
  210. if _, err := conn.Receive(); err != nil {
  211. return fmt.Errorf("conn.Receive: %w", err)
  212. }
  213. }
  214. return nil
  215. }
  216. // FlushRuleset flushes the entire ruleset. See also
  217. // https://wiki.nftables.org/wiki-nftables/index.php/Operations_at_ruleset_level
  218. func (cc *Conn) FlushRuleset() {
  219. cc.mu.Lock()
  220. defer cc.mu.Unlock()
  221. cc.messages = append(cc.messages, netlink.Message{
  222. Header: netlink.Header{
  223. Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),
  224. Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
  225. },
  226. Data: extraHeader(0, 0),
  227. })
  228. }
  229. func (cc *Conn) dialNetlink() (*netlink.Conn, error) {
  230. if cc.TestDial != nil {
  231. return nltest.Dial(cc.TestDial), nil
  232. }
  233. return netlink.Dial(unix.NETLINK_NETFILTER, &netlink.Config{NetNS: cc.NetNS})
  234. }
  235. func (cc *Conn) setErr(err error) {
  236. if cc.err != nil {
  237. return
  238. }
  239. cc.err = err
  240. }
  241. func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte {
  242. b, err := netlink.MarshalAttributes(attrs)
  243. if err != nil {
  244. cc.setErr(err)
  245. return nil
  246. }
  247. return b
  248. }
  249. func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
  250. b, err := expr.Marshal(fam, e)
  251. if err != nil {
  252. cc.setErr(err)
  253. return nil
  254. }
  255. return b
  256. }
  257. func batch(messages []netlink.Message) []netlink.Message {
  258. batch := []netlink.Message{
  259. {
  260. Header: netlink.Header{
  261. Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
  262. Flags: netlink.Request,
  263. },
  264. Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
  265. },
  266. }
  267. batch = append(batch, messages...)
  268. batch = append(batch, netlink.Message{
  269. Header: netlink.Header{
  270. Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
  271. Flags: netlink.Request,
  272. },
  273. Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
  274. })
  275. return batch
  276. }