nltest.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. // Package nltest provides utilities for netlink testing.
  2. package nltest
  3. import (
  4. "fmt"
  5. "io"
  6. "os"
  7. "github.com/mdlayher/netlink"
  8. "github.com/mdlayher/netlink/nlenc"
  9. )
  10. // PID is the netlink header PID value assigned by nltest.
  11. const PID = 1
  12. // MustMarshalAttributes marshals a slice of netlink.Attributes to their binary
  13. // format, but panics if any errors occur.
  14. func MustMarshalAttributes(attrs []netlink.Attribute) []byte {
  15. b, err := netlink.MarshalAttributes(attrs)
  16. if err != nil {
  17. panic(fmt.Sprintf("failed to marshal attributes to binary: %v", err))
  18. }
  19. return b
  20. }
  21. // Multipart sends a slice of netlink.Messages to the caller as a
  22. // netlink multi-part message. If less than two messages are present,
  23. // the messages are not altered.
  24. func Multipart(msgs []netlink.Message) ([]netlink.Message, error) {
  25. if len(msgs) < 2 {
  26. return msgs, nil
  27. }
  28. for i := range msgs {
  29. // Last message has header type "done" in addition to multi-part flag.
  30. if i == len(msgs)-1 {
  31. msgs[i].Header.Type = netlink.Done
  32. }
  33. msgs[i].Header.Flags |= netlink.Multi
  34. }
  35. return msgs, nil
  36. }
  37. // Error returns a netlink error to the caller with the specified error
  38. // number, in the body of the specified request message.
  39. func Error(number int, reqs []netlink.Message) ([]netlink.Message, error) {
  40. req := reqs[0]
  41. req.Header.Length += 4
  42. req.Header.Type = netlink.Error
  43. errno := -1 * int32(number)
  44. req.Data = append(nlenc.Int32Bytes(errno), req.Data...)
  45. return []netlink.Message{req}, nil
  46. }
  47. // A Func is a function that can be used to test netlink.Conn interactions.
  48. // The function can choose to return zero or more netlink messages, or an
  49. // error if needed.
  50. //
  51. // For a netlink request/response interaction, a request req is populated by
  52. // netlink.Conn.Send and passed to the function.
  53. //
  54. // For multicast interactions, an empty request req is passed to the function
  55. // when netlink.Conn.Receive is called.
  56. //
  57. // If a Func returns an error, the error will be returned as-is to the caller.
  58. // If no messages and io.EOF are returned, no messages and no error will be
  59. // returned to the caller, simulating a multi-part message with no data.
  60. type Func func(req []netlink.Message) ([]netlink.Message, error)
  61. // Dial sets up a netlink.Conn for testing using the specified Func. All requests
  62. // sent from the connection will be passed to the Func. The connection should be
  63. // closed as usual when it is no longer needed.
  64. func Dial(fn Func) *netlink.Conn {
  65. sock := &socket{
  66. fn: fn,
  67. }
  68. return netlink.NewConn(sock, PID)
  69. }
  70. // CheckRequest returns a Func that verifies that each message in an incoming
  71. // request has the specified netlink header type and flags in the same slice
  72. // position index, and then passes the request through to fn.
  73. //
  74. // The length of the types and flags slices must match the number of requests
  75. // passed to the returned Func, or CheckRequest will panic.
  76. //
  77. // As an example:
  78. // - types[0] and flags[0] will be checked against reqs[0]
  79. // - types[1] and flags[1] will be checked against reqs[1]
  80. // - ... and so on
  81. //
  82. // If an element of types or flags is set to the zero value, that check will
  83. // be skipped for the request message that occurs at the same index.
  84. //
  85. // As an example, if types[0] is 0 and reqs[0].Header.Type is 1, the check will
  86. // succeed because types[0] was not specified.
  87. func CheckRequest(types []netlink.HeaderType, flags []netlink.HeaderFlags, fn Func) Func {
  88. if len(types) != len(flags) {
  89. panicf("nltest: CheckRequest called with mismatched types and flags slice lengths: %d != %d",
  90. len(types), len(flags))
  91. }
  92. return func(req []netlink.Message) ([]netlink.Message, error) {
  93. if len(types) != len(req) {
  94. panicf("nltest: CheckRequest function invoked types/flags and request message slice lengths: %d != %d",
  95. len(types), len(req))
  96. }
  97. for i := range req {
  98. if want, got := types[i], req[i].Header.Type; types[i] != 0 && want != got {
  99. return nil, fmt.Errorf("nltest: unexpected netlink header type: %s, want: %s", got, want)
  100. }
  101. if want, got := flags[i], req[i].Header.Flags; flags[i] != 0 && want != got {
  102. return nil, fmt.Errorf("nltest: unexpected netlink header flags: %s, want: %s", got, want)
  103. }
  104. }
  105. return fn(req)
  106. }
  107. }
  108. // A socket is a netlink.Socket used for testing.
  109. type socket struct {
  110. fn Func
  111. msgs []netlink.Message
  112. err error
  113. }
  114. func (c *socket) Close() error { return nil }
  115. func (c *socket) SendMessages(messages []netlink.Message) error {
  116. msgs, err := c.fn(messages)
  117. c.msgs = append(c.msgs, msgs...)
  118. c.err = err
  119. return nil
  120. }
  121. func (c *socket) Send(m netlink.Message) error {
  122. c.msgs, c.err = c.fn([]netlink.Message{m})
  123. return nil
  124. }
  125. func (c *socket) Receive() ([]netlink.Message, error) {
  126. // No messages set by Send means that we are emulating a
  127. // multicast response or an error occurred.
  128. if len(c.msgs) == 0 {
  129. switch c.err {
  130. case nil:
  131. // No error, simulate multicast, but also return EOF to simulate
  132. // no replies if needed.
  133. msgs, err := c.fn(nil)
  134. if err == io.EOF {
  135. err = nil
  136. }
  137. return msgs, err
  138. case io.EOF:
  139. // EOF, simulate no replies in multi-part message.
  140. return nil, nil
  141. }
  142. // If the error is a system call error, wrap it in os.NewSyscallError
  143. // to simulate what the Linux netlink.Conn does.
  144. if isSyscallError(c.err) {
  145. return nil, os.NewSyscallError("recvmsg", c.err)
  146. }
  147. // Some generic error occurred and should be passed to the caller.
  148. return nil, c.err
  149. }
  150. // Detect multi-part messages.
  151. var multi bool
  152. for _, m := range c.msgs {
  153. if m.Header.Flags&netlink.Multi != 0 && m.Header.Type != netlink.Done {
  154. multi = true
  155. }
  156. }
  157. // When a multi-part message is detected, return all messages except for the
  158. // final "multi-part done", so that a second call to Receive from netlink.Conn
  159. // will drain that message.
  160. if multi {
  161. last := c.msgs[len(c.msgs)-1]
  162. ret := c.msgs[:len(c.msgs)-1]
  163. c.msgs = []netlink.Message{last}
  164. return ret, c.err
  165. }
  166. msgs, err := c.msgs, c.err
  167. c.msgs, c.err = nil, nil
  168. return msgs, err
  169. }
  170. func panicf(format string, a ...interface{}) {
  171. panic(fmt.Sprintf(format, a...))
  172. }