netns_linux.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. //go:build linux
  2. // +build linux
  3. package socket
  4. import (
  5. "errors"
  6. "fmt"
  7. "os"
  8. "runtime"
  9. "golang.org/x/sync/errgroup"
  10. "golang.org/x/sys/unix"
  11. )
  12. // errNetNSDisabled is returned when network namespaces are unavailable on
  13. // a given system.
  14. var errNetNSDisabled = errors.New("socket: Linux network namespaces are not enabled on this system")
  15. // withNetNS invokes fn within the context of the network namespace specified by
  16. // fd, while also managing the logic required to safely do so by manipulating
  17. // thread-local state.
  18. func withNetNS(fd int, fn func() (*Conn, error)) (*Conn, error) {
  19. var (
  20. eg errgroup.Group
  21. conn *Conn
  22. )
  23. eg.Go(func() error {
  24. // Retrieve and store the calling OS thread's network namespace so the
  25. // thread can be reassigned to it after creating a socket in another network
  26. // namespace.
  27. runtime.LockOSThread()
  28. ns, err := threadNetNS()
  29. if err != nil {
  30. // No thread-local manipulation, unlock.
  31. runtime.UnlockOSThread()
  32. return err
  33. }
  34. defer ns.Close()
  35. // Beyond this point, the thread's network namespace is poisoned. Do not
  36. // unlock the OS thread until all network namespace manipulation completes
  37. // to avoid returning to the caller with altered thread-local state.
  38. // Assign the current OS thread the goroutine is locked to to the given
  39. // network namespace.
  40. if err := ns.Set(fd); err != nil {
  41. return err
  42. }
  43. // Attempt Conn creation and unconditionally restore the original namespace.
  44. c, err := fn()
  45. if nerr := ns.Restore(); nerr != nil {
  46. // Failed to restore original namespace. Return an error and allow the
  47. // runtime to terminate the thread.
  48. if err == nil {
  49. _ = c.Close()
  50. }
  51. return nerr
  52. }
  53. // No more thread-local state manipulation; return the new Conn.
  54. runtime.UnlockOSThread()
  55. conn = c
  56. return nil
  57. })
  58. if err := eg.Wait(); err != nil {
  59. return nil, err
  60. }
  61. return conn, nil
  62. }
  63. // A netNS is a handle that can manipulate network namespaces.
  64. //
  65. // Operations performed on a netNS must use runtime.LockOSThread before
  66. // manipulating any network namespaces.
  67. type netNS struct {
  68. // The handle to a network namespace.
  69. f *os.File
  70. // Indicates if network namespaces are disabled on this system, and thus
  71. // operations should become a no-op or return errors.
  72. disabled bool
  73. }
  74. // threadNetNS constructs a netNS using the network namespace of the calling
  75. // thread. If the namespace is not the default namespace, runtime.LockOSThread
  76. // should be invoked first.
  77. func threadNetNS() (*netNS, error) {
  78. return fileNetNS(fmt.Sprintf("/proc/self/task/%d/ns/net", unix.Gettid()))
  79. }
  80. // fileNetNS opens file and creates a netNS. fileNetNS should only be called
  81. // directly in tests.
  82. func fileNetNS(file string) (*netNS, error) {
  83. f, err := os.Open(file)
  84. switch {
  85. case err == nil:
  86. return &netNS{f: f}, nil
  87. case os.IsNotExist(err):
  88. // Network namespaces are not enabled on this system. Use this signal
  89. // to return errors elsewhere if the caller explicitly asks for a
  90. // network namespace to be set.
  91. return &netNS{disabled: true}, nil
  92. default:
  93. return nil, err
  94. }
  95. }
  96. // Close releases the handle to a network namespace.
  97. func (n *netNS) Close() error {
  98. return n.do(func() error { return n.f.Close() })
  99. }
  100. // FD returns a file descriptor which represents the network namespace.
  101. func (n *netNS) FD() int {
  102. if n.disabled {
  103. // No reasonable file descriptor value in this case, so specify a
  104. // non-existent one.
  105. return -1
  106. }
  107. return int(n.f.Fd())
  108. }
  109. // Restore restores the original network namespace for the calling thread.
  110. func (n *netNS) Restore() error {
  111. return n.do(func() error { return n.Set(n.FD()) })
  112. }
  113. // Set sets a new network namespace for the current thread using fd.
  114. func (n *netNS) Set(fd int) error {
  115. return n.do(func() error {
  116. return os.NewSyscallError("setns", unix.Setns(fd, unix.CLONE_NEWNET))
  117. })
  118. }
  119. // do runs fn if network namespaces are enabled on this system.
  120. func (n *netNS) do(fn func() error) error {
  121. if n.disabled {
  122. return errNetNSDisabled
  123. }
  124. return fn()
  125. }