streams_map.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. package quic
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "sync"
  8. "github.com/Psiphon-Labs/quic-go/internal/flowcontrol"
  9. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  10. "github.com/Psiphon-Labs/quic-go/internal/qerr"
  11. "github.com/Psiphon-Labs/quic-go/internal/wire"
  12. )
  13. type streamError struct {
  14. message string
  15. nums []protocol.StreamNum
  16. }
  17. func (e streamError) Error() string {
  18. return e.message
  19. }
  20. func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error {
  21. strError, ok := err.(streamError)
  22. if !ok {
  23. return err
  24. }
  25. ids := make([]interface{}, len(strError.nums))
  26. for i, num := range strError.nums {
  27. ids[i] = num.StreamID(stype, pers)
  28. }
  29. return fmt.Errorf(strError.Error(), ids...)
  30. }
  31. type streamOpenErr struct{ error }
  32. var _ net.Error = &streamOpenErr{}
  33. func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams }
  34. func (streamOpenErr) Timeout() bool { return false }
  35. // errTooManyOpenStreams is used internally by the outgoing streams maps.
  36. var errTooManyOpenStreams = errors.New("too many open streams")
  37. type streamsMap struct {
  38. perspective protocol.Perspective
  39. version protocol.VersionNumber
  40. maxIncomingBidiStreams uint64
  41. maxIncomingUniStreams uint64
  42. sender streamSender
  43. newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
  44. mutex sync.Mutex
  45. outgoingBidiStreams *outgoingBidiStreamsMap
  46. outgoingUniStreams *outgoingUniStreamsMap
  47. incomingBidiStreams *incomingBidiStreamsMap
  48. incomingUniStreams *incomingUniStreamsMap
  49. reset bool
  50. }
  51. var _ streamManager = &streamsMap{}
  52. func newStreamsMap(
  53. sender streamSender,
  54. newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
  55. maxIncomingBidiStreams uint64,
  56. maxIncomingUniStreams uint64,
  57. perspective protocol.Perspective,
  58. version protocol.VersionNumber,
  59. ) streamManager {
  60. m := &streamsMap{
  61. perspective: perspective,
  62. newFlowController: newFlowController,
  63. maxIncomingBidiStreams: maxIncomingBidiStreams,
  64. maxIncomingUniStreams: maxIncomingUniStreams,
  65. sender: sender,
  66. version: version,
  67. }
  68. m.initMaps()
  69. return m
  70. }
  71. func (m *streamsMap) initMaps() {
  72. m.outgoingBidiStreams = newOutgoingBidiStreamsMap(
  73. func(num protocol.StreamNum) streamI {
  74. id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
  75. return newStream(id, m.sender, m.newFlowController(id), m.version)
  76. },
  77. m.sender.queueControlFrame,
  78. )
  79. m.incomingBidiStreams = newIncomingBidiStreamsMap(
  80. func(num protocol.StreamNum) streamI {
  81. id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite())
  82. return newStream(id, m.sender, m.newFlowController(id), m.version)
  83. },
  84. m.maxIncomingBidiStreams,
  85. m.sender.queueControlFrame,
  86. )
  87. m.outgoingUniStreams = newOutgoingUniStreamsMap(
  88. func(num protocol.StreamNum) sendStreamI {
  89. id := num.StreamID(protocol.StreamTypeUni, m.perspective)
  90. return newSendStream(id, m.sender, m.newFlowController(id), m.version)
  91. },
  92. m.sender.queueControlFrame,
  93. )
  94. m.incomingUniStreams = newIncomingUniStreamsMap(
  95. func(num protocol.StreamNum) receiveStreamI {
  96. id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite())
  97. return newReceiveStream(id, m.sender, m.newFlowController(id), m.version)
  98. },
  99. m.maxIncomingUniStreams,
  100. m.sender.queueControlFrame,
  101. )
  102. }
  103. func (m *streamsMap) OpenStream() (Stream, error) {
  104. m.mutex.Lock()
  105. reset := m.reset
  106. mm := m.outgoingBidiStreams
  107. m.mutex.Unlock()
  108. if reset {
  109. return nil, Err0RTTRejected
  110. }
  111. str, err := mm.OpenStream()
  112. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
  113. }
  114. func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) {
  115. m.mutex.Lock()
  116. reset := m.reset
  117. mm := m.outgoingBidiStreams
  118. m.mutex.Unlock()
  119. if reset {
  120. return nil, Err0RTTRejected
  121. }
  122. str, err := mm.OpenStreamSync(ctx)
  123. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
  124. }
  125. func (m *streamsMap) OpenUniStream() (SendStream, error) {
  126. m.mutex.Lock()
  127. reset := m.reset
  128. mm := m.outgoingUniStreams
  129. m.mutex.Unlock()
  130. if reset {
  131. return nil, Err0RTTRejected
  132. }
  133. str, err := mm.OpenStream()
  134. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
  135. }
  136. func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
  137. m.mutex.Lock()
  138. reset := m.reset
  139. mm := m.outgoingUniStreams
  140. m.mutex.Unlock()
  141. if reset {
  142. return nil, Err0RTTRejected
  143. }
  144. str, err := mm.OpenStreamSync(ctx)
  145. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
  146. }
  147. func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
  148. m.mutex.Lock()
  149. reset := m.reset
  150. mm := m.incomingBidiStreams
  151. m.mutex.Unlock()
  152. if reset {
  153. return nil, Err0RTTRejected
  154. }
  155. str, err := mm.AcceptStream(ctx)
  156. return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
  157. }
  158. func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
  159. m.mutex.Lock()
  160. reset := m.reset
  161. mm := m.incomingUniStreams
  162. m.mutex.Unlock()
  163. if reset {
  164. return nil, Err0RTTRejected
  165. }
  166. str, err := mm.AcceptStream(ctx)
  167. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
  168. }
  169. func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
  170. num := id.StreamNum()
  171. switch id.Type() {
  172. case protocol.StreamTypeUni:
  173. if id.InitiatedBy() == m.perspective {
  174. return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective)
  175. }
  176. return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite())
  177. case protocol.StreamTypeBidi:
  178. if id.InitiatedBy() == m.perspective {
  179. return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective)
  180. }
  181. return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite())
  182. }
  183. panic("")
  184. }
  185. func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
  186. str, err := m.getOrOpenReceiveStream(id)
  187. if err != nil {
  188. return nil, &qerr.TransportError{
  189. ErrorCode: qerr.StreamStateError,
  190. ErrorMessage: err.Error(),
  191. }
  192. }
  193. return str, nil
  194. }
  195. func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
  196. num := id.StreamNum()
  197. switch id.Type() {
  198. case protocol.StreamTypeUni:
  199. if id.InitiatedBy() == m.perspective {
  200. // an outgoing unidirectional stream is a send stream, not a receive stream
  201. return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
  202. }
  203. str, err := m.incomingUniStreams.GetOrOpenStream(num)
  204. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
  205. case protocol.StreamTypeBidi:
  206. var str receiveStreamI
  207. var err error
  208. if id.InitiatedBy() == m.perspective {
  209. str, err = m.outgoingBidiStreams.GetStream(num)
  210. } else {
  211. str, err = m.incomingBidiStreams.GetOrOpenStream(num)
  212. }
  213. return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
  214. }
  215. panic("")
  216. }
  217. func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
  218. str, err := m.getOrOpenSendStream(id)
  219. if err != nil {
  220. return nil, &qerr.TransportError{
  221. ErrorCode: qerr.StreamStateError,
  222. ErrorMessage: err.Error(),
  223. }
  224. }
  225. return str, nil
  226. }
  227. func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
  228. num := id.StreamNum()
  229. switch id.Type() {
  230. case protocol.StreamTypeUni:
  231. if id.InitiatedBy() == m.perspective {
  232. str, err := m.outgoingUniStreams.GetStream(num)
  233. return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
  234. }
  235. // an incoming unidirectional stream is a receive stream, not a send stream
  236. return nil, fmt.Errorf("peer attempted to open send stream %d", id)
  237. case protocol.StreamTypeBidi:
  238. var str sendStreamI
  239. var err error
  240. if id.InitiatedBy() == m.perspective {
  241. str, err = m.outgoingBidiStreams.GetStream(num)
  242. } else {
  243. str, err = m.incomingBidiStreams.GetOrOpenStream(num)
  244. }
  245. return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
  246. }
  247. panic("")
  248. }
  249. func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
  250. switch f.Type {
  251. case protocol.StreamTypeUni:
  252. m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum)
  253. case protocol.StreamTypeBidi:
  254. m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum)
  255. }
  256. }
  257. func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) {
  258. m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote)
  259. m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum)
  260. m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni)
  261. m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum)
  262. }
  263. func (m *streamsMap) CloseWithError(err error) {
  264. m.outgoingBidiStreams.CloseWithError(err)
  265. m.outgoingUniStreams.CloseWithError(err)
  266. m.incomingBidiStreams.CloseWithError(err)
  267. m.incomingUniStreams.CloseWithError(err)
  268. }
  269. // ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are
  270. // 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error.
  271. // 2. reset to their initial state, such that we can immediately process new incoming stream data.
  272. // Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error,
  273. // until UseResetMaps() has been called.
  274. func (m *streamsMap) ResetFor0RTT() {
  275. m.mutex.Lock()
  276. defer m.mutex.Unlock()
  277. m.reset = true
  278. m.CloseWithError(Err0RTTRejected)
  279. m.initMaps()
  280. }
  281. func (m *streamsMap) UseResetMaps() {
  282. m.mutex.Lock()
  283. m.reset = false
  284. m.mutex.Unlock()
  285. }