stream.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package sctp
  4. import (
  5. "errors"
  6. "fmt"
  7. "io"
  8. "os"
  9. "sync"
  10. "sync/atomic"
  11. "time"
  12. "github.com/pion/logging"
  13. )
  14. const (
  15. // ReliabilityTypeReliable is used for reliable transmission
  16. ReliabilityTypeReliable byte = 0
  17. // ReliabilityTypeRexmit is used for partial reliability by retransmission count
  18. ReliabilityTypeRexmit byte = 1
  19. // ReliabilityTypeTimed is used for partial reliability by retransmission duration
  20. ReliabilityTypeTimed byte = 2
  21. )
  22. // StreamState is an enum for SCTP Stream state field
  23. // This field identifies the state of stream.
  24. type StreamState int
  25. // StreamState enums
  26. const (
  27. StreamStateOpen StreamState = iota // Stream object starts with StreamStateOpen
  28. StreamStateClosing // Outgoing stream is being reset
  29. StreamStateClosed // Stream has been closed
  30. )
  31. func (ss StreamState) String() string {
  32. switch ss {
  33. case StreamStateOpen:
  34. return "open"
  35. case StreamStateClosing:
  36. return "closing"
  37. case StreamStateClosed:
  38. return "closed"
  39. }
  40. return "unknown"
  41. }
  42. // SCTP stream errors
  43. var (
  44. ErrOutboundPacketTooLarge = errors.New("outbound packet larger than maximum message size")
  45. ErrStreamClosed = errors.New("stream closed")
  46. ErrReadDeadlineExceeded = fmt.Errorf("read deadline exceeded: %w", os.ErrDeadlineExceeded)
  47. )
  48. // Stream represents an SCTP stream
  49. type Stream struct {
  50. association *Association
  51. lock sync.RWMutex
  52. streamIdentifier uint16
  53. defaultPayloadType PayloadProtocolIdentifier
  54. reassemblyQueue *reassemblyQueue
  55. sequenceNumber uint16
  56. readNotifier *sync.Cond
  57. readErr error
  58. readTimeoutCancel chan struct{}
  59. unordered bool
  60. reliabilityType byte
  61. reliabilityValue uint32
  62. bufferedAmount uint64
  63. bufferedAmountLow uint64
  64. onBufferedAmountLow func()
  65. state StreamState
  66. log logging.LeveledLogger
  67. name string
  68. }
  69. // StreamIdentifier returns the Stream identifier associated to the stream.
  70. func (s *Stream) StreamIdentifier() uint16 {
  71. s.lock.RLock()
  72. defer s.lock.RUnlock()
  73. return s.streamIdentifier
  74. }
  75. // SetDefaultPayloadType sets the default payload type used by Write.
  76. func (s *Stream) SetDefaultPayloadType(defaultPayloadType PayloadProtocolIdentifier) {
  77. atomic.StoreUint32((*uint32)(&s.defaultPayloadType), uint32(defaultPayloadType))
  78. }
  79. // SetReliabilityParams sets reliability parameters for this stream.
  80. func (s *Stream) SetReliabilityParams(unordered bool, relType byte, relVal uint32) {
  81. s.lock.Lock()
  82. defer s.lock.Unlock()
  83. s.setReliabilityParams(unordered, relType, relVal)
  84. }
  85. // setReliabilityParams sets reliability parameters for this stream.
  86. // The caller should hold the lock.
  87. func (s *Stream) setReliabilityParams(unordered bool, relType byte, relVal uint32) {
  88. s.log.Debugf("[%s] reliability params: ordered=%v type=%d value=%d",
  89. s.name, !unordered, relType, relVal)
  90. s.unordered = unordered
  91. s.reliabilityType = relType
  92. s.reliabilityValue = relVal
  93. }
  94. // Read reads a packet of len(p) bytes, dropping the Payload Protocol Identifier.
  95. // Returns EOF when the stream is reset or an error if the stream is closed
  96. // otherwise.
  97. func (s *Stream) Read(p []byte) (int, error) {
  98. n, _, err := s.ReadSCTP(p)
  99. return n, err
  100. }
  101. // ReadSCTP reads a packet of len(p) bytes and returns the associated Payload
  102. // Protocol Identifier.
  103. // Returns EOF when the stream is reset or an error if the stream is closed
  104. // otherwise.
  105. func (s *Stream) ReadSCTP(p []byte) (int, PayloadProtocolIdentifier, error) {
  106. s.lock.Lock()
  107. defer s.lock.Unlock()
  108. defer func() {
  109. // close readTimeoutCancel if the current read timeout routine is no longer effective
  110. if s.readTimeoutCancel != nil && s.readErr != nil {
  111. close(s.readTimeoutCancel)
  112. s.readTimeoutCancel = nil
  113. }
  114. }()
  115. for {
  116. n, ppi, err := s.reassemblyQueue.read(p)
  117. if err == nil {
  118. return n, ppi, nil
  119. } else if errors.Is(err, io.ErrShortBuffer) {
  120. return 0, PayloadProtocolIdentifier(0), err
  121. }
  122. err = s.readErr
  123. if err != nil {
  124. return 0, PayloadProtocolIdentifier(0), err
  125. }
  126. s.readNotifier.Wait()
  127. }
  128. }
  129. // SetReadDeadline sets the read deadline in an identical way to net.Conn
  130. func (s *Stream) SetReadDeadline(deadline time.Time) error {
  131. s.lock.Lock()
  132. defer s.lock.Unlock()
  133. if s.readTimeoutCancel != nil {
  134. close(s.readTimeoutCancel)
  135. s.readTimeoutCancel = nil
  136. }
  137. if s.readErr != nil {
  138. if !errors.Is(s.readErr, ErrReadDeadlineExceeded) {
  139. return nil
  140. }
  141. s.readErr = nil
  142. }
  143. if !deadline.IsZero() {
  144. s.readTimeoutCancel = make(chan struct{})
  145. go func(readTimeoutCancel chan struct{}) {
  146. t := time.NewTimer(time.Until(deadline))
  147. select {
  148. case <-readTimeoutCancel:
  149. t.Stop()
  150. return
  151. case <-t.C:
  152. select {
  153. case <-readTimeoutCancel:
  154. return
  155. default:
  156. }
  157. s.lock.Lock()
  158. if s.readErr == nil {
  159. s.readErr = ErrReadDeadlineExceeded
  160. }
  161. s.readTimeoutCancel = nil
  162. s.lock.Unlock()
  163. s.readNotifier.Signal()
  164. }
  165. }(s.readTimeoutCancel)
  166. }
  167. return nil
  168. }
  169. func (s *Stream) handleData(pd *chunkPayloadData) {
  170. s.lock.Lock()
  171. defer s.lock.Unlock()
  172. var readable bool
  173. if s.reassemblyQueue.push(pd) {
  174. readable = s.reassemblyQueue.isReadable()
  175. s.log.Debugf("[%s] reassemblyQueue readable=%v", s.name, readable)
  176. if readable {
  177. s.log.Debugf("[%s] readNotifier.signal()", s.name)
  178. s.readNotifier.Signal()
  179. s.log.Debugf("[%s] readNotifier.signal() done", s.name)
  180. }
  181. }
  182. }
  183. func (s *Stream) handleForwardTSNForOrdered(ssn uint16) {
  184. var readable bool
  185. func() {
  186. s.lock.Lock()
  187. defer s.lock.Unlock()
  188. if s.unordered {
  189. return // unordered chunks are handled by handleForwardUnordered method
  190. }
  191. // Remove all chunks older than or equal to the new TSN from
  192. // the reassemblyQueue.
  193. s.reassemblyQueue.forwardTSNForOrdered(ssn)
  194. readable = s.reassemblyQueue.isReadable()
  195. }()
  196. // Notify the reader asynchronously if there's a data chunk to read.
  197. if readable {
  198. s.readNotifier.Signal()
  199. }
  200. }
  201. func (s *Stream) handleForwardTSNForUnordered(newCumulativeTSN uint32) {
  202. var readable bool
  203. func() {
  204. s.lock.Lock()
  205. defer s.lock.Unlock()
  206. if !s.unordered {
  207. return // ordered chunks are handled by handleForwardTSNOrdered method
  208. }
  209. // Remove all chunks older than or equal to the new TSN from
  210. // the reassemblyQueue.
  211. s.reassemblyQueue.forwardTSNForUnordered(newCumulativeTSN)
  212. readable = s.reassemblyQueue.isReadable()
  213. }()
  214. // Notify the reader asynchronously if there's a data chunk to read.
  215. if readable {
  216. s.readNotifier.Signal()
  217. }
  218. }
  219. // Write writes len(p) bytes from p with the default Payload Protocol Identifier
  220. func (s *Stream) Write(p []byte) (n int, err error) {
  221. ppi := PayloadProtocolIdentifier(atomic.LoadUint32((*uint32)(&s.defaultPayloadType)))
  222. return s.WriteSCTP(p, ppi)
  223. }
  224. // WriteSCTP writes len(p) bytes from p to the DTLS connection
  225. func (s *Stream) WriteSCTP(p []byte, ppi PayloadProtocolIdentifier) (int, error) {
  226. maxMessageSize := s.association.MaxMessageSize()
  227. if len(p) > int(maxMessageSize) {
  228. return 0, fmt.Errorf("%w: %v", ErrOutboundPacketTooLarge, maxMessageSize)
  229. }
  230. if s.State() != StreamStateOpen {
  231. return 0, ErrStreamClosed
  232. }
  233. chunks := s.packetize(p, ppi)
  234. n := len(p)
  235. err := s.association.sendPayloadData(chunks)
  236. if err != nil {
  237. return n, ErrStreamClosed
  238. }
  239. return n, nil
  240. }
  241. func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) []*chunkPayloadData {
  242. s.lock.Lock()
  243. defer s.lock.Unlock()
  244. i := uint32(0)
  245. remaining := uint32(len(raw))
  246. // From draft-ietf-rtcweb-data-protocol-09, section 6:
  247. // All Data Channel Establishment Protocol messages MUST be sent using
  248. // ordered delivery and reliable transmission.
  249. unordered := ppi != PayloadTypeWebRTCDCEP && s.unordered
  250. var chunks []*chunkPayloadData
  251. var head *chunkPayloadData
  252. for remaining != 0 {
  253. fragmentSize := min32(s.association.maxPayloadSize, remaining)
  254. // Copy the userdata since we'll have to store it until acked
  255. // and the caller may re-use the buffer in the mean time
  256. userData := make([]byte, fragmentSize)
  257. copy(userData, raw[i:i+fragmentSize])
  258. chunk := &chunkPayloadData{
  259. streamIdentifier: s.streamIdentifier,
  260. userData: userData,
  261. unordered: unordered,
  262. beginningFragment: i == 0,
  263. endingFragment: remaining-fragmentSize == 0,
  264. immediateSack: false,
  265. payloadType: ppi,
  266. streamSequenceNumber: s.sequenceNumber,
  267. head: head,
  268. }
  269. if head == nil {
  270. head = chunk
  271. }
  272. chunks = append(chunks, chunk)
  273. remaining -= fragmentSize
  274. i += fragmentSize
  275. }
  276. // RFC 4960 Sec 6.6
  277. // Note: When transmitting ordered and unordered data, an endpoint does
  278. // not increment its Stream Sequence Number when transmitting a DATA
  279. // chunk with U flag set to 1.
  280. if !unordered {
  281. s.sequenceNumber++
  282. }
  283. s.bufferedAmount += uint64(len(raw))
  284. s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount)
  285. return chunks
  286. }
  287. // Close closes the write-direction of the stream.
  288. // Future calls to Write are not permitted after calling Close.
  289. func (s *Stream) Close() error {
  290. if sid, resetOutbound := func() (uint16, bool) {
  291. s.lock.Lock()
  292. defer s.lock.Unlock()
  293. s.log.Debugf("[%s] Close: state=%s", s.name, s.state.String())
  294. if s.state == StreamStateOpen {
  295. if s.readErr == nil {
  296. s.state = StreamStateClosing
  297. } else {
  298. s.state = StreamStateClosed
  299. }
  300. s.log.Debugf("[%s] state change: open => %s", s.name, s.state.String())
  301. return s.streamIdentifier, true
  302. }
  303. return s.streamIdentifier, false
  304. }(); resetOutbound {
  305. // Reset the outgoing stream
  306. // https://tools.ietf.org/html/rfc6525
  307. return s.association.sendResetRequest(sid)
  308. }
  309. return nil
  310. }
  311. // BufferedAmount returns the number of bytes of data currently queued to be sent over this stream.
  312. func (s *Stream) BufferedAmount() uint64 {
  313. s.lock.RLock()
  314. defer s.lock.RUnlock()
  315. return s.bufferedAmount
  316. }
  317. // BufferedAmountLowThreshold returns the number of bytes of buffered outgoing data that is
  318. // considered "low." Defaults to 0.
  319. func (s *Stream) BufferedAmountLowThreshold() uint64 {
  320. s.lock.RLock()
  321. defer s.lock.RUnlock()
  322. return s.bufferedAmountLow
  323. }
  324. // SetBufferedAmountLowThreshold is used to update the threshold.
  325. // See BufferedAmountLowThreshold().
  326. func (s *Stream) SetBufferedAmountLowThreshold(th uint64) {
  327. s.lock.Lock()
  328. defer s.lock.Unlock()
  329. s.bufferedAmountLow = th
  330. }
  331. // OnBufferedAmountLow sets the callback handler which would be called when the number of
  332. // bytes of outgoing data buffered is lower than the threshold.
  333. func (s *Stream) OnBufferedAmountLow(f func()) {
  334. s.lock.Lock()
  335. defer s.lock.Unlock()
  336. s.onBufferedAmountLow = f
  337. }
  338. // This method is called by association's readLoop (go-)routine to notify this stream
  339. // of the specified amount of outgoing data has been delivered to the peer.
  340. func (s *Stream) onBufferReleased(nBytesReleased int) {
  341. if nBytesReleased <= 0 {
  342. return
  343. }
  344. s.lock.Lock()
  345. fromAmount := s.bufferedAmount
  346. if s.bufferedAmount < uint64(nBytesReleased) {
  347. s.bufferedAmount = 0
  348. s.log.Errorf("[%s] released buffer size %d should be <= %d",
  349. s.name, nBytesReleased, s.bufferedAmount)
  350. } else {
  351. s.bufferedAmount -= uint64(nBytesReleased)
  352. }
  353. s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount)
  354. if s.onBufferedAmountLow != nil && fromAmount > s.bufferedAmountLow && s.bufferedAmount <= s.bufferedAmountLow {
  355. f := s.onBufferedAmountLow
  356. s.lock.Unlock()
  357. f()
  358. return
  359. }
  360. s.lock.Unlock()
  361. }
  362. func (s *Stream) getNumBytesInReassemblyQueue() int {
  363. // No lock is required as it reads the size with atomic load function.
  364. return s.reassemblyQueue.getNumBytes()
  365. }
  366. func (s *Stream) onInboundStreamReset() {
  367. s.lock.Lock()
  368. defer s.lock.Unlock()
  369. s.log.Debugf("[%s] onInboundStreamReset: state=%s", s.name, s.state.String())
  370. // No more inbound data to read. Unblock the read with io.EOF.
  371. // This should cause DCEP layer (datachannel package) to call Close() which
  372. // will reset outgoing stream also.
  373. // See RFC 8831 section 6.7:
  374. // if one side decides to close the data channel, it resets the corresponding
  375. // outgoing stream. When the peer sees that an incoming stream was
  376. // reset, it also resets its corresponding outgoing stream. Once this
  377. // is completed, the data channel is closed.
  378. s.readErr = io.EOF
  379. s.readNotifier.Broadcast()
  380. if s.state == StreamStateClosing {
  381. s.log.Debugf("[%s] state change: closing => closed", s.name)
  382. s.state = StreamStateClosed
  383. }
  384. }
  385. // State return the stream state.
  386. func (s *Stream) State() StreamState {
  387. s.lock.RLock()
  388. defer s.lock.RUnlock()
  389. return s.state
  390. }