conn_linux.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. //go:build linux
  2. // +build linux
  3. package netlink
  4. import (
  5. "context"
  6. "os"
  7. "syscall"
  8. "time"
  9. "unsafe"
  10. "github.com/mdlayher/socket"
  11. "golang.org/x/net/bpf"
  12. "golang.org/x/sys/unix"
  13. )
  14. var _ Socket = &conn{}
  15. // A conn is the Linux implementation of a netlink sockets connection.
  16. type conn struct {
  17. s *socket.Conn
  18. }
  19. // dial is the entry point for Dial. dial opens a netlink socket using
  20. // system calls, and returns its PID.
  21. func dial(family int, config *Config) (*conn, uint32, error) {
  22. if config == nil {
  23. config = &Config{}
  24. }
  25. // Prepare the netlink socket.
  26. s, err := socket.Socket(
  27. unix.AF_NETLINK,
  28. unix.SOCK_RAW,
  29. family,
  30. "netlink",
  31. &socket.Config{NetNS: config.NetNS},
  32. )
  33. if err != nil {
  34. return nil, 0, err
  35. }
  36. return newConn(s, config)
  37. }
  38. // newConn binds a connection to netlink using the input *socket.Conn.
  39. func newConn(s *socket.Conn, config *Config) (*conn, uint32, error) {
  40. if config == nil {
  41. config = &Config{}
  42. }
  43. addr := &unix.SockaddrNetlink{
  44. Family: unix.AF_NETLINK,
  45. Groups: config.Groups,
  46. Pid: config.PID,
  47. }
  48. // Socket must be closed in the event of any system call errors, to avoid
  49. // leaking file descriptors.
  50. if err := s.Bind(addr); err != nil {
  51. _ = s.Close()
  52. return nil, 0, err
  53. }
  54. sa, err := s.Getsockname()
  55. if err != nil {
  56. _ = s.Close()
  57. return nil, 0, err
  58. }
  59. c := &conn{s: s}
  60. if config.Strict {
  61. // The caller has requested the strict option set. Historically we have
  62. // recommended checking for ENOPROTOOPT if the kernel does not support
  63. // the option in question, but that may result in a silent failure and
  64. // unexpected behavior for the user.
  65. //
  66. // Treat any error here as a fatal error, and require the caller to deal
  67. // with it.
  68. for _, o := range []ConnOption{ExtendedAcknowledge, GetStrictCheck} {
  69. if err := c.SetOption(o, true); err != nil {
  70. _ = c.Close()
  71. return nil, 0, err
  72. }
  73. }
  74. }
  75. return c, sa.(*unix.SockaddrNetlink).Pid, nil
  76. }
  77. // SendMessages serializes multiple Messages and sends them to netlink.
  78. func (c *conn) SendMessages(messages []Message) error {
  79. var buf []byte
  80. for _, m := range messages {
  81. b, err := m.MarshalBinary()
  82. if err != nil {
  83. return err
  84. }
  85. buf = append(buf, b...)
  86. }
  87. sa := &unix.SockaddrNetlink{Family: unix.AF_NETLINK}
  88. _, err := c.s.Sendmsg(context.Background(), buf, nil, sa, 0)
  89. return err
  90. }
  91. // Send sends a single Message to netlink.
  92. func (c *conn) Send(m Message) error {
  93. b, err := m.MarshalBinary()
  94. if err != nil {
  95. return err
  96. }
  97. sa := &unix.SockaddrNetlink{Family: unix.AF_NETLINK}
  98. _, err = c.s.Sendmsg(context.Background(), b, nil, sa, 0)
  99. return err
  100. }
  101. // Receive receives one or more Messages from netlink.
  102. func (c *conn) Receive() ([]Message, error) {
  103. b := make([]byte, os.Getpagesize())
  104. for {
  105. // Peek at the buffer to see how many bytes are available.
  106. //
  107. // TODO(mdlayher): deal with OOB message data if available, such as
  108. // when PacketInfo ConnOption is true.
  109. n, _, _, _, err := c.s.Recvmsg(context.Background(), b, nil, unix.MSG_PEEK)
  110. if err != nil {
  111. return nil, err
  112. }
  113. // Break when we can read all messages
  114. if n < len(b) {
  115. break
  116. }
  117. // Double in size if not enough bytes
  118. b = make([]byte, len(b)*2)
  119. }
  120. // Read out all available messages
  121. n, _, _, _, err := c.s.Recvmsg(context.Background(), b, nil, 0)
  122. if err != nil {
  123. return nil, err
  124. }
  125. raw, err := syscall.ParseNetlinkMessage(b[:nlmsgAlign(n)])
  126. if err != nil {
  127. return nil, err
  128. }
  129. msgs := make([]Message, 0, len(raw))
  130. for _, r := range raw {
  131. m := Message{
  132. Header: sysToHeader(r.Header),
  133. Data: r.Data,
  134. }
  135. msgs = append(msgs, m)
  136. }
  137. return msgs, nil
  138. }
  139. // Close closes the connection.
  140. func (c *conn) Close() error { return c.s.Close() }
  141. // JoinGroup joins a multicast group by ID.
  142. func (c *conn) JoinGroup(group uint32) error {
  143. return c.s.SetsockoptInt(unix.SOL_NETLINK, unix.NETLINK_ADD_MEMBERSHIP, int(group))
  144. }
  145. // LeaveGroup leaves a multicast group by ID.
  146. func (c *conn) LeaveGroup(group uint32) error {
  147. return c.s.SetsockoptInt(unix.SOL_NETLINK, unix.NETLINK_DROP_MEMBERSHIP, int(group))
  148. }
  149. // SetBPF attaches an assembled BPF program to a conn.
  150. func (c *conn) SetBPF(filter []bpf.RawInstruction) error { return c.s.SetBPF(filter) }
  151. // RemoveBPF removes a BPF filter from a conn.
  152. func (c *conn) RemoveBPF() error { return c.s.RemoveBPF() }
  153. // SetOption enables or disables a netlink socket option for the Conn.
  154. func (c *conn) SetOption(option ConnOption, enable bool) error {
  155. o, ok := linuxOption(option)
  156. if !ok {
  157. // Return the typical Linux error for an unknown ConnOption.
  158. return os.NewSyscallError("setsockopt", unix.ENOPROTOOPT)
  159. }
  160. var v int
  161. if enable {
  162. v = 1
  163. }
  164. return c.s.SetsockoptInt(unix.SOL_NETLINK, o, v)
  165. }
  166. func (c *conn) SetDeadline(t time.Time) error { return c.s.SetDeadline(t) }
  167. func (c *conn) SetReadDeadline(t time.Time) error { return c.s.SetReadDeadline(t) }
  168. func (c *conn) SetWriteDeadline(t time.Time) error { return c.s.SetWriteDeadline(t) }
  169. // SetReadBuffer sets the size of the operating system's receive buffer
  170. // associated with the Conn.
  171. func (c *conn) SetReadBuffer(bytes int) error { return c.s.SetReadBuffer(bytes) }
  172. // SetReadBuffer sets the size of the operating system's transmit buffer
  173. // associated with the Conn.
  174. func (c *conn) SetWriteBuffer(bytes int) error { return c.s.SetWriteBuffer(bytes) }
  175. // SyscallConn returns a raw network connection.
  176. func (c *conn) SyscallConn() (syscall.RawConn, error) { return c.s.SyscallConn() }
  177. // linuxOption converts a ConnOption to its Linux value.
  178. func linuxOption(o ConnOption) (int, bool) {
  179. switch o {
  180. case PacketInfo:
  181. return unix.NETLINK_PKTINFO, true
  182. case BroadcastError:
  183. return unix.NETLINK_BROADCAST_ERROR, true
  184. case NoENOBUFS:
  185. return unix.NETLINK_NO_ENOBUFS, true
  186. case ListenAllNSID:
  187. return unix.NETLINK_LISTEN_ALL_NSID, true
  188. case CapAcknowledge:
  189. return unix.NETLINK_CAP_ACK, true
  190. case ExtendedAcknowledge:
  191. return unix.NETLINK_EXT_ACK, true
  192. case GetStrictCheck:
  193. return unix.NETLINK_GET_STRICT_CHK, true
  194. default:
  195. return 0, false
  196. }
  197. }
  198. // sysToHeader converts a syscall.NlMsghdr to a Header.
  199. func sysToHeader(r syscall.NlMsghdr) Header {
  200. // NB: the memory layout of Header and syscall.NlMsgHdr must be
  201. // exactly the same for this unsafe cast to work
  202. return *(*Header)(unsafe.Pointer(&r))
  203. }
  204. // newError converts an error number from netlink into the appropriate
  205. // system call error for Linux.
  206. func newError(errno int) error {
  207. return syscall.Errno(errno)
  208. }