| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394 |
- package rtnetlink
- import (
- "bytes"
- "encoding/binary"
- "errors"
- "net"
- "github.com/jsimonetti/rtnetlink/internal/unix"
- "github.com/mdlayher/netlink"
- )
- var (
- // errInvalidRuleMessage is returned when a RuleMessage is malformed.
- errInvalidRuleMessage = errors.New("rtnetlink RuleMessage is invalid or too short")
- // errInvalidRuleAttribute is returned when a RuleMessage contains an unknown attribute.
- errInvalidRuleAttribute = errors.New("rtnetlink RuleMessage contains an unknown Attribute")
- )
- var _ Message = &RuleMessage{}
- // A RuleMessage is a route netlink link message.
- type RuleMessage struct {
- // Address family
- Family uint8
- // Length of destination prefix
- DstLength uint8
- // Length of source prefix
- SrcLength uint8
- // Rule TOS
- TOS uint8
- // Routing table identifier
- Table uint8
- // Rule action
- Action uint8
- // Rule flags
- Flags uint32
- // Attributes List
- Attributes *RuleAttributes
- }
- // MarshalBinary marshals a LinkMessage into a byte slice.
- func (m *RuleMessage) MarshalBinary() ([]byte, error) {
- b := make([]byte, 12)
- // fib_rule_hdr
- b[0] = m.Family
- b[1] = m.DstLength
- b[2] = m.SrcLength
- b[3] = m.TOS
- b[4] = m.Table
- b[7] = m.Action
- nativeEndian.PutUint32(b[8:12], m.Flags)
- if m.Attributes != nil {
- ae := netlink.NewAttributeEncoder()
- ae.ByteOrder = nativeEndian
- err := m.Attributes.encode(ae)
- if err != nil {
- return nil, err
- }
- a, err := ae.Encode()
- if err != nil {
- return nil, err
- }
- return append(b, a...), nil
- }
- return b, nil
- }
- // UnmarshalBinary unmarshals the contents of a byte slice into a LinkMessage.
- func (m *RuleMessage) UnmarshalBinary(b []byte) error {
- l := len(b)
- if l < 12 {
- return errInvalidRuleMessage
- }
- m.Family = b[0]
- m.DstLength = b[1]
- m.SrcLength = b[2]
- m.TOS = b[3]
- m.Table = b[4]
- // b[5] and b[6] are reserved fields
- m.Action = b[7]
- m.Flags = nativeEndian.Uint32(b[8:12])
- if l > 12 {
- m.Attributes = &RuleAttributes{}
- ad, err := netlink.NewAttributeDecoder(b[12:])
- if err != nil {
- return err
- }
- ad.ByteOrder = nativeEndian
- return m.Attributes.decode(ad)
- }
- return nil
- }
- // rtMessage is an empty method to sattisfy the Message interface.
- func (*RuleMessage) rtMessage() {}
- // RuleService is used to retrieve rtnetlink family information.
- type RuleService struct {
- c *Conn
- }
- func (r *RuleService) execute(m Message, family uint16, flags netlink.HeaderFlags) ([]RuleMessage, error) {
- msgs, err := r.c.Execute(m, family, flags)
- rules := make([]RuleMessage, len(msgs))
- for i := range msgs {
- rules[i] = *msgs[i].(*RuleMessage)
- }
- return rules, err
- }
- // Add new rule
- func (r *RuleService) Add(req *RuleMessage) error {
- flags := netlink.Request | netlink.Create | netlink.Acknowledge | netlink.Excl
- _, err := r.c.Execute(req, unix.RTM_NEWRULE, flags)
- return err
- }
- // Replace or add new rule
- func (r *RuleService) Replace(req *RuleMessage) error {
- flags := netlink.Request | netlink.Create | netlink.Replace | netlink.Acknowledge
- _, err := r.c.Execute(req, unix.RTM_NEWRULE, flags)
- return err
- }
- // Delete existing rule
- func (r *RuleService) Delete(req *RuleMessage) error {
- flags := netlink.Request | netlink.Acknowledge
- _, err := r.c.Execute(req, unix.RTM_DELRULE, flags)
- return err
- }
- // Get Rule(s)
- func (r *RuleService) Get(req *RuleMessage) ([]RuleMessage, error) {
- flags := netlink.Request | netlink.DumpFiltered
- return r.execute(req, unix.RTM_GETRULE, flags)
- }
- // List all rules
- func (r *RuleService) List() ([]RuleMessage, error) {
- flags := netlink.Request | netlink.Dump
- return r.execute(&RuleMessage{}, unix.RTM_GETRULE, flags)
- }
- // RuleAttributes contains all attributes for a rule.
- type RuleAttributes struct {
- Src, Dst *net.IP
- IIFName, OIFName *string
- Goto *uint32
- Priority *uint32
- FwMark, FwMask *uint32
- SrcRealm *uint16
- DstRealm *uint16
- TunID *uint64
- Table *uint32
- L3MDev *uint8
- Protocol *uint8
- IPProto *uint8
- SuppressPrefixLen *uint32
- SuppressIFGroup *uint32
- UIDRange *RuleUIDRange
- SPortRange *RulePortRange
- DPortRange *RulePortRange
- }
- // unmarshalBinary unmarshals the contents of a byte slice into a RuleMessage.
- func (r *RuleAttributes) decode(ad *netlink.AttributeDecoder) error {
- for ad.Next() {
- switch ad.Type() {
- case unix.FRA_UNSPEC:
- // unused
- continue
- case unix.FRA_DST:
- r.Dst = &net.IP{}
- ad.Do(decodeIP(r.Dst))
- case unix.FRA_SRC:
- r.Src = &net.IP{}
- ad.Do(decodeIP(r.Src))
- case unix.FRA_IIFNAME:
- v := ad.String()
- r.IIFName = &v
- case unix.FRA_GOTO:
- v := ad.Uint32()
- r.Goto = &v
- case unix.FRA_UNUSED2:
- // unused
- continue
- case unix.FRA_PRIORITY:
- v := ad.Uint32()
- r.Priority = &v
- case unix.FRA_UNUSED3:
- // unused
- continue
- case unix.FRA_UNUSED4:
- // unused
- continue
- case unix.FRA_UNUSED5:
- // unused
- continue
- case unix.FRA_FWMARK:
- v := ad.Uint32()
- r.FwMark = &v
- case unix.FRA_FLOW:
- dst32 := ad.Uint32()
- src32 := uint32(dst32 >> 16)
- src32 &= 0xFFFF
- dst32 &= 0xFFFF
- src16 := uint16(src32)
- dst16 := uint16(dst32)
- r.SrcRealm = &src16
- r.DstRealm = &dst16
- case unix.FRA_TUN_ID:
- v := ad.Uint64()
- r.TunID = &v
- case unix.FRA_SUPPRESS_IFGROUP:
- v := ad.Uint32()
- r.SuppressIFGroup = &v
- case unix.FRA_SUPPRESS_PREFIXLEN:
- v := ad.Uint32()
- r.SuppressPrefixLen = &v
- case unix.FRA_TABLE:
- v := ad.Uint32()
- r.Table = &v
- case unix.FRA_FWMASK:
- v := ad.Uint32()
- r.FwMask = &v
- case unix.FRA_OIFNAME:
- v := ad.String()
- r.OIFName = &v
- case unix.FRA_PAD:
- // unused
- continue
- case unix.FRA_L3MDEV:
- v := ad.Uint8()
- r.L3MDev = &v
- case unix.FRA_UID_RANGE:
- r.UIDRange = &RuleUIDRange{}
- err := r.UIDRange.unmarshalBinary(ad.Bytes())
- if err != nil {
- return err
- }
- case unix.FRA_PROTOCOL:
- v := ad.Uint8()
- r.Protocol = &v
- case unix.FRA_IP_PROTO:
- v := ad.Uint8()
- r.IPProto = &v
- case unix.FRA_SPORT_RANGE:
- r.SPortRange = &RulePortRange{}
- err := r.SPortRange.unmarshalBinary(ad.Bytes())
- if err != nil {
- return err
- }
- case unix.FRA_DPORT_RANGE:
- r.DPortRange = &RulePortRange{}
- err := r.DPortRange.unmarshalBinary(ad.Bytes())
- if err != nil {
- return err
- }
- default:
- return errInvalidRuleAttribute
- }
- }
- return ad.Err()
- }
- // MarshalBinary marshals a RuleAttributes into a byte slice.
- func (r *RuleAttributes) encode(ae *netlink.AttributeEncoder) error {
- if r.Table != nil {
- ae.Uint32(unix.FRA_TABLE, *r.Table)
- }
- if r.Protocol != nil {
- ae.Uint8(unix.FRA_PROTOCOL, *r.Protocol)
- }
- if r.Src != nil {
- ae.Do(unix.FRA_SRC, encodeIP(*r.Src))
- }
- if r.Dst != nil {
- ae.Do(unix.FRA_DST, encodeIP(*r.Dst))
- }
- if r.IIFName != nil {
- ae.String(unix.FRA_IIFNAME, *r.IIFName)
- }
- if r.OIFName != nil {
- ae.String(unix.FRA_OIFNAME, *r.OIFName)
- }
- if r.Goto != nil {
- ae.Uint32(unix.FRA_GOTO, *r.Goto)
- }
- if r.Priority != nil {
- ae.Uint32(unix.FRA_PRIORITY, *r.Priority)
- }
- if r.FwMark != nil {
- ae.Uint32(unix.FRA_FWMARK, *r.FwMark)
- }
- if r.FwMask != nil {
- ae.Uint32(unix.FRA_FWMASK, *r.FwMask)
- }
- if r.DstRealm != nil {
- value := uint32(*r.DstRealm)
- if r.SrcRealm != nil {
- value |= (uint32(*r.SrcRealm&0xFFFF) << 16)
- }
- ae.Uint32(unix.FRA_FLOW, value)
- }
- if r.TunID != nil {
- ae.Uint64(unix.FRA_TUN_ID, *r.TunID)
- }
- if r.L3MDev != nil {
- ae.Uint8(unix.FRA_L3MDEV, *r.L3MDev)
- }
- if r.IPProto != nil {
- ae.Uint8(unix.FRA_IP_PROTO, *r.IPProto)
- }
- if r.SuppressIFGroup != nil {
- ae.Uint32(unix.FRA_SUPPRESS_IFGROUP, *r.SuppressIFGroup)
- }
- if r.SuppressPrefixLen != nil {
- ae.Uint32(unix.FRA_SUPPRESS_PREFIXLEN, *r.SuppressPrefixLen)
- }
- if r.UIDRange != nil {
- data, err := marshalRuleUIDRange(*r.UIDRange)
- if err != nil {
- return err
- }
- ae.Bytes(unix.FRA_UID_RANGE, data)
- }
- if r.SPortRange != nil {
- data, err := marshalRulePortRange(*r.SPortRange)
- if err != nil {
- return err
- }
- ae.Bytes(unix.FRA_SPORT_RANGE, data)
- }
- if r.DPortRange != nil {
- data, err := marshalRulePortRange(*r.DPortRange)
- if err != nil {
- return err
- }
- ae.Bytes(unix.FRA_DPORT_RANGE, data)
- }
- return nil
- }
- // RulePortRange defines start and end ports for a rule
- type RulePortRange struct {
- Start, End uint16
- }
- func (r *RulePortRange) unmarshalBinary(data []byte) error {
- b := bytes.NewReader(data)
- return binary.Read(b, nativeEndian, r)
- }
- func marshalRulePortRange(s RulePortRange) ([]byte, error) {
- var buf bytes.Buffer
- err := binary.Write(&buf, nativeEndian, s)
- return buf.Bytes(), err
- }
- // RuleUIDRange defines the start and end for UID matches
- type RuleUIDRange struct {
- Start, End uint16
- }
- func (r *RuleUIDRange) unmarshalBinary(data []byte) error {
- b := bytes.NewReader(data)
- return binary.Read(b, nativeEndian, r)
- }
- func marshalRuleUIDRange(s RuleUIDRange) ([]byte, error) {
- var buf bytes.Buffer
- err := binary.Write(&buf, nativeEndian, s)
- return buf.Bytes(), err
- }
|