sctptransport.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. //go:build !js
  4. // +build !js
  5. package webrtc
  6. import (
  7. "errors"
  8. "io"
  9. "math"
  10. "sync"
  11. "time"
  12. "github.com/pion/datachannel"
  13. "github.com/pion/logging"
  14. "github.com/pion/sctp"
  15. "github.com/pion/webrtc/v3/pkg/rtcerr"
  16. )
  17. const sctpMaxChannels = uint16(65535)
  18. // SCTPTransport provides details about the SCTP transport.
  19. type SCTPTransport struct {
  20. lock sync.RWMutex
  21. dtlsTransport *DTLSTransport
  22. // State represents the current state of the SCTP transport.
  23. state SCTPTransportState
  24. // SCTPTransportState doesn't have an enum to distinguish between New/Connecting
  25. // so we need a dedicated field
  26. isStarted bool
  27. // MaxMessageSize represents the maximum size of data that can be passed to
  28. // DataChannel's send() method.
  29. maxMessageSize float64
  30. // MaxChannels represents the maximum amount of DataChannel's that can
  31. // be used simultaneously.
  32. maxChannels *uint16
  33. // OnStateChange func()
  34. onErrorHandler func(error)
  35. sctpAssociation *sctp.Association
  36. onDataChannelHandler func(*DataChannel)
  37. onDataChannelOpenedHandler func(*DataChannel)
  38. // DataChannels
  39. dataChannels []*DataChannel
  40. dataChannelsOpened uint32
  41. dataChannelsRequested uint32
  42. dataChannelsAccepted uint32
  43. api *API
  44. log logging.LeveledLogger
  45. }
  46. // NewSCTPTransport creates a new SCTPTransport.
  47. // This constructor is part of the ORTC API. It is not
  48. // meant to be used together with the basic WebRTC API.
  49. func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport {
  50. res := &SCTPTransport{
  51. dtlsTransport: dtls,
  52. state: SCTPTransportStateConnecting,
  53. api: api,
  54. log: api.settingEngine.LoggerFactory.NewLogger("ortc"),
  55. }
  56. res.updateMessageSize()
  57. res.updateMaxChannels()
  58. return res
  59. }
  60. // Transport returns the DTLSTransport instance the SCTPTransport is sending over.
  61. func (r *SCTPTransport) Transport() *DTLSTransport {
  62. r.lock.RLock()
  63. defer r.lock.RUnlock()
  64. return r.dtlsTransport
  65. }
  66. // GetCapabilities returns the SCTPCapabilities of the SCTPTransport.
  67. func (r *SCTPTransport) GetCapabilities() SCTPCapabilities {
  68. return SCTPCapabilities{
  69. MaxMessageSize: 0,
  70. }
  71. }
  72. // Start the SCTPTransport. Since both local and remote parties must mutually
  73. // create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish
  74. // a connection over SCTP.
  75. func (r *SCTPTransport) Start(SCTPCapabilities) error {
  76. if r.isStarted {
  77. return nil
  78. }
  79. r.isStarted = true
  80. dtlsTransport := r.Transport()
  81. if dtlsTransport == nil || dtlsTransport.conn == nil {
  82. return errSCTPTransportDTLS
  83. }
  84. sctpAssociation, err := sctp.Client(sctp.Config{
  85. NetConn: dtlsTransport.conn,
  86. MaxReceiveBufferSize: r.api.settingEngine.sctp.maxReceiveBufferSize,
  87. LoggerFactory: r.api.settingEngine.LoggerFactory,
  88. })
  89. if err != nil {
  90. return err
  91. }
  92. r.lock.Lock()
  93. r.sctpAssociation = sctpAssociation
  94. r.state = SCTPTransportStateConnected
  95. dataChannels := append([]*DataChannel{}, r.dataChannels...)
  96. r.lock.Unlock()
  97. var openedDCCount uint32
  98. for _, d := range dataChannels {
  99. if d.ReadyState() == DataChannelStateConnecting {
  100. err := d.open(r)
  101. if err != nil {
  102. r.log.Warnf("failed to open data channel: %s", err)
  103. continue
  104. }
  105. openedDCCount++
  106. }
  107. }
  108. r.lock.Lock()
  109. r.dataChannelsOpened += openedDCCount
  110. r.lock.Unlock()
  111. go r.acceptDataChannels(sctpAssociation)
  112. return nil
  113. }
  114. // Stop stops the SCTPTransport
  115. func (r *SCTPTransport) Stop() error {
  116. r.lock.Lock()
  117. defer r.lock.Unlock()
  118. if r.sctpAssociation == nil {
  119. return nil
  120. }
  121. err := r.sctpAssociation.Close()
  122. if err != nil {
  123. return err
  124. }
  125. r.sctpAssociation = nil
  126. r.state = SCTPTransportStateClosed
  127. return nil
  128. }
  129. func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) {
  130. r.lock.RLock()
  131. dataChannels := make([]*datachannel.DataChannel, 0, len(r.dataChannels))
  132. for _, dc := range r.dataChannels {
  133. dc.mu.Lock()
  134. isNil := dc.dataChannel == nil
  135. dc.mu.Unlock()
  136. if isNil {
  137. continue
  138. }
  139. dataChannels = append(dataChannels, dc.dataChannel)
  140. }
  141. r.lock.RUnlock()
  142. ACCEPT:
  143. for {
  144. dc, err := datachannel.Accept(a, &datachannel.Config{
  145. LoggerFactory: r.api.settingEngine.LoggerFactory,
  146. }, dataChannels...)
  147. if err != nil {
  148. if !errors.Is(err, io.EOF) {
  149. r.log.Errorf("Failed to accept data channel: %v", err)
  150. r.onError(err)
  151. }
  152. return
  153. }
  154. for _, ch := range dataChannels {
  155. if ch.StreamIdentifier() == dc.StreamIdentifier() {
  156. continue ACCEPT
  157. }
  158. }
  159. var (
  160. maxRetransmits *uint16
  161. maxPacketLifeTime *uint16
  162. )
  163. val := uint16(dc.Config.ReliabilityParameter)
  164. ordered := true
  165. switch dc.Config.ChannelType {
  166. case datachannel.ChannelTypeReliable:
  167. ordered = true
  168. case datachannel.ChannelTypeReliableUnordered:
  169. ordered = false
  170. case datachannel.ChannelTypePartialReliableRexmit:
  171. ordered = true
  172. maxRetransmits = &val
  173. case datachannel.ChannelTypePartialReliableRexmitUnordered:
  174. ordered = false
  175. maxRetransmits = &val
  176. case datachannel.ChannelTypePartialReliableTimed:
  177. ordered = true
  178. maxPacketLifeTime = &val
  179. case datachannel.ChannelTypePartialReliableTimedUnordered:
  180. ordered = false
  181. maxPacketLifeTime = &val
  182. default:
  183. }
  184. sid := dc.StreamIdentifier()
  185. rtcDC, err := r.api.newDataChannel(&DataChannelParameters{
  186. ID: &sid,
  187. Label: dc.Config.Label,
  188. Protocol: dc.Config.Protocol,
  189. Negotiated: dc.Config.Negotiated,
  190. Ordered: ordered,
  191. MaxPacketLifeTime: maxPacketLifeTime,
  192. MaxRetransmits: maxRetransmits,
  193. }, r, r.api.settingEngine.LoggerFactory.NewLogger("ortc"))
  194. if err != nil {
  195. r.log.Errorf("Failed to accept data channel: %v", err)
  196. r.onError(err)
  197. return
  198. }
  199. <-r.onDataChannel(rtcDC)
  200. rtcDC.handleOpen(dc, true, dc.Config.Negotiated)
  201. r.lock.Lock()
  202. r.dataChannelsOpened++
  203. handler := r.onDataChannelOpenedHandler
  204. r.lock.Unlock()
  205. if handler != nil {
  206. handler(rtcDC)
  207. }
  208. }
  209. }
  210. // OnError sets an event handler which is invoked when
  211. // the SCTP connection error occurs.
  212. func (r *SCTPTransport) OnError(f func(err error)) {
  213. r.lock.Lock()
  214. defer r.lock.Unlock()
  215. r.onErrorHandler = f
  216. }
  217. func (r *SCTPTransport) onError(err error) {
  218. r.lock.RLock()
  219. handler := r.onErrorHandler
  220. r.lock.RUnlock()
  221. if handler != nil {
  222. go handler(err)
  223. }
  224. }
  225. // OnDataChannel sets an event handler which is invoked when a data
  226. // channel message arrives from a remote peer.
  227. func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) {
  228. r.lock.Lock()
  229. defer r.lock.Unlock()
  230. r.onDataChannelHandler = f
  231. }
  232. // OnDataChannelOpened sets an event handler which is invoked when a data
  233. // channel is opened
  234. func (r *SCTPTransport) OnDataChannelOpened(f func(*DataChannel)) {
  235. r.lock.Lock()
  236. defer r.lock.Unlock()
  237. r.onDataChannelOpenedHandler = f
  238. }
  239. func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) {
  240. r.lock.Lock()
  241. r.dataChannels = append(r.dataChannels, dc)
  242. r.dataChannelsAccepted++
  243. handler := r.onDataChannelHandler
  244. r.lock.Unlock()
  245. done = make(chan struct{})
  246. if handler == nil || dc == nil {
  247. close(done)
  248. return
  249. }
  250. // Run this synchronously to allow setup done in onDataChannelFn()
  251. // to complete before datachannel event handlers might be called.
  252. go func() {
  253. handler(dc)
  254. close(done)
  255. }()
  256. return
  257. }
  258. func (r *SCTPTransport) updateMessageSize() {
  259. r.lock.Lock()
  260. defer r.lock.Unlock()
  261. var remoteMaxMessageSize float64 = 65536 // pion/webrtc#758
  262. var canSendSize float64 = 65536 // pion/webrtc#758
  263. r.maxMessageSize = r.calcMessageSize(remoteMaxMessageSize, canSendSize)
  264. }
  265. func (r *SCTPTransport) calcMessageSize(remoteMaxMessageSize, canSendSize float64) float64 {
  266. switch {
  267. case remoteMaxMessageSize == 0 &&
  268. canSendSize == 0:
  269. return math.Inf(1)
  270. case remoteMaxMessageSize == 0:
  271. return canSendSize
  272. case canSendSize == 0:
  273. return remoteMaxMessageSize
  274. case canSendSize > remoteMaxMessageSize:
  275. return remoteMaxMessageSize
  276. default:
  277. return canSendSize
  278. }
  279. }
  280. func (r *SCTPTransport) updateMaxChannels() {
  281. val := sctpMaxChannels
  282. r.maxChannels = &val
  283. }
  284. // MaxChannels is the maximum number of RTCDataChannels that can be open simultaneously.
  285. func (r *SCTPTransport) MaxChannels() uint16 {
  286. r.lock.Lock()
  287. defer r.lock.Unlock()
  288. if r.maxChannels == nil {
  289. return sctpMaxChannels
  290. }
  291. return *r.maxChannels
  292. }
  293. // State returns the current state of the SCTPTransport
  294. func (r *SCTPTransport) State() SCTPTransportState {
  295. r.lock.RLock()
  296. defer r.lock.RUnlock()
  297. return r.state
  298. }
  299. func (r *SCTPTransport) collectStats(collector *statsReportCollector) {
  300. collector.Collecting()
  301. stats := SCTPTransportStats{
  302. Timestamp: statsTimestampFrom(time.Now()),
  303. Type: StatsTypeSCTPTransport,
  304. ID: "sctpTransport",
  305. }
  306. association := r.association()
  307. if association != nil {
  308. stats.BytesSent = association.BytesSent()
  309. stats.BytesReceived = association.BytesReceived()
  310. stats.SmoothedRoundTripTime = association.SRTT() * 0.001 // convert milliseconds to seconds
  311. stats.CongestionWindow = association.CWND()
  312. stats.ReceiverWindow = association.RWND()
  313. stats.MTU = association.MTU()
  314. }
  315. collector.Collect(stats.ID, stats)
  316. }
  317. func (r *SCTPTransport) generateAndSetDataChannelID(dtlsRole DTLSRole, idOut **uint16) error {
  318. var id uint16
  319. if dtlsRole != DTLSRoleClient {
  320. id++
  321. }
  322. max := r.MaxChannels()
  323. r.lock.Lock()
  324. defer r.lock.Unlock()
  325. // Create map of ids so we can compare without double-looping each time.
  326. idsMap := make(map[uint16]struct{}, len(r.dataChannels))
  327. for _, dc := range r.dataChannels {
  328. if dc.ID() == nil {
  329. continue
  330. }
  331. idsMap[*dc.ID()] = struct{}{}
  332. }
  333. for ; id < max-1; id += 2 {
  334. if _, ok := idsMap[id]; ok {
  335. continue
  336. }
  337. *idOut = &id
  338. return nil
  339. }
  340. return &rtcerr.OperationError{Err: ErrMaxDataChannelID}
  341. }
  342. func (r *SCTPTransport) association() *sctp.Association {
  343. if r == nil {
  344. return nil
  345. }
  346. r.lock.RLock()
  347. association := r.sctpAssociation
  348. r.lock.RUnlock()
  349. return association
  350. }