conn.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package mdns
  4. import (
  5. "context"
  6. "errors"
  7. "math/big"
  8. "net"
  9. "sync"
  10. "time"
  11. "github.com/pion/logging"
  12. "golang.org/x/net/dns/dnsmessage"
  13. "golang.org/x/net/ipv4"
  14. )
  15. // Conn represents a mDNS Server
  16. type Conn struct {
  17. mu sync.RWMutex
  18. log logging.LeveledLogger
  19. socket *ipv4.PacketConn
  20. dstAddr *net.UDPAddr
  21. queryInterval time.Duration
  22. localNames []string
  23. queries []query
  24. ifaces []net.Interface
  25. closed chan interface{}
  26. }
  27. type query struct {
  28. nameWithSuffix string
  29. queryResultChan chan queryResult
  30. }
  31. type queryResult struct {
  32. answer dnsmessage.ResourceHeader
  33. addr net.Addr
  34. }
  35. const (
  36. defaultQueryInterval = time.Second
  37. destinationAddress = "224.0.0.251:5353"
  38. maxMessageRecords = 3
  39. responseTTL = 120
  40. )
  41. var errNoPositiveMTUFound = errors.New("no positive MTU found")
  42. // Server establishes a mDNS connection over an existing conn
  43. func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) {
  44. if config == nil {
  45. return nil, errNilConfig
  46. }
  47. ifaces, err := net.Interfaces()
  48. if err != nil {
  49. return nil, err
  50. }
  51. inboundBufferSize := 0
  52. joinErrCount := 0
  53. ifacesToUse := make([]net.Interface, 0, len(ifaces))
  54. for i, ifc := range ifaces {
  55. if err = conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil {
  56. joinErrCount++
  57. continue
  58. }
  59. ifcCopy := ifc
  60. ifacesToUse = append(ifacesToUse, ifcCopy)
  61. if ifaces[i].MTU > inboundBufferSize {
  62. inboundBufferSize = ifaces[i].MTU
  63. }
  64. }
  65. if inboundBufferSize == 0 {
  66. return nil, errNoPositiveMTUFound
  67. }
  68. if joinErrCount >= len(ifaces) {
  69. return nil, errJoiningMulticastGroup
  70. }
  71. dstAddr, err := net.ResolveUDPAddr("udp", destinationAddress)
  72. if err != nil {
  73. return nil, err
  74. }
  75. loggerFactory := config.LoggerFactory
  76. if loggerFactory == nil {
  77. loggerFactory = logging.NewDefaultLoggerFactory()
  78. }
  79. localNames := []string{}
  80. for _, l := range config.LocalNames {
  81. localNames = append(localNames, l+".")
  82. }
  83. c := &Conn{
  84. queryInterval: defaultQueryInterval,
  85. queries: []query{},
  86. socket: conn,
  87. dstAddr: dstAddr,
  88. localNames: localNames,
  89. ifaces: ifacesToUse,
  90. log: loggerFactory.NewLogger("mdns"),
  91. closed: make(chan interface{}),
  92. }
  93. if config.QueryInterval != 0 {
  94. c.queryInterval = config.QueryInterval
  95. }
  96. if err := conn.SetControlMessage(ipv4.FlagInterface, true); err != nil {
  97. c.log.Warnf("Failed to SetControlMessage on PacketConn %v", err)
  98. }
  99. // https://www.rfc-editor.org/rfc/rfc6762.html#section-17
  100. // Multicast DNS messages carried by UDP may be up to the IP MTU of the
  101. // physical interface, less the space required for the IP header (20
  102. // bytes for IPv4; 40 bytes for IPv6) and the UDP header (8 bytes).
  103. go c.start(inboundBufferSize-20-8, config)
  104. return c, nil
  105. }
  106. // Close closes the mDNS Conn
  107. func (c *Conn) Close() error {
  108. select {
  109. case <-c.closed:
  110. return nil
  111. default:
  112. }
  113. if err := c.socket.Close(); err != nil {
  114. return err
  115. }
  116. <-c.closed
  117. return nil
  118. }
  119. // Query sends mDNS Queries for the following name until
  120. // either the Context is canceled/expires or we get a result
  121. func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeader, net.Addr, error) {
  122. select {
  123. case <-c.closed:
  124. return dnsmessage.ResourceHeader{}, nil, errConnectionClosed
  125. default:
  126. }
  127. nameWithSuffix := name + "."
  128. queryChan := make(chan queryResult, 1)
  129. c.mu.Lock()
  130. c.queries = append(c.queries, query{nameWithSuffix, queryChan})
  131. ticker := time.NewTicker(c.queryInterval)
  132. c.mu.Unlock()
  133. defer ticker.Stop()
  134. c.sendQuestion(nameWithSuffix)
  135. for {
  136. select {
  137. case <-ticker.C:
  138. c.sendQuestion(nameWithSuffix)
  139. case <-c.closed:
  140. return dnsmessage.ResourceHeader{}, nil, errConnectionClosed
  141. case res := <-queryChan:
  142. return res.answer, res.addr, nil
  143. case <-ctx.Done():
  144. return dnsmessage.ResourceHeader{}, nil, errContextElapsed
  145. }
  146. }
  147. }
  148. func ipToBytes(ip net.IP) (out [4]byte) {
  149. rawIP := ip.To4()
  150. if rawIP == nil {
  151. return
  152. }
  153. ipInt := big.NewInt(0)
  154. ipInt.SetBytes(rawIP)
  155. copy(out[:], ipInt.Bytes())
  156. return
  157. }
  158. func interfaceForRemote(remote string) (net.IP, error) {
  159. conn, err := net.Dial("udp", remote)
  160. if err != nil {
  161. return nil, err
  162. }
  163. localAddr, ok := conn.LocalAddr().(*net.UDPAddr)
  164. if !ok {
  165. return nil, errFailedCast
  166. }
  167. if err := conn.Close(); err != nil {
  168. return nil, err
  169. }
  170. return localAddr.IP, nil
  171. }
  172. func (c *Conn) sendQuestion(name string) {
  173. packedName, err := dnsmessage.NewName(name)
  174. if err != nil {
  175. c.log.Warnf("Failed to construct mDNS packet %v", err)
  176. return
  177. }
  178. msg := dnsmessage.Message{
  179. Header: dnsmessage.Header{},
  180. Questions: []dnsmessage.Question{
  181. {
  182. Type: dnsmessage.TypeA,
  183. Class: dnsmessage.ClassINET,
  184. Name: packedName,
  185. },
  186. },
  187. }
  188. rawQuery, err := msg.Pack()
  189. if err != nil {
  190. c.log.Warnf("Failed to construct mDNS packet %v", err)
  191. return
  192. }
  193. c.writeToSocket(0, rawQuery, false)
  194. }
  195. func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) {
  196. if ifIndex != 0 {
  197. ifc, err := net.InterfaceByIndex(ifIndex)
  198. if err != nil {
  199. c.log.Warnf("Failed to get interface interface for %d: %v", ifIndex, err)
  200. return
  201. }
  202. if onlyLooback && ifc.Flags&net.FlagLoopback == 0 {
  203. // avoid accidentally tricking the destination that itself is the same as us
  204. c.log.Warnf("Interface is not loopback %d", ifIndex)
  205. return
  206. }
  207. if err := c.socket.SetMulticastInterface(ifc); err != nil {
  208. c.log.Warnf("Failed to set multicast interface for %d: %v", ifIndex, err)
  209. } else {
  210. if _, err := c.socket.WriteTo(b, nil, c.dstAddr); err != nil {
  211. c.log.Warnf("Failed to send mDNS packet on interface %d: %v", ifIndex, err)
  212. }
  213. }
  214. return
  215. }
  216. for ifcIdx := range c.ifaces {
  217. if onlyLooback && c.ifaces[ifcIdx].Flags&net.FlagLoopback == 0 {
  218. // avoid accidentally tricking the destination that itself is the same as us
  219. continue
  220. }
  221. if err := c.socket.SetMulticastInterface(&c.ifaces[ifcIdx]); err != nil {
  222. c.log.Warnf("Failed to set multicast interface for %d: %v", c.ifaces[ifcIdx].Index, err)
  223. } else {
  224. if _, err := c.socket.WriteTo(b, nil, c.dstAddr); err != nil {
  225. c.log.Warnf("Failed to send mDNS packet on interface %d: %v", c.ifaces[ifcIdx].Index, err)
  226. }
  227. }
  228. }
  229. }
  230. func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) {
  231. packedName, err := dnsmessage.NewName(name)
  232. if err != nil {
  233. c.log.Warnf("Failed to construct mDNS packet %v", err)
  234. return
  235. }
  236. msg := dnsmessage.Message{
  237. Header: dnsmessage.Header{
  238. Response: true,
  239. Authoritative: true,
  240. },
  241. Answers: []dnsmessage.Resource{
  242. {
  243. Header: dnsmessage.ResourceHeader{
  244. Type: dnsmessage.TypeA,
  245. Class: dnsmessage.ClassINET,
  246. Name: packedName,
  247. TTL: responseTTL,
  248. },
  249. Body: &dnsmessage.AResource{
  250. A: ipToBytes(dst),
  251. },
  252. },
  253. },
  254. }
  255. rawAnswer, err := msg.Pack()
  256. if err != nil {
  257. c.log.Warnf("Failed to construct mDNS packet %v", err)
  258. return
  259. }
  260. c.writeToSocket(ifIndex, rawAnswer, dst.IsLoopback())
  261. }
  262. func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit
  263. defer func() {
  264. c.mu.Lock()
  265. defer c.mu.Unlock()
  266. close(c.closed)
  267. }()
  268. b := make([]byte, inboundBufferSize)
  269. p := dnsmessage.Parser{}
  270. for {
  271. n, cm, src, err := c.socket.ReadFrom(b)
  272. if err != nil {
  273. if errors.Is(err, net.ErrClosed) {
  274. return
  275. }
  276. c.log.Warnf("Failed to ReadFrom %q %v", src, err)
  277. continue
  278. }
  279. var ifIndex int
  280. if cm != nil {
  281. ifIndex = cm.IfIndex
  282. }
  283. func() {
  284. c.mu.RLock()
  285. defer c.mu.RUnlock()
  286. if _, err := p.Start(b[:n]); err != nil {
  287. c.log.Warnf("Failed to parse mDNS packet %v", err)
  288. return
  289. }
  290. for i := 0; i <= maxMessageRecords; i++ {
  291. q, err := p.Question()
  292. if errors.Is(err, dnsmessage.ErrSectionDone) {
  293. break
  294. } else if err != nil {
  295. c.log.Warnf("Failed to parse mDNS packet %v", err)
  296. return
  297. }
  298. for _, localName := range c.localNames {
  299. if localName == q.Name.String() {
  300. if config.LocalAddress != nil {
  301. c.sendAnswer(q.Name.String(), ifIndex, config.LocalAddress)
  302. } else {
  303. localAddress, err := interfaceForRemote(src.String())
  304. if err != nil {
  305. c.log.Warnf("Failed to get local interface to communicate with %s: %v", src.String(), err)
  306. continue
  307. }
  308. c.sendAnswer(q.Name.String(), ifIndex, localAddress)
  309. }
  310. }
  311. }
  312. }
  313. for i := 0; i <= maxMessageRecords; i++ {
  314. a, err := p.AnswerHeader()
  315. if errors.Is(err, dnsmessage.ErrSectionDone) {
  316. return
  317. }
  318. if err != nil {
  319. c.log.Warnf("Failed to parse mDNS packet %v", err)
  320. return
  321. }
  322. if a.Type != dnsmessage.TypeA && a.Type != dnsmessage.TypeAAAA {
  323. continue
  324. }
  325. for i := len(c.queries) - 1; i >= 0; i-- {
  326. if c.queries[i].nameWithSuffix == a.Name.String() {
  327. ip, err := ipFromAnswerHeader(a, p)
  328. if err != nil {
  329. c.log.Warnf("Failed to parse mDNS answer %v", err)
  330. return
  331. }
  332. c.queries[i].queryResultChan <- queryResult{a, &net.IPAddr{
  333. IP: ip,
  334. }}
  335. c.queries = append(c.queries[:i], c.queries[i+1:]...)
  336. }
  337. }
  338. }
  339. }()
  340. }
  341. }
  342. func ipFromAnswerHeader(a dnsmessage.ResourceHeader, p dnsmessage.Parser) (ip []byte, err error) {
  343. if a.Type == dnsmessage.TypeA {
  344. resource, err := p.AResource()
  345. if err != nil {
  346. return nil, err
  347. }
  348. ip = net.IP(resource.A[:])
  349. } else {
  350. resource, err := p.AAAAResource()
  351. if err != nil {
  352. return nil, err
  353. }
  354. ip = resource.AAAA[:]
  355. }
  356. return
  357. }