udp_mux.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package ice
  4. import (
  5. "errors"
  6. "io"
  7. "net"
  8. "os"
  9. "strings"
  10. "sync"
  11. "github.com/pion/logging"
  12. "github.com/pion/stun"
  13. "github.com/pion/transport/v2"
  14. "github.com/pion/transport/v2/stdnet"
  15. )
  16. // UDPMux allows multiple connections to go over a single UDP port
  17. type UDPMux interface {
  18. io.Closer
  19. GetConn(ufrag string, addr net.Addr) (net.PacketConn, error)
  20. RemoveConnByUfrag(ufrag string)
  21. GetListenAddresses() []net.Addr
  22. }
  23. // UDPMuxDefault is an implementation of the interface
  24. type UDPMuxDefault struct {
  25. params UDPMuxParams
  26. closedChan chan struct{}
  27. closeOnce sync.Once
  28. // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
  29. connsIPv4, connsIPv6 map[string]*udpMuxedConn
  30. addressMapMu sync.RWMutex
  31. addressMap map[udpMuxedConnAddr]*udpMuxedConn
  32. // Buffer pool to recycle buffers for net.UDPAddr encodes/decodes
  33. pool *sync.Pool
  34. mu sync.Mutex
  35. // For UDP connection listen at unspecified address
  36. localAddrsForUnspecified []net.Addr
  37. }
  38. // UDPMuxParams are parameters for UDPMux.
  39. type UDPMuxParams struct {
  40. Logger logging.LeveledLogger
  41. UDPConn net.PacketConn
  42. // Required for gathering local addresses
  43. // in case a un UDPConn is passed which does not
  44. // bind to a specific local address.
  45. Net transport.Net
  46. }
  47. // NewUDPMuxDefault creates an implementation of UDPMux
  48. func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
  49. if params.Logger == nil {
  50. params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
  51. }
  52. var localAddrsForUnspecified []net.Addr
  53. if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
  54. params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
  55. } else if ok && addr.IP.IsUnspecified() {
  56. // For unspecified addresses, the correct behavior is to return errListenUnspecified, but
  57. // it will break the applications that are already using unspecified UDP connection
  58. // with UDPMuxDefault, so print a warn log and create a local address list for mux.
  59. params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
  60. var networks []NetworkType
  61. switch {
  62. case addr.IP.To4() != nil:
  63. networks = []NetworkType{NetworkTypeUDP4}
  64. case addr.IP.To16() != nil:
  65. networks = []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}
  66. default:
  67. params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
  68. }
  69. if len(networks) > 0 {
  70. if params.Net == nil {
  71. var err error
  72. if params.Net, err = stdnet.NewNet(); err != nil {
  73. params.Logger.Errorf("Failed to get create network: %v", err)
  74. }
  75. }
  76. ips, err := localInterfaces(params.Net, nil, nil, networks, true)
  77. if err == nil {
  78. for _, ip := range ips {
  79. localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
  80. }
  81. } else {
  82. params.Logger.Errorf("Failed to get local interfaces for unspecified addr: %v", err)
  83. }
  84. }
  85. }
  86. m := &UDPMuxDefault{
  87. addressMap: map[udpMuxedConnAddr]*udpMuxedConn{},
  88. params: params,
  89. connsIPv4: make(map[string]*udpMuxedConn),
  90. connsIPv6: make(map[string]*udpMuxedConn),
  91. closedChan: make(chan struct{}, 1),
  92. pool: &sync.Pool{
  93. New: func() interface{} {
  94. // Big enough buffer to fit both packet and address
  95. return newBufferHolder(receiveMTU)
  96. },
  97. },
  98. localAddrsForUnspecified: localAddrsForUnspecified,
  99. }
  100. // [Psiphon]
  101. //
  102. // - Currently, pion/ice code produces the following race condition due to
  103. // NewUDPMuxDefault launching go m.connWorker() before the called
  104. // assigns m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams), while
  105. // connWorker may access fields in the
  106. // UniversalUDPMuxDefault.UDPMuxDefault embedded struct.
  107. //
  108. // - For Psiphon's use case, the simple workaround is to delay launching
  109. // go m.connWorker() until after the assignment in
  110. // NewUniversalUDPMuxDefault. This isn't a general purpose fix since it
  111. // means NewUDPMuxDefault by itself won't work.
  112. //
  113. // - Note that the IsPsiphon flag/check added for
  114. // gatherCandidatesSrflxUDPMux also checks that this fix is in place.
  115. //
  116. //
  117. // ==================
  118. // WARNING: DATA RACE
  119. // Read at 0x00c000ee28c0 by goroutine 22319:
  120. // github.com/pion/ice/v2.(*UniversalUDPMuxDefault).isXORMappedResponse()
  121. // /pion/ice/v2/udp_mux_universal.go:136 +0x40
  122. // github.com/pion/ice/v2.(*udpConn).ReadFrom()
  123. // /pion/ice/v2/udp_mux_universal.go:122 +0x234
  124. // github.com/pion/ice/v2.(*UDPMuxDefault).connWorker()
  125. // /pion/ice/v2/udp_mux.go:286 +0xd4
  126. // github.com/pion/ice/v2.NewUDPMuxDefault.func2()
  127. // /pion/ice/v2/udp_mux.go:122 +0x34
  128. //
  129. // Previous write at 0x00c000ee28c0 by goroutine 22315:
  130. // github.com/pion/ice/v2.NewUniversalUDPMuxDefault()
  131. // /pion/ice/v2/udp_mux_universal.go:73 +0x354
  132. // github.com/pion/webrtc/v3.NewICEUniversalUDPMux()
  133. // /pion/webrtc/v3/icemux.go:39 +0x2b4
  134. // ==================
  135. //go m.connWorker()
  136. return m
  137. }
  138. // LocalAddr returns the listening address of this UDPMuxDefault
  139. func (m *UDPMuxDefault) LocalAddr() net.Addr {
  140. return m.params.UDPConn.LocalAddr()
  141. }
  142. // GetListenAddresses returns the list of addresses that this mux is listening on
  143. func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
  144. if len(m.localAddrsForUnspecified) > 0 {
  145. return m.localAddrsForUnspecified
  146. }
  147. return []net.Addr{m.LocalAddr()}
  148. }
  149. // GetConn returns a PacketConn given the connection's ufrag and network address
  150. // creates the connection if an existing one can't be found
  151. func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
  152. // don't check addr for mux using unspecified address
  153. if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
  154. return nil, errInvalidAddress
  155. }
  156. var isIPv6 bool
  157. if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
  158. isIPv6 = true
  159. }
  160. m.mu.Lock()
  161. defer m.mu.Unlock()
  162. if m.IsClosed() {
  163. return nil, io.ErrClosedPipe
  164. }
  165. if conn, ok := m.getConn(ufrag, isIPv6); ok {
  166. return conn, nil
  167. }
  168. c := m.createMuxedConn(ufrag)
  169. go func() {
  170. <-c.CloseChannel()
  171. m.RemoveConnByUfrag(ufrag)
  172. }()
  173. if isIPv6 {
  174. m.connsIPv6[ufrag] = c
  175. } else {
  176. m.connsIPv4[ufrag] = c
  177. }
  178. return c, nil
  179. }
  180. // RemoveConnByUfrag stops and removes the muxed packet connection
  181. func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
  182. removedConns := make([]*udpMuxedConn, 0, 2)
  183. // Keep lock section small to avoid deadlock with conn lock
  184. m.mu.Lock()
  185. if c, ok := m.connsIPv4[ufrag]; ok {
  186. delete(m.connsIPv4, ufrag)
  187. removedConns = append(removedConns, c)
  188. }
  189. if c, ok := m.connsIPv6[ufrag]; ok {
  190. delete(m.connsIPv6, ufrag)
  191. removedConns = append(removedConns, c)
  192. }
  193. m.mu.Unlock()
  194. if len(removedConns) == 0 {
  195. // No need to lock if no connection was found
  196. return
  197. }
  198. m.addressMapMu.Lock()
  199. defer m.addressMapMu.Unlock()
  200. for _, c := range removedConns {
  201. addresses := c.getAddresses()
  202. for _, addr := range addresses {
  203. delete(m.addressMap, addr)
  204. }
  205. }
  206. }
  207. // IsClosed returns true if the mux had been closed
  208. func (m *UDPMuxDefault) IsClosed() bool {
  209. select {
  210. case <-m.closedChan:
  211. return true
  212. default:
  213. return false
  214. }
  215. }
  216. // Close the mux, no further connections could be created
  217. func (m *UDPMuxDefault) Close() error {
  218. var err error
  219. m.closeOnce.Do(func() {
  220. m.mu.Lock()
  221. defer m.mu.Unlock()
  222. for _, c := range m.connsIPv4 {
  223. _ = c.Close()
  224. }
  225. for _, c := range m.connsIPv6 {
  226. _ = c.Close()
  227. }
  228. m.connsIPv4 = make(map[string]*udpMuxedConn)
  229. m.connsIPv6 = make(map[string]*udpMuxedConn)
  230. close(m.closedChan)
  231. _ = m.params.UDPConn.Close()
  232. })
  233. return err
  234. }
  235. func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
  236. return m.params.UDPConn.WriteTo(buf, rAddr)
  237. }
  238. func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr udpMuxedConnAddr) {
  239. if m.IsClosed() {
  240. return
  241. }
  242. m.addressMapMu.Lock()
  243. defer m.addressMapMu.Unlock()
  244. existing, ok := m.addressMap[addr]
  245. if ok {
  246. existing.removeAddress(addr)
  247. }
  248. m.addressMap[addr] = conn
  249. m.params.Logger.Debugf("Registered %s for %s", addr, conn.params.Key)
  250. }
  251. func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
  252. c := newUDPMuxedConn(&udpMuxedConnParams{
  253. Mux: m,
  254. Key: key,
  255. AddrPool: m.pool,
  256. LocalAddr: m.LocalAddr(),
  257. Logger: m.params.Logger,
  258. })
  259. return c
  260. }
  261. func (m *UDPMuxDefault) connWorker() {
  262. logger := m.params.Logger
  263. defer func() {
  264. _ = m.Close()
  265. }()
  266. buf := make([]byte, receiveMTU)
  267. for {
  268. n, addr, err := m.params.UDPConn.ReadFrom(buf)
  269. if m.IsClosed() {
  270. return
  271. } else if err != nil {
  272. if os.IsTimeout(err) {
  273. continue
  274. } else if !errors.Is(err, io.EOF) {
  275. logger.Errorf("Failed to read UDP packet: %v", err)
  276. }
  277. return
  278. }
  279. udpAddr, ok := addr.(*net.UDPAddr)
  280. if !ok {
  281. logger.Errorf("Underlying PacketConn did not return a UDPAddr")
  282. return
  283. }
  284. // If we have already seen this address dispatch to the appropriate destination
  285. m.addressMapMu.Lock()
  286. destinationConn := m.addressMap[newUDPMuxedConnAddr(udpAddr)]
  287. m.addressMapMu.Unlock()
  288. // If we haven't seen this address before but is a STUN packet lookup by ufrag
  289. if destinationConn == nil && stun.IsMessage(buf[:n]) {
  290. msg := &stun.Message{
  291. Raw: append([]byte{}, buf[:n]...),
  292. }
  293. if err = msg.Decode(); err != nil {
  294. m.params.Logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err)
  295. continue
  296. }
  297. attr, stunAttrErr := msg.Get(stun.AttrUsername)
  298. if stunAttrErr != nil {
  299. m.params.Logger.Warnf("No Username attribute in STUN message from %s", addr.String())
  300. continue
  301. }
  302. ufrag := strings.Split(string(attr), ":")[0]
  303. isIPv6 := udpAddr.IP.To4() == nil
  304. m.mu.Lock()
  305. destinationConn, _ = m.getConn(ufrag, isIPv6)
  306. m.mu.Unlock()
  307. }
  308. if destinationConn == nil {
  309. m.params.Logger.Tracef("Dropping packet from %s, addr: %s", udpAddr, addr)
  310. continue
  311. }
  312. if err = destinationConn.writePacket(buf[:n], udpAddr); err != nil {
  313. m.params.Logger.Errorf("Failed to write packet: %v", err)
  314. }
  315. }
  316. }
  317. func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
  318. if isIPv6 {
  319. val, ok = m.connsIPv6[ufrag]
  320. } else {
  321. val, ok = m.connsIPv4[ufrag]
  322. }
  323. return
  324. }
  325. type bufferHolder struct {
  326. next *bufferHolder
  327. buf []byte
  328. addr *net.UDPAddr
  329. }
  330. func newBufferHolder(size int) *bufferHolder {
  331. return &bufferHolder{
  332. buf: make([]byte, size),
  333. }
  334. }
  335. func (b *bufferHolder) reset() {
  336. b.next = nil
  337. b.addr = nil
  338. }