server.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. // Package turn contains the public API for pion/turn, a toolkit for building TURN clients and servers
  4. package turn
  5. import (
  6. "errors"
  7. "fmt"
  8. "net"
  9. "sync"
  10. "time"
  11. "github.com/pion/logging"
  12. "github.com/pion/turn/v2/internal/allocation"
  13. "github.com/pion/turn/v2/internal/proto"
  14. "github.com/pion/turn/v2/internal/server"
  15. )
  16. const (
  17. defaultInboundMTU = 1600
  18. )
  19. // Server is an instance of the Pion TURN Server
  20. type Server struct {
  21. log logging.LeveledLogger
  22. authHandler AuthHandler
  23. realm string
  24. channelBindTimeout time.Duration
  25. nonces *sync.Map
  26. packetConnConfigs []PacketConnConfig
  27. listenerConfigs []ListenerConfig
  28. allocationManagers []*allocation.Manager
  29. inboundMTU int
  30. }
  31. // NewServer creates the Pion TURN server
  32. //
  33. //nolint:gocognit
  34. func NewServer(config ServerConfig) (*Server, error) {
  35. if err := config.validate(); err != nil {
  36. return nil, err
  37. }
  38. loggerFactory := config.LoggerFactory
  39. if loggerFactory == nil {
  40. loggerFactory = logging.NewDefaultLoggerFactory()
  41. }
  42. mtu := defaultInboundMTU
  43. if config.InboundMTU != 0 {
  44. mtu = config.InboundMTU
  45. }
  46. s := &Server{
  47. log: loggerFactory.NewLogger("turn"),
  48. authHandler: config.AuthHandler,
  49. realm: config.Realm,
  50. channelBindTimeout: config.ChannelBindTimeout,
  51. packetConnConfigs: config.PacketConnConfigs,
  52. listenerConfigs: config.ListenerConfigs,
  53. nonces: &sync.Map{},
  54. inboundMTU: mtu,
  55. }
  56. if s.channelBindTimeout == 0 {
  57. s.channelBindTimeout = proto.DefaultLifetime
  58. }
  59. for _, cfg := range s.packetConnConfigs {
  60. am, err := s.createAllocationManager(cfg.RelayAddressGenerator, cfg.PermissionHandler)
  61. if err != nil {
  62. return nil, fmt.Errorf("failed to create AllocationManager: %w", err)
  63. }
  64. go func(cfg PacketConnConfig, am *allocation.Manager) {
  65. s.readLoop(cfg.PacketConn, am)
  66. if err := am.Close(); err != nil {
  67. s.log.Errorf("Failed to close AllocationManager: %s", err)
  68. }
  69. }(cfg, am)
  70. }
  71. for _, cfg := range s.listenerConfigs {
  72. am, err := s.createAllocationManager(cfg.RelayAddressGenerator, cfg.PermissionHandler)
  73. if err != nil {
  74. return nil, fmt.Errorf("failed to create AllocationManager: %w", err)
  75. }
  76. go func(cfg ListenerConfig, am *allocation.Manager) {
  77. s.readListener(cfg.Listener, am)
  78. if err := am.Close(); err != nil {
  79. s.log.Errorf("Failed to close AllocationManager: %s", err)
  80. }
  81. }(cfg, am)
  82. }
  83. return s, nil
  84. }
  85. // AllocationCount returns the number of active allocations. It can be used to drain the server before closing
  86. func (s *Server) AllocationCount() int {
  87. allocs := 0
  88. for _, am := range s.allocationManagers {
  89. allocs += am.AllocationCount()
  90. }
  91. return allocs
  92. }
  93. // Close stops the TURN Server. It cleans up any associated state and closes all connections it is managing
  94. func (s *Server) Close() error {
  95. var errors []error
  96. for _, cfg := range s.packetConnConfigs {
  97. if err := cfg.PacketConn.Close(); err != nil {
  98. errors = append(errors, err)
  99. }
  100. }
  101. for _, cfg := range s.listenerConfigs {
  102. if err := cfg.Listener.Close(); err != nil {
  103. errors = append(errors, err)
  104. }
  105. }
  106. if len(errors) == 0 {
  107. return nil
  108. }
  109. err := errFailedToClose
  110. for _, e := range errors {
  111. err = fmt.Errorf("%s; close error (%w) ", err, e) //nolint:errorlint
  112. }
  113. return err
  114. }
  115. func (s *Server) readListener(l net.Listener, am *allocation.Manager) {
  116. for {
  117. conn, err := l.Accept()
  118. if err != nil {
  119. s.log.Debugf("Failed to accept: %s", err)
  120. return
  121. }
  122. go func() {
  123. s.readLoop(NewSTUNConn(conn), am)
  124. if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
  125. s.log.Errorf("Failed to close conn: %s", err)
  126. }
  127. }()
  128. }
  129. }
  130. func (s *Server) createAllocationManager(addrGenerator RelayAddressGenerator, handler PermissionHandler) (*allocation.Manager, error) {
  131. if handler == nil {
  132. handler = DefaultPermissionHandler
  133. }
  134. am, err := allocation.NewManager(allocation.ManagerConfig{
  135. AllocatePacketConn: addrGenerator.AllocatePacketConn,
  136. AllocateConn: addrGenerator.AllocateConn,
  137. PermissionHandler: handler,
  138. LeveledLogger: s.log,
  139. })
  140. if err != nil {
  141. return am, err
  142. }
  143. s.allocationManagers = append(s.allocationManagers, am)
  144. return am, err
  145. }
  146. func (s *Server) readLoop(p net.PacketConn, allocationManager *allocation.Manager) {
  147. buf := make([]byte, s.inboundMTU)
  148. for {
  149. n, addr, err := p.ReadFrom(buf)
  150. switch {
  151. case err != nil:
  152. s.log.Debugf("Exit read loop on error: %s", err)
  153. return
  154. case n >= s.inboundMTU:
  155. s.log.Debugf("Read bytes exceeded MTU, packet is possibly truncated")
  156. }
  157. if err := server.HandleRequest(server.Request{
  158. Conn: p,
  159. SrcAddr: addr,
  160. Buff: buf[:n],
  161. Log: s.log,
  162. AuthHandler: s.authHandler,
  163. Realm: s.realm,
  164. AllocationManager: allocationManager,
  165. ChannelBindTimeout: s.channelBindTimeout,
  166. Nonces: s.nonces,
  167. }); err != nil {
  168. s.log.Errorf("Failed to handle datagram: %v", err)
  169. }
  170. }
  171. }