sctptransport.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. //go:build !js
  4. // +build !js
  5. package webrtc
  6. import (
  7. "errors"
  8. "io"
  9. "math"
  10. "sync"
  11. "time"
  12. "github.com/pion/datachannel"
  13. "github.com/pion/logging"
  14. "github.com/pion/sctp"
  15. "github.com/pion/webrtc/v3/pkg/rtcerr"
  16. )
  17. const sctpMaxChannels = uint16(65535)
  18. // SCTPTransport provides details about the SCTP transport.
  19. type SCTPTransport struct {
  20. lock sync.RWMutex
  21. dtlsTransport *DTLSTransport
  22. // State represents the current state of the SCTP transport.
  23. state SCTPTransportState
  24. // SCTPTransportState doesn't have an enum to distinguish between New/Connecting
  25. // so we need a dedicated field
  26. isStarted bool
  27. // MaxMessageSize represents the maximum size of data that can be passed to
  28. // DataChannel's send() method.
  29. maxMessageSize float64
  30. // MaxChannels represents the maximum amount of DataChannel's that can
  31. // be used simultaneously.
  32. maxChannels *uint16
  33. // OnStateChange func()
  34. onErrorHandler func(error)
  35. sctpAssociation *sctp.Association
  36. onDataChannelHandler func(*DataChannel)
  37. onDataChannelOpenedHandler func(*DataChannel)
  38. // DataChannels
  39. dataChannels []*DataChannel
  40. dataChannelsOpened uint32
  41. dataChannelsRequested uint32
  42. dataChannelsAccepted uint32
  43. api *API
  44. log logging.LeveledLogger
  45. }
  46. // NewSCTPTransport creates a new SCTPTransport.
  47. // This constructor is part of the ORTC API. It is not
  48. // meant to be used together with the basic WebRTC API.
  49. func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport {
  50. res := &SCTPTransport{
  51. dtlsTransport: dtls,
  52. state: SCTPTransportStateConnecting,
  53. api: api,
  54. log: api.settingEngine.LoggerFactory.NewLogger("ortc"),
  55. }
  56. res.updateMessageSize()
  57. res.updateMaxChannels()
  58. return res
  59. }
  60. // Transport returns the DTLSTransport instance the SCTPTransport is sending over.
  61. func (r *SCTPTransport) Transport() *DTLSTransport {
  62. r.lock.RLock()
  63. defer r.lock.RUnlock()
  64. return r.dtlsTransport
  65. }
  66. // GetCapabilities returns the SCTPCapabilities of the SCTPTransport.
  67. func (r *SCTPTransport) GetCapabilities() SCTPCapabilities {
  68. return SCTPCapabilities{
  69. MaxMessageSize: 0,
  70. }
  71. }
  72. // Start the SCTPTransport. Since both local and remote parties must mutually
  73. // create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish
  74. // a connection over SCTP.
  75. func (r *SCTPTransport) Start(SCTPCapabilities) error {
  76. if r.isStarted {
  77. return nil
  78. }
  79. r.isStarted = true
  80. dtlsTransport := r.Transport()
  81. if dtlsTransport == nil || dtlsTransport.conn == nil {
  82. return errSCTPTransportDTLS
  83. }
  84. sctpAssociation, err := sctp.Client(sctp.Config{
  85. NetConn: dtlsTransport.conn,
  86. MaxReceiveBufferSize: r.api.settingEngine.sctp.maxReceiveBufferSize,
  87. EnableZeroChecksum: r.api.settingEngine.sctp.enableZeroChecksum,
  88. LoggerFactory: r.api.settingEngine.LoggerFactory,
  89. })
  90. if err != nil {
  91. return err
  92. }
  93. r.lock.Lock()
  94. r.sctpAssociation = sctpAssociation
  95. r.state = SCTPTransportStateConnected
  96. dataChannels := append([]*DataChannel{}, r.dataChannels...)
  97. r.lock.Unlock()
  98. var openedDCCount uint32
  99. for _, d := range dataChannels {
  100. if d.ReadyState() == DataChannelStateConnecting {
  101. err := d.open(r)
  102. if err != nil {
  103. r.log.Warnf("failed to open data channel: %s", err)
  104. continue
  105. }
  106. openedDCCount++
  107. }
  108. }
  109. r.lock.Lock()
  110. r.dataChannelsOpened += openedDCCount
  111. r.lock.Unlock()
  112. go r.acceptDataChannels(sctpAssociation)
  113. return nil
  114. }
  115. // Stop stops the SCTPTransport
  116. func (r *SCTPTransport) Stop() error {
  117. r.lock.Lock()
  118. defer r.lock.Unlock()
  119. if r.sctpAssociation == nil {
  120. return nil
  121. }
  122. err := r.sctpAssociation.Close()
  123. if err != nil {
  124. return err
  125. }
  126. r.sctpAssociation = nil
  127. r.state = SCTPTransportStateClosed
  128. return nil
  129. }
  130. func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) {
  131. r.lock.RLock()
  132. dataChannels := make([]*datachannel.DataChannel, 0, len(r.dataChannels))
  133. for _, dc := range r.dataChannels {
  134. dc.mu.Lock()
  135. isNil := dc.dataChannel == nil
  136. dc.mu.Unlock()
  137. if isNil {
  138. continue
  139. }
  140. dataChannels = append(dataChannels, dc.dataChannel)
  141. }
  142. r.lock.RUnlock()
  143. ACCEPT:
  144. for {
  145. dc, err := datachannel.Accept(a, &datachannel.Config{
  146. LoggerFactory: r.api.settingEngine.LoggerFactory,
  147. }, dataChannels...)
  148. if err != nil {
  149. if !errors.Is(err, io.EOF) {
  150. r.log.Errorf("Failed to accept data channel: %v", err)
  151. r.onError(err)
  152. }
  153. return
  154. }
  155. for _, ch := range dataChannels {
  156. if ch.StreamIdentifier() == dc.StreamIdentifier() {
  157. continue ACCEPT
  158. }
  159. }
  160. var (
  161. maxRetransmits *uint16
  162. maxPacketLifeTime *uint16
  163. )
  164. val := uint16(dc.Config.ReliabilityParameter)
  165. ordered := true
  166. switch dc.Config.ChannelType {
  167. case datachannel.ChannelTypeReliable:
  168. ordered = true
  169. case datachannel.ChannelTypeReliableUnordered:
  170. ordered = false
  171. case datachannel.ChannelTypePartialReliableRexmit:
  172. ordered = true
  173. maxRetransmits = &val
  174. case datachannel.ChannelTypePartialReliableRexmitUnordered:
  175. ordered = false
  176. maxRetransmits = &val
  177. case datachannel.ChannelTypePartialReliableTimed:
  178. ordered = true
  179. maxPacketLifeTime = &val
  180. case datachannel.ChannelTypePartialReliableTimedUnordered:
  181. ordered = false
  182. maxPacketLifeTime = &val
  183. default:
  184. }
  185. sid := dc.StreamIdentifier()
  186. rtcDC, err := r.api.newDataChannel(&DataChannelParameters{
  187. ID: &sid,
  188. Label: dc.Config.Label,
  189. Protocol: dc.Config.Protocol,
  190. Negotiated: dc.Config.Negotiated,
  191. Ordered: ordered,
  192. MaxPacketLifeTime: maxPacketLifeTime,
  193. MaxRetransmits: maxRetransmits,
  194. }, r, r.api.settingEngine.LoggerFactory.NewLogger("ortc"))
  195. if err != nil {
  196. r.log.Errorf("Failed to accept data channel: %v", err)
  197. r.onError(err)
  198. return
  199. }
  200. <-r.onDataChannel(rtcDC)
  201. rtcDC.handleOpen(dc, true, dc.Config.Negotiated)
  202. r.lock.Lock()
  203. r.dataChannelsOpened++
  204. handler := r.onDataChannelOpenedHandler
  205. r.lock.Unlock()
  206. if handler != nil {
  207. handler(rtcDC)
  208. }
  209. }
  210. }
  211. // OnError sets an event handler which is invoked when
  212. // the SCTP connection error occurs.
  213. func (r *SCTPTransport) OnError(f func(err error)) {
  214. r.lock.Lock()
  215. defer r.lock.Unlock()
  216. r.onErrorHandler = f
  217. }
  218. func (r *SCTPTransport) onError(err error) {
  219. r.lock.RLock()
  220. handler := r.onErrorHandler
  221. r.lock.RUnlock()
  222. if handler != nil {
  223. go handler(err)
  224. }
  225. }
  226. // OnDataChannel sets an event handler which is invoked when a data
  227. // channel message arrives from a remote peer.
  228. func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) {
  229. r.lock.Lock()
  230. defer r.lock.Unlock()
  231. r.onDataChannelHandler = f
  232. }
  233. // OnDataChannelOpened sets an event handler which is invoked when a data
  234. // channel is opened
  235. func (r *SCTPTransport) OnDataChannelOpened(f func(*DataChannel)) {
  236. r.lock.Lock()
  237. defer r.lock.Unlock()
  238. r.onDataChannelOpenedHandler = f
  239. }
  240. func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) {
  241. r.lock.Lock()
  242. r.dataChannels = append(r.dataChannels, dc)
  243. r.dataChannelsAccepted++
  244. handler := r.onDataChannelHandler
  245. r.lock.Unlock()
  246. done = make(chan struct{})
  247. if handler == nil || dc == nil {
  248. close(done)
  249. return
  250. }
  251. // Run this synchronously to allow setup done in onDataChannelFn()
  252. // to complete before datachannel event handlers might be called.
  253. go func() {
  254. handler(dc)
  255. close(done)
  256. }()
  257. return
  258. }
  259. func (r *SCTPTransport) updateMessageSize() {
  260. r.lock.Lock()
  261. defer r.lock.Unlock()
  262. var remoteMaxMessageSize float64 = 65536 // pion/webrtc#758
  263. var canSendSize float64 = 65536 // pion/webrtc#758
  264. r.maxMessageSize = r.calcMessageSize(remoteMaxMessageSize, canSendSize)
  265. }
  266. func (r *SCTPTransport) calcMessageSize(remoteMaxMessageSize, canSendSize float64) float64 {
  267. switch {
  268. case remoteMaxMessageSize == 0 &&
  269. canSendSize == 0:
  270. return math.Inf(1)
  271. case remoteMaxMessageSize == 0:
  272. return canSendSize
  273. case canSendSize == 0:
  274. return remoteMaxMessageSize
  275. case canSendSize > remoteMaxMessageSize:
  276. return remoteMaxMessageSize
  277. default:
  278. return canSendSize
  279. }
  280. }
  281. func (r *SCTPTransport) updateMaxChannels() {
  282. val := sctpMaxChannels
  283. r.maxChannels = &val
  284. }
  285. // MaxChannels is the maximum number of RTCDataChannels that can be open simultaneously.
  286. func (r *SCTPTransport) MaxChannels() uint16 {
  287. r.lock.Lock()
  288. defer r.lock.Unlock()
  289. if r.maxChannels == nil {
  290. return sctpMaxChannels
  291. }
  292. return *r.maxChannels
  293. }
  294. // State returns the current state of the SCTPTransport
  295. func (r *SCTPTransport) State() SCTPTransportState {
  296. r.lock.RLock()
  297. defer r.lock.RUnlock()
  298. return r.state
  299. }
  300. func (r *SCTPTransport) collectStats(collector *statsReportCollector) {
  301. collector.Collecting()
  302. stats := SCTPTransportStats{
  303. Timestamp: statsTimestampFrom(time.Now()),
  304. Type: StatsTypeSCTPTransport,
  305. ID: "sctpTransport",
  306. }
  307. association := r.association()
  308. if association != nil {
  309. stats.BytesSent = association.BytesSent()
  310. stats.BytesReceived = association.BytesReceived()
  311. stats.SmoothedRoundTripTime = association.SRTT() * 0.001 // convert milliseconds to seconds
  312. stats.CongestionWindow = association.CWND()
  313. stats.ReceiverWindow = association.RWND()
  314. stats.MTU = association.MTU()
  315. }
  316. collector.Collect(stats.ID, stats)
  317. }
  318. func (r *SCTPTransport) generateAndSetDataChannelID(dtlsRole DTLSRole, idOut **uint16) error {
  319. var id uint16
  320. if dtlsRole != DTLSRoleClient {
  321. id++
  322. }
  323. max := r.MaxChannels()
  324. r.lock.Lock()
  325. defer r.lock.Unlock()
  326. // Create map of ids so we can compare without double-looping each time.
  327. idsMap := make(map[uint16]struct{}, len(r.dataChannels))
  328. for _, dc := range r.dataChannels {
  329. if dc.ID() == nil {
  330. continue
  331. }
  332. idsMap[*dc.ID()] = struct{}{}
  333. }
  334. for ; id < max-1; id += 2 {
  335. if _, ok := idsMap[id]; ok {
  336. continue
  337. }
  338. *idOut = &id
  339. return nil
  340. }
  341. return &rtcerr.OperationError{Err: ErrMaxDataChannelID}
  342. }
  343. func (r *SCTPTransport) association() *sctp.Association {
  344. if r == nil {
  345. return nil
  346. }
  347. r.lock.RLock()
  348. association := r.sctpAssociation
  349. r.lock.RUnlock()
  350. return association
  351. }