conn_linux.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. //go:build linux
  2. // +build linux
  3. package netlink
  4. import (
  5. "os"
  6. "runtime"
  7. "syscall"
  8. "time"
  9. "unsafe"
  10. "github.com/mdlayher/socket"
  11. "golang.org/x/net/bpf"
  12. "golang.org/x/sys/unix"
  13. )
  14. var _ Socket = &conn{}
  15. // A conn is the Linux implementation of a netlink sockets connection.
  16. type conn struct {
  17. s *socket.Conn
  18. }
  19. // dial is the entry point for Dial. dial opens a netlink socket using
  20. // system calls, and returns its PID.
  21. func dial(family int, config *Config) (*conn, uint32, error) {
  22. if config == nil {
  23. config = &Config{}
  24. }
  25. // The caller has indicated it wants the netlink socket to be created
  26. // inside another network namespace.
  27. if config.NetNS != 0 {
  28. runtime.LockOSThread()
  29. defer runtime.UnlockOSThread()
  30. // Retrieve and store the calling OS thread's network namespace so
  31. // the thread can be reassigned to it after creating a socket in another
  32. // network namespace.
  33. threadNS, err := threadNetNS()
  34. if err != nil {
  35. return nil, 0, err
  36. }
  37. // Always close the netns handle created above.
  38. defer threadNS.Close()
  39. // Assign the current OS thread the goroutine is locked to to the given
  40. // network namespace.
  41. if err := threadNS.Set(config.NetNS); err != nil {
  42. return nil, 0, err
  43. }
  44. // Thread's namespace has been successfully set. Return the thread
  45. // back to its original namespace after attempting to create the
  46. // netlink socket.
  47. defer threadNS.Restore()
  48. }
  49. // Prepare the netlink socket.
  50. s, err := socket.Socket(unix.AF_NETLINK, unix.SOCK_RAW, family, "netlink")
  51. if err != nil {
  52. return nil, 0, err
  53. }
  54. return newConn(s, config)
  55. }
  56. // newConn binds a connection to netlink using the input *socket.Conn.
  57. func newConn(s *socket.Conn, config *Config) (*conn, uint32, error) {
  58. if config == nil {
  59. config = &Config{}
  60. }
  61. addr := &unix.SockaddrNetlink{
  62. Family: unix.AF_NETLINK,
  63. Groups: config.Groups,
  64. }
  65. // Socket must be closed in the event of any system call errors, to avoid
  66. // leaking file descriptors.
  67. if err := s.Bind(addr); err != nil {
  68. _ = s.Close()
  69. return nil, 0, err
  70. }
  71. sa, err := s.Getsockname()
  72. if err != nil {
  73. _ = s.Close()
  74. return nil, 0, err
  75. }
  76. return &conn{
  77. s: s,
  78. }, sa.(*unix.SockaddrNetlink).Pid, nil
  79. }
  80. // SendMessages serializes multiple Messages and sends them to netlink.
  81. func (c *conn) SendMessages(messages []Message) error {
  82. var buf []byte
  83. for _, m := range messages {
  84. b, err := m.MarshalBinary()
  85. if err != nil {
  86. return err
  87. }
  88. buf = append(buf, b...)
  89. }
  90. sa := &unix.SockaddrNetlink{Family: unix.AF_NETLINK}
  91. return c.s.Sendmsg(buf, nil, sa, 0)
  92. }
  93. // Send sends a single Message to netlink.
  94. func (c *conn) Send(m Message) error {
  95. b, err := m.MarshalBinary()
  96. if err != nil {
  97. return err
  98. }
  99. sa := &unix.SockaddrNetlink{Family: unix.AF_NETLINK}
  100. return c.s.Sendmsg(b, nil, sa, 0)
  101. }
  102. // Receive receives one or more Messages from netlink.
  103. func (c *conn) Receive() ([]Message, error) {
  104. b := make([]byte, os.Getpagesize())
  105. for {
  106. // Peek at the buffer to see how many bytes are available.
  107. //
  108. // TODO(mdlayher): deal with OOB message data if available, such as
  109. // when PacketInfo ConnOption is true.
  110. n, _, _, _, err := c.s.Recvmsg(b, nil, unix.MSG_PEEK)
  111. if err != nil {
  112. return nil, err
  113. }
  114. // Break when we can read all messages
  115. if n < len(b) {
  116. break
  117. }
  118. // Double in size if not enough bytes
  119. b = make([]byte, len(b)*2)
  120. }
  121. // Read out all available messages
  122. n, _, _, _, err := c.s.Recvmsg(b, nil, 0)
  123. if err != nil {
  124. return nil, err
  125. }
  126. raw, err := syscall.ParseNetlinkMessage(b[:nlmsgAlign(n)])
  127. if err != nil {
  128. return nil, err
  129. }
  130. msgs := make([]Message, 0, len(raw))
  131. for _, r := range raw {
  132. m := Message{
  133. Header: sysToHeader(r.Header),
  134. Data: r.Data,
  135. }
  136. msgs = append(msgs, m)
  137. }
  138. return msgs, nil
  139. }
  140. // Close closes the connection.
  141. func (c *conn) Close() error { return c.s.Close() }
  142. // JoinGroup joins a multicast group by ID.
  143. func (c *conn) JoinGroup(group uint32) error {
  144. return c.s.SetsockoptInt(unix.SOL_NETLINK, unix.NETLINK_ADD_MEMBERSHIP, int(group))
  145. }
  146. // LeaveGroup leaves a multicast group by ID.
  147. func (c *conn) LeaveGroup(group uint32) error {
  148. return c.s.SetsockoptInt(unix.SOL_NETLINK, unix.NETLINK_DROP_MEMBERSHIP, int(group))
  149. }
  150. // SetBPF attaches an assembled BPF program to a conn.
  151. func (c *conn) SetBPF(filter []bpf.RawInstruction) error { return c.s.SetBPF(filter) }
  152. // RemoveBPF removes a BPF filter from a conn.
  153. func (c *conn) RemoveBPF() error { return c.s.RemoveBPF() }
  154. // SetOption enables or disables a netlink socket option for the Conn.
  155. func (c *conn) SetOption(option ConnOption, enable bool) error {
  156. o, ok := linuxOption(option)
  157. if !ok {
  158. // Return the typical Linux error for an unknown ConnOption.
  159. return os.NewSyscallError("setsockopt", unix.ENOPROTOOPT)
  160. }
  161. var v int
  162. if enable {
  163. v = 1
  164. }
  165. return c.s.SetsockoptInt(unix.SOL_NETLINK, o, v)
  166. }
  167. func (c *conn) SetDeadline(t time.Time) error { return c.s.SetDeadline(t) }
  168. func (c *conn) SetReadDeadline(t time.Time) error { return c.s.SetReadDeadline(t) }
  169. func (c *conn) SetWriteDeadline(t time.Time) error { return c.s.SetWriteDeadline(t) }
  170. // SetReadBuffer sets the size of the operating system's receive buffer
  171. // associated with the Conn.
  172. func (c *conn) SetReadBuffer(bytes int) error { return c.s.SetReadBuffer(bytes) }
  173. // SetReadBuffer sets the size of the operating system's transmit buffer
  174. // associated with the Conn.
  175. func (c *conn) SetWriteBuffer(bytes int) error { return c.s.SetWriteBuffer(bytes) }
  176. // SyscallConn returns a raw network connection.
  177. func (c *conn) SyscallConn() (syscall.RawConn, error) { return c.s.SyscallConn() }
  178. // linuxOption converts a ConnOption to its Linux value.
  179. func linuxOption(o ConnOption) (int, bool) {
  180. switch o {
  181. case PacketInfo:
  182. return unix.NETLINK_PKTINFO, true
  183. case BroadcastError:
  184. return unix.NETLINK_BROADCAST_ERROR, true
  185. case NoENOBUFS:
  186. return unix.NETLINK_NO_ENOBUFS, true
  187. case ListenAllNSID:
  188. return unix.NETLINK_LISTEN_ALL_NSID, true
  189. case CapAcknowledge:
  190. return unix.NETLINK_CAP_ACK, true
  191. case ExtendedAcknowledge:
  192. return unix.NETLINK_EXT_ACK, true
  193. case GetStrictCheck:
  194. return unix.NETLINK_GET_STRICT_CHK, true
  195. default:
  196. return 0, false
  197. }
  198. }
  199. // sysToHeader converts a syscall.NlMsghdr to a Header.
  200. func sysToHeader(r syscall.NlMsghdr) Header {
  201. // NB: the memory layout of Header and syscall.NlMsgHdr must be
  202. // exactly the same for this unsafe cast to work
  203. return *(*Header)(unsafe.Pointer(&r))
  204. }
  205. // newError converts an error number from netlink into the appropriate
  206. // system call error for Linux.
  207. func newError(errno int) error {
  208. return syscall.Errno(errno)
  209. }