streams_map.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. package quic
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "sync"
  8. "github.com/Psiphon-Labs/quic-go/internal/flowcontrol"
  9. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  10. "github.com/Psiphon-Labs/quic-go/internal/qerr"
  11. "github.com/Psiphon-Labs/quic-go/internal/wire"
  12. )
  13. type streamError struct {
  14. message string
  15. nums []protocol.StreamNum
  16. }
  17. func (e streamError) Error() string {
  18. return e.message
  19. }
  20. func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error {
  21. strError, ok := err.(streamError)
  22. if !ok {
  23. return err
  24. }
  25. ids := make([]interface{}, len(strError.nums))
  26. for i, num := range strError.nums {
  27. ids[i] = num.StreamID(stype, pers)
  28. }
  29. return fmt.Errorf(strError.Error(), ids...)
  30. }
  31. type streamOpenErr struct{ error }
  32. var _ net.Error = &streamOpenErr{}
  33. func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams }
  34. func (streamOpenErr) Timeout() bool { return false }
  35. // errTooManyOpenStreams is used internally by the outgoing streams maps.
  36. var errTooManyOpenStreams = errors.New("too many open streams")
  37. type streamsMap struct {
  38. perspective protocol.Perspective
  39. maxIncomingBidiStreams uint64
  40. maxIncomingUniStreams uint64
  41. sender streamSender
  42. newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
  43. mutex sync.Mutex
  44. outgoingBidiStreams *outgoingStreamsMap[streamI]
  45. outgoingUniStreams *outgoingStreamsMap[sendStreamI]
  46. incomingBidiStreams *incomingStreamsMap[streamI]
  47. incomingUniStreams *incomingStreamsMap[receiveStreamI]
  48. reset bool
  49. }
  50. var _ streamManager = &streamsMap{}
  51. func newStreamsMap(
  52. sender streamSender,
  53. newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
  54. maxIncomingBidiStreams uint64,
  55. maxIncomingUniStreams uint64,
  56. perspective protocol.Perspective,
  57. ) streamManager {
  58. m := &streamsMap{
  59. perspective: perspective,
  60. newFlowController: newFlowController,
  61. maxIncomingBidiStreams: maxIncomingBidiStreams,
  62. maxIncomingUniStreams: maxIncomingUniStreams,
  63. sender: sender,
  64. }
  65. m.initMaps()
  66. return m
  67. }
  68. func (m *streamsMap) initMaps() {
  69. m.outgoingBidiStreams = newOutgoingStreamsMap(
  70. protocol.StreamTypeBidi,
  71. func(num protocol.StreamNum) streamI {
  72. id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
  73. return newStream(id, m.sender, m.newFlowController(id))
  74. },
  75. m.sender.queueControlFrame,
  76. )
  77. m.incomingBidiStreams = newIncomingStreamsMap(
  78. protocol.StreamTypeBidi,
  79. func(num protocol.StreamNum) streamI {
  80. id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite())
  81. return newStream(id, m.sender, m.newFlowController(id))
  82. },
  83. m.maxIncomingBidiStreams,
  84. m.sender.queueControlFrame,
  85. )
  86. m.outgoingUniStreams = newOutgoingStreamsMap(
  87. protocol.StreamTypeUni,
  88. func(num protocol.StreamNum) sendStreamI {
  89. id := num.StreamID(protocol.StreamTypeUni, m.perspective)
  90. return newSendStream(id, m.sender, m.newFlowController(id))
  91. },
  92. m.sender.queueControlFrame,
  93. )
  94. m.incomingUniStreams = newIncomingStreamsMap(
  95. protocol.StreamTypeUni,
  96. func(num protocol.StreamNum) receiveStreamI {
  97. id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite())
  98. return newReceiveStream(id, m.sender, m.newFlowController(id))
  99. },
  100. m.maxIncomingUniStreams,
  101. m.sender.queueControlFrame,
  102. )
  103. }
  104. func (m *streamsMap) OpenStream() (Stream, error) {
  105. m.mutex.Lock()
  106. reset := m.reset
  107. mm := m.outgoingBidiStreams
  108. m.mutex.Unlock()
  109. if reset {
  110. return nil, Err0RTTRejected
  111. }
  112. str, err := mm.OpenStream()
  113. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
  114. }
  115. func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) {
  116. m.mutex.Lock()
  117. reset := m.reset
  118. mm := m.outgoingBidiStreams
  119. m.mutex.Unlock()
  120. if reset {
  121. return nil, Err0RTTRejected
  122. }
  123. str, err := mm.OpenStreamSync(ctx)
  124. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
  125. }
  126. func (m *streamsMap) OpenUniStream() (SendStream, error) {
  127. m.mutex.Lock()
  128. reset := m.reset
  129. mm := m.outgoingUniStreams
  130. m.mutex.Unlock()
  131. if reset {
  132. return nil, Err0RTTRejected
  133. }
  134. str, err := mm.OpenStream()
  135. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
  136. }
  137. func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
  138. m.mutex.Lock()
  139. reset := m.reset
  140. mm := m.outgoingUniStreams
  141. m.mutex.Unlock()
  142. if reset {
  143. return nil, Err0RTTRejected
  144. }
  145. str, err := mm.OpenStreamSync(ctx)
  146. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
  147. }
  148. func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
  149. m.mutex.Lock()
  150. reset := m.reset
  151. mm := m.incomingBidiStreams
  152. m.mutex.Unlock()
  153. if reset {
  154. return nil, Err0RTTRejected
  155. }
  156. str, err := mm.AcceptStream(ctx)
  157. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
  158. }
  159. func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
  160. m.mutex.Lock()
  161. reset := m.reset
  162. mm := m.incomingUniStreams
  163. m.mutex.Unlock()
  164. if reset {
  165. return nil, Err0RTTRejected
  166. }
  167. str, err := mm.AcceptStream(ctx)
  168. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
  169. }
  170. func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
  171. num := id.StreamNum()
  172. switch id.Type() {
  173. case protocol.StreamTypeUni:
  174. if id.InitiatedBy() == m.perspective {
  175. return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective)
  176. }
  177. return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite())
  178. case protocol.StreamTypeBidi:
  179. if id.InitiatedBy() == m.perspective {
  180. return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective)
  181. }
  182. return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite())
  183. }
  184. panic("")
  185. }
  186. func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
  187. str, err := m.getOrOpenReceiveStream(id)
  188. if err != nil {
  189. return nil, &qerr.TransportError{
  190. ErrorCode: qerr.StreamStateError,
  191. ErrorMessage: err.Error(),
  192. }
  193. }
  194. return str, nil
  195. }
  196. func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
  197. num := id.StreamNum()
  198. switch id.Type() {
  199. case protocol.StreamTypeUni:
  200. if id.InitiatedBy() == m.perspective {
  201. // an outgoing unidirectional stream is a send stream, not a receive stream
  202. return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
  203. }
  204. str, err := m.incomingUniStreams.GetOrOpenStream(num)
  205. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
  206. case protocol.StreamTypeBidi:
  207. var str receiveStreamI
  208. var err error
  209. if id.InitiatedBy() == m.perspective {
  210. str, err = m.outgoingBidiStreams.GetStream(num)
  211. } else {
  212. str, err = m.incomingBidiStreams.GetOrOpenStream(num)
  213. }
  214. return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
  215. }
  216. panic("")
  217. }
  218. func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
  219. str, err := m.getOrOpenSendStream(id)
  220. if err != nil {
  221. return nil, &qerr.TransportError{
  222. ErrorCode: qerr.StreamStateError,
  223. ErrorMessage: err.Error(),
  224. }
  225. }
  226. return str, nil
  227. }
  228. func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
  229. num := id.StreamNum()
  230. switch id.Type() {
  231. case protocol.StreamTypeUni:
  232. if id.InitiatedBy() == m.perspective {
  233. str, err := m.outgoingUniStreams.GetStream(num)
  234. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
  235. }
  236. // an incoming unidirectional stream is a receive stream, not a send stream
  237. return nil, fmt.Errorf("peer attempted to open send stream %d", id)
  238. case protocol.StreamTypeBidi:
  239. var str sendStreamI
  240. var err error
  241. if id.InitiatedBy() == m.perspective {
  242. str, err = m.outgoingBidiStreams.GetStream(num)
  243. } else {
  244. str, err = m.incomingBidiStreams.GetOrOpenStream(num)
  245. }
  246. return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
  247. }
  248. panic("")
  249. }
  250. func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
  251. switch f.Type {
  252. case protocol.StreamTypeUni:
  253. m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum)
  254. case protocol.StreamTypeBidi:
  255. m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum)
  256. }
  257. }
  258. func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) {
  259. m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote)
  260. m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum)
  261. m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni)
  262. m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum)
  263. }
  264. func (m *streamsMap) CloseWithError(err error) {
  265. m.outgoingBidiStreams.CloseWithError(err)
  266. m.outgoingUniStreams.CloseWithError(err)
  267. m.incomingBidiStreams.CloseWithError(err)
  268. m.incomingUniStreams.CloseWithError(err)
  269. }
  270. // ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are
  271. // 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error.
  272. // 2. reset to their initial state, such that we can immediately process new incoming stream data.
  273. // Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error,
  274. // until UseResetMaps() has been called.
  275. func (m *streamsMap) ResetFor0RTT() {
  276. m.mutex.Lock()
  277. defer m.mutex.Unlock()
  278. m.reset = true
  279. m.CloseWithError(Err0RTTRejected)
  280. m.initMaps()
  281. }
  282. func (m *streamsMap) UseResetMaps() {
  283. m.mutex.Lock()
  284. m.reset = false
  285. m.mutex.Unlock()
  286. }