stream.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package sctp
  4. import (
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math"
  9. "os"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. "github.com/pion/logging"
  14. )
  15. const (
  16. // ReliabilityTypeReliable is used for reliable transmission
  17. ReliabilityTypeReliable byte = 0
  18. // ReliabilityTypeRexmit is used for partial reliability by retransmission count
  19. ReliabilityTypeRexmit byte = 1
  20. // ReliabilityTypeTimed is used for partial reliability by retransmission duration
  21. ReliabilityTypeTimed byte = 2
  22. )
  23. // StreamState is an enum for SCTP Stream state field
  24. // This field identifies the state of stream.
  25. type StreamState int
  26. // StreamState enums
  27. const (
  28. StreamStateOpen StreamState = iota // Stream object starts with StreamStateOpen
  29. StreamStateClosing // Outgoing stream is being reset
  30. StreamStateClosed // Stream has been closed
  31. )
  32. func (ss StreamState) String() string {
  33. switch ss {
  34. case StreamStateOpen:
  35. return "open"
  36. case StreamStateClosing:
  37. return "closing"
  38. case StreamStateClosed:
  39. return "closed"
  40. }
  41. return "unknown"
  42. }
  43. // SCTP stream errors
  44. var (
  45. ErrOutboundPacketTooLarge = errors.New("outbound packet larger than maximum message size")
  46. ErrStreamClosed = errors.New("stream closed")
  47. ErrReadDeadlineExceeded = fmt.Errorf("read deadline exceeded: %w", os.ErrDeadlineExceeded)
  48. )
  49. // Stream represents an SCTP stream
  50. type Stream struct {
  51. association *Association
  52. lock sync.RWMutex
  53. streamIdentifier uint16
  54. defaultPayloadType PayloadProtocolIdentifier
  55. reassemblyQueue *reassemblyQueue
  56. sequenceNumber uint16
  57. readNotifier *sync.Cond
  58. readErr error
  59. readTimeoutCancel chan struct{}
  60. unordered bool
  61. reliabilityType byte
  62. reliabilityValue uint32
  63. bufferedAmount uint64
  64. bufferedAmountLow uint64
  65. onBufferedAmountLow func()
  66. state StreamState
  67. log logging.LeveledLogger
  68. name string
  69. }
  70. // StreamIdentifier returns the Stream identifier associated to the stream.
  71. func (s *Stream) StreamIdentifier() uint16 {
  72. s.lock.RLock()
  73. defer s.lock.RUnlock()
  74. return s.streamIdentifier
  75. }
  76. // SetDefaultPayloadType sets the default payload type used by Write.
  77. func (s *Stream) SetDefaultPayloadType(defaultPayloadType PayloadProtocolIdentifier) {
  78. atomic.StoreUint32((*uint32)(&s.defaultPayloadType), uint32(defaultPayloadType))
  79. }
  80. // SetReliabilityParams sets reliability parameters for this stream.
  81. func (s *Stream) SetReliabilityParams(unordered bool, relType byte, relVal uint32) {
  82. s.lock.Lock()
  83. defer s.lock.Unlock()
  84. s.setReliabilityParams(unordered, relType, relVal)
  85. }
  86. // setReliabilityParams sets reliability parameters for this stream.
  87. // The caller should hold the lock.
  88. func (s *Stream) setReliabilityParams(unordered bool, relType byte, relVal uint32) {
  89. s.log.Debugf("[%s] reliability params: ordered=%v type=%d value=%d",
  90. s.name, !unordered, relType, relVal)
  91. s.unordered = unordered
  92. s.reliabilityType = relType
  93. s.reliabilityValue = relVal
  94. }
  95. // Read reads a packet of len(p) bytes, dropping the Payload Protocol Identifier.
  96. // Returns EOF when the stream is reset or an error if the stream is closed
  97. // otherwise.
  98. func (s *Stream) Read(p []byte) (int, error) {
  99. n, _, err := s.ReadSCTP(p)
  100. return n, err
  101. }
  102. // ReadSCTP reads a packet of len(p) bytes and returns the associated Payload
  103. // Protocol Identifier.
  104. // Returns EOF when the stream is reset or an error if the stream is closed
  105. // otherwise.
  106. func (s *Stream) ReadSCTP(p []byte) (int, PayloadProtocolIdentifier, error) {
  107. s.lock.Lock()
  108. defer s.lock.Unlock()
  109. defer func() {
  110. // close readTimeoutCancel if the current read timeout routine is no longer effective
  111. if s.readTimeoutCancel != nil && s.readErr != nil {
  112. close(s.readTimeoutCancel)
  113. s.readTimeoutCancel = nil
  114. }
  115. }()
  116. for {
  117. n, ppi, err := s.reassemblyQueue.read(p)
  118. if err == nil {
  119. return n, ppi, nil
  120. } else if errors.Is(err, io.ErrShortBuffer) {
  121. return 0, PayloadProtocolIdentifier(0), err
  122. }
  123. err = s.readErr
  124. if err != nil {
  125. return 0, PayloadProtocolIdentifier(0), err
  126. }
  127. s.readNotifier.Wait()
  128. }
  129. }
  130. // SetReadDeadline sets the read deadline in an identical way to net.Conn
  131. func (s *Stream) SetReadDeadline(deadline time.Time) error {
  132. s.lock.Lock()
  133. defer s.lock.Unlock()
  134. if s.readTimeoutCancel != nil {
  135. close(s.readTimeoutCancel)
  136. s.readTimeoutCancel = nil
  137. }
  138. if s.readErr != nil {
  139. if !errors.Is(s.readErr, ErrReadDeadlineExceeded) {
  140. return nil
  141. }
  142. s.readErr = nil
  143. }
  144. if !deadline.IsZero() {
  145. s.readTimeoutCancel = make(chan struct{})
  146. go func(readTimeoutCancel chan struct{}) {
  147. t := time.NewTimer(time.Until(deadline))
  148. select {
  149. case <-readTimeoutCancel:
  150. t.Stop()
  151. return
  152. case <-t.C:
  153. s.lock.Lock()
  154. if s.readErr == nil {
  155. s.readErr = ErrReadDeadlineExceeded
  156. }
  157. s.readTimeoutCancel = nil
  158. s.lock.Unlock()
  159. s.readNotifier.Signal()
  160. }
  161. }(s.readTimeoutCancel)
  162. }
  163. return nil
  164. }
  165. func (s *Stream) handleData(pd *chunkPayloadData) {
  166. s.lock.Lock()
  167. defer s.lock.Unlock()
  168. var readable bool
  169. if s.reassemblyQueue.push(pd) {
  170. readable = s.reassemblyQueue.isReadable()
  171. s.log.Debugf("[%s] reassemblyQueue readable=%v", s.name, readable)
  172. if readable {
  173. s.log.Debugf("[%s] readNotifier.signal()", s.name)
  174. s.readNotifier.Signal()
  175. s.log.Debugf("[%s] readNotifier.signal() done", s.name)
  176. }
  177. }
  178. }
  179. func (s *Stream) handleForwardTSNForOrdered(ssn uint16) {
  180. var readable bool
  181. func() {
  182. s.lock.Lock()
  183. defer s.lock.Unlock()
  184. if s.unordered {
  185. return // unordered chunks are handled by handleForwardUnordered method
  186. }
  187. // Remove all chunks older than or equal to the new TSN from
  188. // the reassemblyQueue.
  189. s.reassemblyQueue.forwardTSNForOrdered(ssn)
  190. readable = s.reassemblyQueue.isReadable()
  191. }()
  192. // Notify the reader asynchronously if there's a data chunk to read.
  193. if readable {
  194. s.readNotifier.Signal()
  195. }
  196. }
  197. func (s *Stream) handleForwardTSNForUnordered(newCumulativeTSN uint32) {
  198. var readable bool
  199. func() {
  200. s.lock.Lock()
  201. defer s.lock.Unlock()
  202. if !s.unordered {
  203. return // ordered chunks are handled by handleForwardTSNOrdered method
  204. }
  205. // Remove all chunks older than or equal to the new TSN from
  206. // the reassemblyQueue.
  207. s.reassemblyQueue.forwardTSNForUnordered(newCumulativeTSN)
  208. readable = s.reassemblyQueue.isReadable()
  209. }()
  210. // Notify the reader asynchronously if there's a data chunk to read.
  211. if readable {
  212. s.readNotifier.Signal()
  213. }
  214. }
  215. // Write writes len(p) bytes from p with the default Payload Protocol Identifier
  216. func (s *Stream) Write(p []byte) (n int, err error) {
  217. ppi := PayloadProtocolIdentifier(atomic.LoadUint32((*uint32)(&s.defaultPayloadType)))
  218. return s.WriteSCTP(p, ppi)
  219. }
  220. // WriteSCTP writes len(p) bytes from p to the DTLS connection
  221. func (s *Stream) WriteSCTP(p []byte, ppi PayloadProtocolIdentifier) (int, error) {
  222. maxMessageSize := s.association.MaxMessageSize()
  223. if len(p) > int(maxMessageSize) {
  224. return 0, fmt.Errorf("%w: %v", ErrOutboundPacketTooLarge, math.MaxUint16)
  225. }
  226. if s.State() != StreamStateOpen {
  227. return 0, ErrStreamClosed
  228. }
  229. chunks := s.packetize(p, ppi)
  230. n := len(p)
  231. err := s.association.sendPayloadData(chunks)
  232. if err != nil {
  233. return n, ErrStreamClosed
  234. }
  235. return n, nil
  236. }
  237. func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) []*chunkPayloadData {
  238. s.lock.Lock()
  239. defer s.lock.Unlock()
  240. i := uint32(0)
  241. remaining := uint32(len(raw))
  242. // From draft-ietf-rtcweb-data-protocol-09, section 6:
  243. // All Data Channel Establishment Protocol messages MUST be sent using
  244. // ordered delivery and reliable transmission.
  245. unordered := ppi != PayloadTypeWebRTCDCEP && s.unordered
  246. var chunks []*chunkPayloadData
  247. var head *chunkPayloadData
  248. for remaining != 0 {
  249. fragmentSize := min32(s.association.maxPayloadSize, remaining)
  250. // Copy the userdata since we'll have to store it until acked
  251. // and the caller may re-use the buffer in the mean time
  252. userData := make([]byte, fragmentSize)
  253. copy(userData, raw[i:i+fragmentSize])
  254. chunk := &chunkPayloadData{
  255. streamIdentifier: s.streamIdentifier,
  256. userData: userData,
  257. unordered: unordered,
  258. beginningFragment: i == 0,
  259. endingFragment: remaining-fragmentSize == 0,
  260. immediateSack: false,
  261. payloadType: ppi,
  262. streamSequenceNumber: s.sequenceNumber,
  263. head: head,
  264. }
  265. if head == nil {
  266. head = chunk
  267. }
  268. chunks = append(chunks, chunk)
  269. remaining -= fragmentSize
  270. i += fragmentSize
  271. }
  272. // RFC 4960 Sec 6.6
  273. // Note: When transmitting ordered and unordered data, an endpoint does
  274. // not increment its Stream Sequence Number when transmitting a DATA
  275. // chunk with U flag set to 1.
  276. if !unordered {
  277. s.sequenceNumber++
  278. }
  279. s.bufferedAmount += uint64(len(raw))
  280. s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount)
  281. return chunks
  282. }
  283. // Close closes the write-direction of the stream.
  284. // Future calls to Write are not permitted after calling Close.
  285. func (s *Stream) Close() error {
  286. if sid, resetOutbound := func() (uint16, bool) {
  287. s.lock.Lock()
  288. defer s.lock.Unlock()
  289. s.log.Debugf("[%s] Close: state=%s", s.name, s.state.String())
  290. if s.state == StreamStateOpen {
  291. if s.readErr == nil {
  292. s.state = StreamStateClosing
  293. } else {
  294. s.state = StreamStateClosed
  295. }
  296. s.log.Debugf("[%s] state change: open => %s", s.name, s.state.String())
  297. return s.streamIdentifier, true
  298. }
  299. return s.streamIdentifier, false
  300. }(); resetOutbound {
  301. // Reset the outgoing stream
  302. // https://tools.ietf.org/html/rfc6525
  303. return s.association.sendResetRequest(sid)
  304. }
  305. return nil
  306. }
  307. // BufferedAmount returns the number of bytes of data currently queued to be sent over this stream.
  308. func (s *Stream) BufferedAmount() uint64 {
  309. s.lock.RLock()
  310. defer s.lock.RUnlock()
  311. return s.bufferedAmount
  312. }
  313. // BufferedAmountLowThreshold returns the number of bytes of buffered outgoing data that is
  314. // considered "low." Defaults to 0.
  315. func (s *Stream) BufferedAmountLowThreshold() uint64 {
  316. s.lock.RLock()
  317. defer s.lock.RUnlock()
  318. return s.bufferedAmountLow
  319. }
  320. // SetBufferedAmountLowThreshold is used to update the threshold.
  321. // See BufferedAmountLowThreshold().
  322. func (s *Stream) SetBufferedAmountLowThreshold(th uint64) {
  323. s.lock.Lock()
  324. defer s.lock.Unlock()
  325. s.bufferedAmountLow = th
  326. }
  327. // OnBufferedAmountLow sets the callback handler which would be called when the number of
  328. // bytes of outgoing data buffered is lower than the threshold.
  329. func (s *Stream) OnBufferedAmountLow(f func()) {
  330. s.lock.Lock()
  331. defer s.lock.Unlock()
  332. s.onBufferedAmountLow = f
  333. }
  334. // This method is called by association's readLoop (go-)routine to notify this stream
  335. // of the specified amount of outgoing data has been delivered to the peer.
  336. func (s *Stream) onBufferReleased(nBytesReleased int) {
  337. if nBytesReleased <= 0 {
  338. return
  339. }
  340. s.lock.Lock()
  341. fromAmount := s.bufferedAmount
  342. if s.bufferedAmount < uint64(nBytesReleased) {
  343. s.bufferedAmount = 0
  344. s.log.Errorf("[%s] released buffer size %d should be <= %d",
  345. s.name, nBytesReleased, s.bufferedAmount)
  346. } else {
  347. s.bufferedAmount -= uint64(nBytesReleased)
  348. }
  349. s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount)
  350. if s.onBufferedAmountLow != nil && fromAmount > s.bufferedAmountLow && s.bufferedAmount <= s.bufferedAmountLow {
  351. f := s.onBufferedAmountLow
  352. s.lock.Unlock()
  353. f()
  354. return
  355. }
  356. s.lock.Unlock()
  357. }
  358. func (s *Stream) getNumBytesInReassemblyQueue() int {
  359. // No lock is required as it reads the size with atomic load function.
  360. return s.reassemblyQueue.getNumBytes()
  361. }
  362. func (s *Stream) onInboundStreamReset() {
  363. s.lock.Lock()
  364. defer s.lock.Unlock()
  365. s.log.Debugf("[%s] onInboundStreamReset: state=%s", s.name, s.state.String())
  366. // No more inbound data to read. Unblock the read with io.EOF.
  367. // This should cause DCEP layer (datachannel package) to call Close() which
  368. // will reset outgoing stream also.
  369. // See RFC 8831 section 6.7:
  370. // if one side decides to close the data channel, it resets the corresponding
  371. // outgoing stream. When the peer sees that an incoming stream was
  372. // reset, it also resets its corresponding outgoing stream. Once this
  373. // is completed, the data channel is closed.
  374. s.readErr = io.EOF
  375. s.readNotifier.Broadcast()
  376. if s.state == StreamStateClosing {
  377. s.log.Debugf("[%s] state change: closing => closed", s.name)
  378. s.state = StreamStateClosed
  379. }
  380. }
  381. // State return the stream state.
  382. func (s *Stream) State() StreamState {
  383. s.lock.RLock()
  384. defer s.lock.RUnlock()
  385. return s.state
  386. }