streams_map.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. package quic
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "sync"
  8. "github.com/Psiphon-Labs/quic-go/internal/flowcontrol"
  9. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  10. "github.com/Psiphon-Labs/quic-go/internal/qerr"
  11. "github.com/Psiphon-Labs/quic-go/internal/wire"
  12. )
  13. type streamError struct {
  14. message string
  15. nums []protocol.StreamNum
  16. }
  17. func (e streamError) Error() string {
  18. return e.message
  19. }
  20. func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error {
  21. strError, ok := err.(streamError)
  22. if !ok {
  23. return err
  24. }
  25. ids := make([]interface{}, len(strError.nums))
  26. for i, num := range strError.nums {
  27. ids[i] = num.StreamID(stype, pers)
  28. }
  29. return fmt.Errorf(strError.Error(), ids...)
  30. }
  31. type streamOpenErr struct{ error }
  32. var _ net.Error = &streamOpenErr{}
  33. func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams }
  34. func (streamOpenErr) Timeout() bool { return false }
  35. // errTooManyOpenStreams is used internally by the outgoing streams maps.
  36. var errTooManyOpenStreams = errors.New("too many open streams")
  37. type streamsMap struct {
  38. perspective protocol.Perspective
  39. version protocol.VersionNumber
  40. maxIncomingBidiStreams uint64
  41. maxIncomingUniStreams uint64
  42. sender streamSender
  43. newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
  44. mutex sync.Mutex
  45. outgoingBidiStreams *outgoingStreamsMap[streamI]
  46. outgoingUniStreams *outgoingStreamsMap[sendStreamI]
  47. incomingBidiStreams *incomingStreamsMap[streamI]
  48. incomingUniStreams *incomingStreamsMap[receiveStreamI]
  49. reset bool
  50. }
  51. var _ streamManager = &streamsMap{}
  52. func newStreamsMap(
  53. sender streamSender,
  54. newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
  55. maxIncomingBidiStreams uint64,
  56. maxIncomingUniStreams uint64,
  57. perspective protocol.Perspective,
  58. version protocol.VersionNumber,
  59. ) streamManager {
  60. m := &streamsMap{
  61. perspective: perspective,
  62. newFlowController: newFlowController,
  63. maxIncomingBidiStreams: maxIncomingBidiStreams,
  64. maxIncomingUniStreams: maxIncomingUniStreams,
  65. sender: sender,
  66. version: version,
  67. }
  68. m.initMaps()
  69. return m
  70. }
  71. func (m *streamsMap) initMaps() {
  72. m.outgoingBidiStreams = newOutgoingStreamsMap(
  73. protocol.StreamTypeBidi,
  74. func(num protocol.StreamNum) streamI {
  75. id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
  76. return newStream(id, m.sender, m.newFlowController(id), m.version)
  77. },
  78. m.sender.queueControlFrame,
  79. )
  80. m.incomingBidiStreams = newIncomingStreamsMap(
  81. protocol.StreamTypeBidi,
  82. func(num protocol.StreamNum) streamI {
  83. id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite())
  84. return newStream(id, m.sender, m.newFlowController(id), m.version)
  85. },
  86. m.maxIncomingBidiStreams,
  87. m.sender.queueControlFrame,
  88. )
  89. m.outgoingUniStreams = newOutgoingStreamsMap(
  90. protocol.StreamTypeUni,
  91. func(num protocol.StreamNum) sendStreamI {
  92. id := num.StreamID(protocol.StreamTypeUni, m.perspective)
  93. return newSendStream(id, m.sender, m.newFlowController(id), m.version)
  94. },
  95. m.sender.queueControlFrame,
  96. )
  97. m.incomingUniStreams = newIncomingStreamsMap(
  98. protocol.StreamTypeUni,
  99. func(num protocol.StreamNum) receiveStreamI {
  100. id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite())
  101. return newReceiveStream(id, m.sender, m.newFlowController(id), m.version)
  102. },
  103. m.maxIncomingUniStreams,
  104. m.sender.queueControlFrame,
  105. )
  106. }
  107. func (m *streamsMap) OpenStream() (Stream, error) {
  108. m.mutex.Lock()
  109. reset := m.reset
  110. mm := m.outgoingBidiStreams
  111. m.mutex.Unlock()
  112. if reset {
  113. return nil, Err0RTTRejected
  114. }
  115. str, err := mm.OpenStream()
  116. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
  117. }
  118. func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) {
  119. m.mutex.Lock()
  120. reset := m.reset
  121. mm := m.outgoingBidiStreams
  122. m.mutex.Unlock()
  123. if reset {
  124. return nil, Err0RTTRejected
  125. }
  126. str, err := mm.OpenStreamSync(ctx)
  127. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
  128. }
  129. func (m *streamsMap) OpenUniStream() (SendStream, error) {
  130. m.mutex.Lock()
  131. reset := m.reset
  132. mm := m.outgoingUniStreams
  133. m.mutex.Unlock()
  134. if reset {
  135. return nil, Err0RTTRejected
  136. }
  137. str, err := mm.OpenStream()
  138. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
  139. }
  140. func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
  141. m.mutex.Lock()
  142. reset := m.reset
  143. mm := m.outgoingUniStreams
  144. m.mutex.Unlock()
  145. if reset {
  146. return nil, Err0RTTRejected
  147. }
  148. str, err := mm.OpenStreamSync(ctx)
  149. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
  150. }
  151. func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
  152. m.mutex.Lock()
  153. reset := m.reset
  154. mm := m.incomingBidiStreams
  155. m.mutex.Unlock()
  156. if reset {
  157. return nil, Err0RTTRejected
  158. }
  159. str, err := mm.AcceptStream(ctx)
  160. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
  161. }
  162. func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
  163. m.mutex.Lock()
  164. reset := m.reset
  165. mm := m.incomingUniStreams
  166. m.mutex.Unlock()
  167. if reset {
  168. return nil, Err0RTTRejected
  169. }
  170. str, err := mm.AcceptStream(ctx)
  171. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
  172. }
  173. func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
  174. num := id.StreamNum()
  175. switch id.Type() {
  176. case protocol.StreamTypeUni:
  177. if id.InitiatedBy() == m.perspective {
  178. return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective)
  179. }
  180. return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite())
  181. case protocol.StreamTypeBidi:
  182. if id.InitiatedBy() == m.perspective {
  183. return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective)
  184. }
  185. return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite())
  186. }
  187. panic("")
  188. }
  189. func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
  190. str, err := m.getOrOpenReceiveStream(id)
  191. if err != nil {
  192. return nil, &qerr.TransportError{
  193. ErrorCode: qerr.StreamStateError,
  194. ErrorMessage: err.Error(),
  195. }
  196. }
  197. return str, nil
  198. }
  199. func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
  200. num := id.StreamNum()
  201. switch id.Type() {
  202. case protocol.StreamTypeUni:
  203. if id.InitiatedBy() == m.perspective {
  204. // an outgoing unidirectional stream is a send stream, not a receive stream
  205. return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
  206. }
  207. str, err := m.incomingUniStreams.GetOrOpenStream(num)
  208. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
  209. case protocol.StreamTypeBidi:
  210. var str receiveStreamI
  211. var err error
  212. if id.InitiatedBy() == m.perspective {
  213. str, err = m.outgoingBidiStreams.GetStream(num)
  214. } else {
  215. str, err = m.incomingBidiStreams.GetOrOpenStream(num)
  216. }
  217. return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
  218. }
  219. panic("")
  220. }
  221. func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
  222. str, err := m.getOrOpenSendStream(id)
  223. if err != nil {
  224. return nil, &qerr.TransportError{
  225. ErrorCode: qerr.StreamStateError,
  226. ErrorMessage: err.Error(),
  227. }
  228. }
  229. return str, nil
  230. }
  231. func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
  232. num := id.StreamNum()
  233. switch id.Type() {
  234. case protocol.StreamTypeUni:
  235. if id.InitiatedBy() == m.perspective {
  236. str, err := m.outgoingUniStreams.GetStream(num)
  237. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
  238. }
  239. // an incoming unidirectional stream is a receive stream, not a send stream
  240. return nil, fmt.Errorf("peer attempted to open send stream %d", id)
  241. case protocol.StreamTypeBidi:
  242. var str sendStreamI
  243. var err error
  244. if id.InitiatedBy() == m.perspective {
  245. str, err = m.outgoingBidiStreams.GetStream(num)
  246. } else {
  247. str, err = m.incomingBidiStreams.GetOrOpenStream(num)
  248. }
  249. return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
  250. }
  251. panic("")
  252. }
  253. func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
  254. switch f.Type {
  255. case protocol.StreamTypeUni:
  256. m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum)
  257. case protocol.StreamTypeBidi:
  258. m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum)
  259. }
  260. }
  261. func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) {
  262. m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote)
  263. m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum)
  264. m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni)
  265. m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum)
  266. }
  267. func (m *streamsMap) CloseWithError(err error) {
  268. m.outgoingBidiStreams.CloseWithError(err)
  269. m.outgoingUniStreams.CloseWithError(err)
  270. m.incomingBidiStreams.CloseWithError(err)
  271. m.incomingUniStreams.CloseWithError(err)
  272. }
  273. // ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are
  274. // 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error.
  275. // 2. reset to their initial state, such that we can immediately process new incoming stream data.
  276. // Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error,
  277. // until UseResetMaps() has been called.
  278. func (m *streamsMap) ResetFor0RTT() {
  279. m.mutex.Lock()
  280. defer m.mutex.Unlock()
  281. m.reset = true
  282. m.CloseWithError(Err0RTTRejected)
  283. m.initMaps()
  284. }
  285. func (m *streamsMap) UseResetMaps() {
  286. m.mutex.Lock()
  287. m.reset = false
  288. m.mutex.Unlock()
  289. }