| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315 |
- // Copyright 2018 Google LLC. All Rights Reserved.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package nftables
- import (
- "errors"
- "fmt"
- "sync"
- "github.com/google/nftables/binaryutil"
- "github.com/google/nftables/expr"
- "github.com/mdlayher/netlink"
- "github.com/mdlayher/netlink/nltest"
- "golang.org/x/sys/unix"
- )
- // A Conn represents a netlink connection of the nftables family.
- //
- // All methods return their input, so that variables can be defined from string
- // literals when desired.
- //
- // Commands are buffered. Flush sends all buffered commands in a single batch.
- type Conn struct {
- TestDial nltest.Func // for testing only; passed to nltest.Dial
- NetNS int // fd referencing the network namespace netlink will interact with.
- lasting bool // establish a lasting connection to be used across multiple netlink operations.
- mu sync.Mutex // protects the following state
- messages []netlink.Message
- err error
- nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol.
- }
- // ConnOption is an option to change the behavior of the nftables Conn returned by Open.
- type ConnOption func(*Conn)
- // New returns a netlink connection for querying and modifying nftables. Some
- // aspects of the new netlink connection can be configured using the options
- // WithNetNSFd, WithTestDial, and AsLasting.
- //
- // A lasting netlink connection should be closed by calling CloseLasting() to
- // close the underlying lasting netlink connection, cancelling all pending
- // operations using this connection.
- func New(opts ...ConnOption) (*Conn, error) {
- cc := &Conn{}
- for _, opt := range opts {
- opt(cc)
- }
- if !cc.lasting {
- return cc, nil
- }
- nlconn, err := cc.dialNetlink()
- if err != nil {
- return nil, err
- }
- cc.nlconn = nlconn
- return cc, nil
- }
- // AsLasting creates the new netlink connection as a lasting connection that is
- // reused across multiple netlink operations, instead of opening and closing the
- // underlying netlink connection only for the duration of a single netlink
- // operation.
- func AsLasting() ConnOption {
- return func(cc *Conn) {
- // We cannot create the underlying connection yet, as we are called
- // anywhere in the option processing chain and there might be later
- // options still modifying connection behavior.
- cc.lasting = true
- }
- }
- // WithNetNSFd sets the network namespace to create a new netlink connection to:
- // the fd must reference a network namespace.
- func WithNetNSFd(fd int) ConnOption {
- return func(cc *Conn) {
- cc.NetNS = fd
- }
- }
- // WithTestDial sets the specified nltest.Func when creating a new netlink
- // connection.
- func WithTestDial(f nltest.Func) ConnOption {
- return func(cc *Conn) {
- cc.TestDial = f
- }
- }
- // netlinkCloser is returned by netlinkConn(UnderLock) and must be called after
- // being done with the returned netlink connection in order to properly close
- // this connection, if necessary.
- type netlinkCloser func() error
- // netlinkConn returns a netlink connection together with a netlinkCloser that
- // later must be called by the caller when it doesn't need the returned netlink
- // connection anymore. The netlinkCloser will close the netlink connection when
- // necessary. If New has been told to create a lasting connection, then this
- // lasting netlink connection will be returned, otherwise a new "transient"
- // netlink connection will be opened and returned instead. netlinkConn must not
- // be called while the Conn.mu lock is currently helt (this will cause a
- // deadlock). Use netlinkConnUnderLock instead in such situations.
- func (cc *Conn) netlinkConn() (*netlink.Conn, netlinkCloser, error) {
- cc.mu.Lock()
- defer cc.mu.Unlock()
- return cc.netlinkConnUnderLock()
- }
- // netlinkConnUnderLock works like netlinkConn but must be called while holding
- // the Conn.mu lock.
- func (cc *Conn) netlinkConnUnderLock() (*netlink.Conn, netlinkCloser, error) {
- if cc.nlconn != nil {
- return cc.nlconn, func() error { return nil }, nil
- }
- nlconn, err := cc.dialNetlink()
- if err != nil {
- return nil, nil, err
- }
- return nlconn, func() error { return nlconn.Close() }, nil
- }
- func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([]netlink.Message, error) {
- if nlconn == nil {
- return nil, errors.New("netlink conn is not initialized")
- }
- // first receive will be the message that we expect
- reply, err := nlconn.Receive()
- if err != nil {
- return nil, err
- }
- if (sentMsgFlags & netlink.Acknowledge) == 0 {
- // we did not request an ack
- return reply, nil
- }
- if (sentMsgFlags & netlink.Dump) == netlink.Dump {
- // sent message has Dump flag set, there will be no acks
- // https://github.com/torvalds/linux/blob/7e062cda7d90543ac8c7700fc7c5527d0c0f22ad/net/netlink/af_netlink.c#L2387-L2390
- return reply, nil
- }
- // Dump flag is not set, we expect an ack
- ack, err := nlconn.Receive()
- if err != nil {
- return nil, err
- }
- if len(ack) == 0 {
- return nil, errors.New("received an empty ack")
- }
- msg := ack[0]
- if msg.Header.Type != netlink.Error {
- // acks should be delivered as NLMSG_ERROR
- return nil, fmt.Errorf("expected header %v, but got %v", netlink.Error, msg.Header.Type)
- }
- if binaryutil.BigEndian.Uint32(msg.Data[:4]) != 0 {
- // if errno field is not set to 0 (success), this is an error
- return nil, fmt.Errorf("error delivered in message: %v", msg.Data)
- }
- return reply, nil
- }
- // CloseLasting closes the lasting netlink connection that has been opened using
- // AsLasting option when creating this connection. If either no lasting netlink
- // connection has been opened or the lasting connection is already in the
- // process of closing or has been closed, CloseLasting will immediately return
- // without any error.
- //
- // CloseLasting will terminate all pending netlink operations using the lasting
- // connection.
- //
- // After closing a lasting connection, the connection will revert to using
- // on-demand transient netlink connections when calling further netlink
- // operations (such as GetTables).
- func (cc *Conn) CloseLasting() error {
- // Don't acquire the lock for the whole duration of the CloseLasting
- // operation, but instead only so long as to make sure to only run the
- // netlink socket close on the first time with a lasting netlink socket. As
- // there is only the New() constructor, but no Open() method, it's
- // impossible to reopen a lasting connection.
- cc.mu.Lock()
- nlconn := cc.nlconn
- cc.nlconn = nil
- cc.mu.Unlock()
- if nlconn != nil {
- return nlconn.Close()
- }
- return nil
- }
- // Flush sends all buffered commands in a single batch to nftables.
- func (cc *Conn) Flush() error {
- cc.mu.Lock()
- defer func() {
- cc.messages = nil
- cc.mu.Unlock()
- }()
- if len(cc.messages) == 0 {
- // Messages were already programmed, returning nil
- return nil
- }
- if cc.err != nil {
- return cc.err // serialization error
- }
- conn, closer, err := cc.netlinkConnUnderLock()
- if err != nil {
- return err
- }
- defer func() { _ = closer() }()
- if _, err := conn.SendMessages(batch(cc.messages)); err != nil {
- return fmt.Errorf("SendMessages: %w", err)
- }
- // Fetch the requested acknowledgement for each message we sent.
- for _, msg := range cc.messages {
- if msg.Header.Flags&netlink.Acknowledge == 0 {
- continue // message did not request an acknowledgement
- }
- if _, err := conn.Receive(); err != nil {
- return fmt.Errorf("conn.Receive: %w", err)
- }
- }
- return nil
- }
- // FlushRuleset flushes the entire ruleset. See also
- // https://wiki.nftables.org/wiki-nftables/index.php/Operations_at_ruleset_level
- func (cc *Conn) FlushRuleset() {
- cc.mu.Lock()
- defer cc.mu.Unlock()
- cc.messages = append(cc.messages, netlink.Message{
- Header: netlink.Header{
- Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),
- Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
- },
- Data: extraHeader(0, 0),
- })
- }
- func (cc *Conn) dialNetlink() (*netlink.Conn, error) {
- if cc.TestDial != nil {
- return nltest.Dial(cc.TestDial), nil
- }
- return netlink.Dial(unix.NETLINK_NETFILTER, &netlink.Config{NetNS: cc.NetNS})
- }
- func (cc *Conn) setErr(err error) {
- if cc.err != nil {
- return
- }
- cc.err = err
- }
- func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte {
- b, err := netlink.MarshalAttributes(attrs)
- if err != nil {
- cc.setErr(err)
- return nil
- }
- return b
- }
- func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
- b, err := expr.Marshal(fam, e)
- if err != nil {
- cc.setErr(err)
- return nil
- }
- return b
- }
- func batch(messages []netlink.Message) []netlink.Message {
- batch := []netlink.Message{
- {
- Header: netlink.Header{
- Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
- Flags: netlink.Request,
- },
- Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
- },
- }
- batch = append(batch, messages...)
- batch = append(batch, netlink.Message{
- Header: netlink.Header{
- Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
- Flags: netlink.Request,
- },
- Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
- })
- return batch
- }
|