conn.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package mdns
  4. import (
  5. "context"
  6. "errors"
  7. "fmt"
  8. "net"
  9. "sync"
  10. "time"
  11. "github.com/pion/logging"
  12. "golang.org/x/net/dns/dnsmessage"
  13. "golang.org/x/net/ipv4"
  14. )
  15. // Conn represents a mDNS Server
  16. type Conn struct {
  17. mu sync.RWMutex
  18. log logging.LeveledLogger
  19. socket *ipv4.PacketConn
  20. dstAddr *net.UDPAddr
  21. queryInterval time.Duration
  22. localNames []string
  23. queries []*query
  24. ifaces []net.Interface
  25. closed chan interface{}
  26. }
  27. type query struct {
  28. nameWithSuffix string
  29. queryResultChan chan queryResult
  30. }
  31. type queryResult struct {
  32. answer dnsmessage.ResourceHeader
  33. addr net.Addr
  34. }
  35. const (
  36. defaultQueryInterval = time.Second
  37. destinationAddress = "224.0.0.251:5353"
  38. maxMessageRecords = 3
  39. responseTTL = 120
  40. // maxPacketSize is the maximum size of a mdns packet.
  41. // From RFC 6762:
  42. // Even when fragmentation is used, a Multicast DNS packet, including IP
  43. // and UDP headers, MUST NOT exceed 9000 bytes.
  44. // https://datatracker.ietf.org/doc/html/rfc6762#section-17
  45. maxPacketSize = 9000
  46. )
  47. var errNoPositiveMTUFound = errors.New("no positive MTU found")
  48. // Server establishes a mDNS connection over an existing conn.
  49. //
  50. // Currently, the server only supports listening on an IPv4 connection, but internally
  51. // it supports answering with IPv6 AAAA records if this were ever to change.
  52. func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) {
  53. if config == nil {
  54. return nil, errNilConfig
  55. }
  56. ifaces := config.Interfaces
  57. if ifaces == nil {
  58. var err error
  59. ifaces, err = net.Interfaces()
  60. if err != nil {
  61. return nil, err
  62. }
  63. }
  64. inboundBufferSize := 0
  65. joinErrCount := 0
  66. ifacesToUse := make([]net.Interface, 0, len(ifaces))
  67. for i, ifc := range ifaces {
  68. if !config.IncludeLoopback && ifc.Flags&net.FlagLoopback == net.FlagLoopback {
  69. continue
  70. }
  71. if err := conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil {
  72. joinErrCount++
  73. continue
  74. }
  75. ifcCopy := ifc
  76. ifacesToUse = append(ifacesToUse, ifcCopy)
  77. if ifaces[i].MTU > inboundBufferSize {
  78. inboundBufferSize = ifaces[i].MTU
  79. }
  80. }
  81. if inboundBufferSize == 0 {
  82. return nil, errNoPositiveMTUFound
  83. }
  84. if inboundBufferSize > maxPacketSize {
  85. inboundBufferSize = maxPacketSize
  86. }
  87. if joinErrCount >= len(ifaces) {
  88. return nil, errJoiningMulticastGroup
  89. }
  90. dstAddr, err := net.ResolveUDPAddr("udp", destinationAddress)
  91. if err != nil {
  92. return nil, err
  93. }
  94. loggerFactory := config.LoggerFactory
  95. if loggerFactory == nil {
  96. loggerFactory = logging.NewDefaultLoggerFactory()
  97. }
  98. localNames := []string{}
  99. for _, l := range config.LocalNames {
  100. localNames = append(localNames, l+".")
  101. }
  102. c := &Conn{
  103. queryInterval: defaultQueryInterval,
  104. queries: []*query{},
  105. socket: conn,
  106. dstAddr: dstAddr,
  107. localNames: localNames,
  108. ifaces: ifacesToUse,
  109. log: loggerFactory.NewLogger("mdns"),
  110. closed: make(chan interface{}),
  111. }
  112. if config.QueryInterval != 0 {
  113. c.queryInterval = config.QueryInterval
  114. }
  115. if err := conn.SetControlMessage(ipv4.FlagInterface, true); err != nil {
  116. c.log.Warnf("Failed to SetControlMessage on PacketConn %v", err)
  117. }
  118. if config.IncludeLoopback {
  119. // this is an efficient way for us to send ourselves a message faster instead of it going
  120. // further out into the network stack.
  121. if err := conn.SetMulticastLoopback(true); err != nil {
  122. c.log.Warnf("Failed to SetMulticastLoopback(true) on PacketConn %v; this may cause inefficient network path communications", err)
  123. }
  124. }
  125. // https://www.rfc-editor.org/rfc/rfc6762.html#section-17
  126. // Multicast DNS messages carried by UDP may be up to the IP MTU of the
  127. // physical interface, less the space required for the IP header (20
  128. // bytes for IPv4; 40 bytes for IPv6) and the UDP header (8 bytes).
  129. go c.start(inboundBufferSize-20-8, config)
  130. return c, nil
  131. }
  132. // Close closes the mDNS Conn
  133. func (c *Conn) Close() error {
  134. select {
  135. case <-c.closed:
  136. return nil
  137. default:
  138. }
  139. if err := c.socket.Close(); err != nil {
  140. return err
  141. }
  142. <-c.closed
  143. return nil
  144. }
  145. // Query sends mDNS Queries for the following name until
  146. // either the Context is canceled/expires or we get a result
  147. func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeader, net.Addr, error) {
  148. select {
  149. case <-c.closed:
  150. return dnsmessage.ResourceHeader{}, nil, errConnectionClosed
  151. default:
  152. }
  153. nameWithSuffix := name + "."
  154. queryChan := make(chan queryResult, 1)
  155. query := &query{nameWithSuffix, queryChan}
  156. c.mu.Lock()
  157. c.queries = append(c.queries, query)
  158. c.mu.Unlock()
  159. defer func() {
  160. c.mu.Lock()
  161. defer c.mu.Unlock()
  162. for i := len(c.queries) - 1; i >= 0; i-- {
  163. if c.queries[i] == query {
  164. c.queries = append(c.queries[:i], c.queries[i+1:]...)
  165. }
  166. }
  167. }()
  168. ticker := time.NewTicker(c.queryInterval)
  169. defer ticker.Stop()
  170. c.sendQuestion(nameWithSuffix)
  171. for {
  172. select {
  173. case <-ticker.C:
  174. c.sendQuestion(nameWithSuffix)
  175. case <-c.closed:
  176. return dnsmessage.ResourceHeader{}, nil, errConnectionClosed
  177. case res := <-queryChan:
  178. // Given https://datatracker.ietf.org/doc/html/draft-ietf-mmusic-mdns-ice-candidates#section-3.2.2-2
  179. // An ICE agent SHOULD ignore candidates where the hostname resolution returns more than one IP address.
  180. //
  181. // We will take the first we receive which could result in a race between two suitable addresses where
  182. // one is better than the other (e.g. localhost vs LAN).
  183. return res.answer, res.addr, nil
  184. case <-ctx.Done():
  185. return dnsmessage.ResourceHeader{}, nil, errContextElapsed
  186. }
  187. }
  188. }
  189. type ipToBytesError struct {
  190. ip net.IP
  191. expectedType string
  192. }
  193. func (err ipToBytesError) Error() string {
  194. return fmt.Sprintf("ip (%s) is not %s", err.ip, err.expectedType)
  195. }
  196. func ipv4ToBytes(ip net.IP) ([4]byte, error) {
  197. rawIP := ip.To4()
  198. if rawIP == nil {
  199. return [4]byte{}, ipToBytesError{ip, "IPv4"}
  200. }
  201. // net.IPs are stored in big endian / network byte order
  202. var out [4]byte
  203. copy(out[:], rawIP[:])
  204. return out, nil
  205. }
  206. func ipv6ToBytes(ip net.IP) ([16]byte, error) {
  207. rawIP := ip.To16()
  208. if rawIP == nil {
  209. return [16]byte{}, ipToBytesError{ip, "IPv6"}
  210. }
  211. // net.IPs are stored in big endian / network byte order
  212. var out [16]byte
  213. copy(out[:], rawIP[:])
  214. return out, nil
  215. }
  216. func interfaceForRemote(remote string) (net.IP, error) {
  217. conn, err := net.Dial("udp", remote)
  218. if err != nil {
  219. return nil, err
  220. }
  221. localAddr, ok := conn.LocalAddr().(*net.UDPAddr)
  222. if !ok {
  223. return nil, errFailedCast
  224. }
  225. if err := conn.Close(); err != nil {
  226. return nil, err
  227. }
  228. return localAddr.IP, nil
  229. }
  230. func (c *Conn) sendQuestion(name string) {
  231. packedName, err := dnsmessage.NewName(name)
  232. if err != nil {
  233. c.log.Warnf("Failed to construct mDNS packet %v", err)
  234. return
  235. }
  236. msg := dnsmessage.Message{
  237. Header: dnsmessage.Header{},
  238. Questions: []dnsmessage.Question{
  239. {
  240. Type: dnsmessage.TypeA,
  241. Class: dnsmessage.ClassINET,
  242. Name: packedName,
  243. },
  244. },
  245. }
  246. rawQuery, err := msg.Pack()
  247. if err != nil {
  248. c.log.Warnf("Failed to construct mDNS packet %v", err)
  249. return
  250. }
  251. c.writeToSocket(0, rawQuery, false)
  252. }
  253. func (c *Conn) writeToSocket(ifIndex int, b []byte, srcIfcIsLoopback bool) {
  254. if ifIndex != 0 {
  255. ifc, err := net.InterfaceByIndex(ifIndex)
  256. if err != nil {
  257. c.log.Warnf("Failed to get interface for %d: %v", ifIndex, err)
  258. return
  259. }
  260. if srcIfcIsLoopback && ifc.Flags&net.FlagLoopback == 0 {
  261. // avoid accidentally tricking the destination that itself is the same as us
  262. c.log.Warnf("Interface is not loopback %d", ifIndex)
  263. return
  264. }
  265. if err := c.socket.SetMulticastInterface(ifc); err != nil {
  266. c.log.Warnf("Failed to set multicast interface for %d: %v", ifIndex, err)
  267. } else {
  268. if _, err := c.socket.WriteTo(b, nil, c.dstAddr); err != nil {
  269. c.log.Warnf("Failed to send mDNS packet on interface %d: %v", ifIndex, err)
  270. }
  271. }
  272. return
  273. }
  274. for ifcIdx := range c.ifaces {
  275. if srcIfcIsLoopback && c.ifaces[ifcIdx].Flags&net.FlagLoopback == 0 {
  276. // avoid accidentally tricking the destination that itself is the same as us
  277. continue
  278. }
  279. if err := c.socket.SetMulticastInterface(&c.ifaces[ifcIdx]); err != nil {
  280. c.log.Warnf("Failed to set multicast interface for %d: %v", c.ifaces[ifcIdx].Index, err)
  281. } else {
  282. if _, err := c.socket.WriteTo(b, nil, c.dstAddr); err != nil {
  283. c.log.Warnf("Failed to send mDNS packet on interface %d: %v", c.ifaces[ifcIdx].Index, err)
  284. }
  285. }
  286. }
  287. }
  288. func createAnswer(name string, addr net.IP) (dnsmessage.Message, error) {
  289. packedName, err := dnsmessage.NewName(name)
  290. if err != nil {
  291. return dnsmessage.Message{}, err
  292. }
  293. msg := dnsmessage.Message{
  294. Header: dnsmessage.Header{
  295. Response: true,
  296. Authoritative: true,
  297. },
  298. Answers: []dnsmessage.Resource{
  299. {
  300. Header: dnsmessage.ResourceHeader{
  301. Type: dnsmessage.TypeA,
  302. Class: dnsmessage.ClassINET,
  303. Name: packedName,
  304. TTL: responseTTL,
  305. },
  306. },
  307. },
  308. }
  309. if ip4 := addr.To4(); ip4 != nil {
  310. ipBuf, err := ipv4ToBytes(addr)
  311. if err != nil {
  312. return dnsmessage.Message{}, err
  313. }
  314. msg.Answers[0].Body = &dnsmessage.AResource{
  315. A: ipBuf,
  316. }
  317. } else {
  318. ipBuf, err := ipv6ToBytes(addr)
  319. if err != nil {
  320. return dnsmessage.Message{}, err
  321. }
  322. msg.Answers[0].Body = &dnsmessage.AAAAResource{
  323. AAAA: ipBuf,
  324. }
  325. }
  326. return msg, nil
  327. }
  328. func (c *Conn) sendAnswer(name string, ifIndex int, addr net.IP) {
  329. answer, err := createAnswer(name, addr)
  330. if err != nil {
  331. c.log.Warnf("Failed to create mDNS answer %v", err)
  332. return
  333. }
  334. rawAnswer, err := answer.Pack()
  335. if err != nil {
  336. c.log.Warnf("Failed to construct mDNS packet %v", err)
  337. return
  338. }
  339. c.writeToSocket(ifIndex, rawAnswer, addr.IsLoopback())
  340. }
  341. func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit
  342. defer func() {
  343. c.mu.Lock()
  344. defer c.mu.Unlock()
  345. close(c.closed)
  346. }()
  347. b := make([]byte, inboundBufferSize)
  348. p := dnsmessage.Parser{}
  349. for {
  350. n, cm, src, err := c.socket.ReadFrom(b)
  351. if err != nil {
  352. if errors.Is(err, net.ErrClosed) {
  353. return
  354. }
  355. c.log.Warnf("Failed to ReadFrom %q %v", src, err)
  356. continue
  357. }
  358. var ifIndex int
  359. if cm != nil {
  360. ifIndex = cm.IfIndex
  361. }
  362. var srcIP net.IP
  363. switch addr := src.(type) {
  364. case *net.UDPAddr:
  365. srcIP = addr.IP
  366. case *net.TCPAddr:
  367. srcIP = addr.IP
  368. default:
  369. c.log.Warnf("Failed to determine address type %T for source address %s", src, src)
  370. continue
  371. }
  372. srcIsIPv4 := srcIP.To4() != nil
  373. func() {
  374. c.mu.RLock()
  375. defer c.mu.RUnlock()
  376. if _, err := p.Start(b[:n]); err != nil {
  377. c.log.Warnf("Failed to parse mDNS packet %v", err)
  378. return
  379. }
  380. for i := 0; i <= maxMessageRecords; i++ {
  381. q, err := p.Question()
  382. if errors.Is(err, dnsmessage.ErrSectionDone) {
  383. break
  384. } else if err != nil {
  385. c.log.Warnf("Failed to parse mDNS packet %v", err)
  386. return
  387. }
  388. for _, localName := range c.localNames {
  389. if localName == q.Name.String() {
  390. if config.LocalAddress != nil {
  391. c.sendAnswer(q.Name.String(), ifIndex, config.LocalAddress)
  392. } else {
  393. var localAddress net.IP
  394. // prefer the address of the interface if we know its index, but otherwise
  395. // derive it from the address we read from. We do this because even if
  396. // multicast loopback is in use or we send from a loopback interface,
  397. // there are still cases where the IP packet will contain the wrong
  398. // source IP (e.g. a LAN interface).
  399. // For example, we can have a packet that has:
  400. // Source: 192.168.65.3
  401. // Destination: 224.0.0.251
  402. // Interface Index: 1
  403. // Interface Addresses @ 1: [127.0.0.1/8 ::1/128]
  404. if ifIndex != 0 {
  405. ifc, netErr := net.InterfaceByIndex(ifIndex)
  406. if netErr != nil {
  407. c.log.Warnf("Failed to get interface for %d: %v", ifIndex, netErr)
  408. continue
  409. }
  410. addrs, addrsErr := ifc.Addrs()
  411. if addrsErr != nil {
  412. c.log.Warnf("Failed to get addresses for interface %d: %v", ifIndex, addrsErr)
  413. continue
  414. }
  415. if len(addrs) == 0 {
  416. c.log.Warnf("Expected more than one address for interface %d", ifIndex)
  417. continue
  418. }
  419. var selectedIP net.IP
  420. for _, addr := range addrs {
  421. var ip net.IP
  422. switch addr := addr.(type) {
  423. case *net.IPNet:
  424. ip = addr.IP
  425. case *net.IPAddr:
  426. ip = addr.IP
  427. default:
  428. c.log.Warnf("Failed to determine address type %T from interface %d", addr, ifIndex)
  429. continue
  430. }
  431. // match up respective IP types
  432. if ipv4 := ip.To4(); ipv4 == nil {
  433. if srcIsIPv4 {
  434. continue
  435. } else if !isSupportedIPv6(ip) {
  436. continue
  437. }
  438. } else if !srcIsIPv4 {
  439. continue
  440. }
  441. selectedIP = ip
  442. break
  443. }
  444. if selectedIP == nil {
  445. c.log.Warnf("Failed to find suitable IP for interface %d; deriving address from source address instead", ifIndex)
  446. } else {
  447. localAddress = selectedIP
  448. }
  449. } else if ifIndex == 0 || localAddress == nil {
  450. localAddress, err = interfaceForRemote(src.String())
  451. if err != nil {
  452. c.log.Warnf("Failed to get local interface to communicate with %s: %v", src.String(), err)
  453. continue
  454. }
  455. }
  456. c.sendAnswer(q.Name.String(), ifIndex, localAddress)
  457. }
  458. }
  459. }
  460. }
  461. for i := 0; i <= maxMessageRecords; i++ {
  462. a, err := p.AnswerHeader()
  463. if errors.Is(err, dnsmessage.ErrSectionDone) {
  464. return
  465. }
  466. if err != nil {
  467. c.log.Warnf("Failed to parse mDNS packet %v", err)
  468. return
  469. }
  470. if a.Type != dnsmessage.TypeA && a.Type != dnsmessage.TypeAAAA {
  471. continue
  472. }
  473. for i := len(c.queries) - 1; i >= 0; i-- {
  474. if c.queries[i].nameWithSuffix == a.Name.String() {
  475. ip, err := ipFromAnswerHeader(a, p)
  476. if err != nil {
  477. c.log.Warnf("Failed to parse mDNS answer %v", err)
  478. return
  479. }
  480. c.queries[i].queryResultChan <- queryResult{a, &net.IPAddr{
  481. IP: ip,
  482. }}
  483. c.queries = append(c.queries[:i], c.queries[i+1:]...)
  484. }
  485. }
  486. }
  487. }()
  488. }
  489. }
  490. func ipFromAnswerHeader(a dnsmessage.ResourceHeader, p dnsmessage.Parser) (ip []byte, err error) {
  491. if a.Type == dnsmessage.TypeA {
  492. resource, err := p.AResource()
  493. if err != nil {
  494. return nil, err
  495. }
  496. ip = resource.A[:]
  497. } else {
  498. resource, err := p.AAAAResource()
  499. if err != nil {
  500. return nil, err
  501. }
  502. ip = resource.AAAA[:]
  503. }
  504. return
  505. }
  506. // The conditions of invalidation written below are defined in
  507. // https://tools.ietf.org/html/rfc8445#section-5.1.1.1
  508. func isSupportedIPv6(ip net.IP) bool {
  509. if len(ip) != net.IPv6len ||
  510. isZeros(ip[0:12]) || // !(IPv4-compatible IPv6)
  511. ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 || // !(IPv6 site-local unicast)
  512. ip.IsLinkLocalUnicast() ||
  513. ip.IsLinkLocalMulticast() {
  514. return false
  515. }
  516. return true
  517. }
  518. func isZeros(ip net.IP) bool {
  519. for i := 0; i < len(ip); i++ {
  520. if ip[i] != 0 {
  521. return false
  522. }
  523. }
  524. return true
  525. }