client.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. package http3
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/tls"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "strconv"
  11. "sync"
  12. "github.com/Psiphon-Labs/quic-go"
  13. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  14. "github.com/Psiphon-Labs/quic-go/internal/qtls"
  15. "github.com/Psiphon-Labs/quic-go/internal/utils"
  16. "github.com/Psiphon-Labs/quic-go/quicvarint"
  17. "github.com/marten-seemann/qpack"
  18. )
  19. // MethodGet0RTT allows a GET request to be sent using 0-RTT.
  20. // Note that 0-RTT data doesn't provide replay protection.
  21. const MethodGet0RTT = "GET_0RTT"
  22. const (
  23. defaultUserAgent = "quic-go HTTP/3"
  24. defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
  25. )
  26. var defaultQuicConfig = &quic.Config{
  27. MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
  28. KeepAlive: true,
  29. Versions: []protocol.VersionNumber{protocol.VersionTLS},
  30. }
  31. var dialAddr = quic.DialAddrEarly
  32. type roundTripperOpts struct {
  33. DisableCompression bool
  34. EnableDatagram bool
  35. MaxHeaderBytes int64
  36. }
  37. // client is a HTTP3 client doing requests
  38. type client struct {
  39. tlsConf *tls.Config
  40. config *quic.Config
  41. opts *roundTripperOpts
  42. dialOnce sync.Once
  43. dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
  44. handshakeErr error
  45. requestWriter *requestWriter
  46. decoder *qpack.Decoder
  47. hostname string
  48. // [Psiphon]
  49. // Enable Close to be called concurrently with dial.
  50. sessionMutex sync.Mutex
  51. closed bool
  52. session quic.EarlySession
  53. logger utils.Logger
  54. }
  55. func newClient(
  56. hostname string,
  57. tlsConf *tls.Config,
  58. opts *roundTripperOpts,
  59. quicConfig *quic.Config,
  60. dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error),
  61. ) (*client, error) {
  62. if quicConfig == nil {
  63. quicConfig = defaultQuicConfig.Clone()
  64. } else if len(quicConfig.Versions) == 0 {
  65. quicConfig = quicConfig.Clone()
  66. quicConfig.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]}
  67. }
  68. if len(quicConfig.Versions) != 1 {
  69. return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
  70. }
  71. quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
  72. quicConfig.EnableDatagrams = opts.EnableDatagram
  73. logger := utils.DefaultLogger.WithPrefix("h3 client")
  74. if tlsConf == nil {
  75. tlsConf = &tls.Config{}
  76. } else {
  77. tlsConf = tlsConf.Clone()
  78. }
  79. // Replace existing ALPNs by H3
  80. tlsConf.NextProtos = []string{versionToALPN(quicConfig.Versions[0])}
  81. return &client{
  82. hostname: authorityAddr("https", hostname),
  83. tlsConf: tlsConf,
  84. requestWriter: newRequestWriter(logger),
  85. decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
  86. config: quicConfig,
  87. opts: opts,
  88. dialer: dialer,
  89. logger: logger,
  90. }, nil
  91. }
  92. func (c *client) dial() error {
  93. var err error
  94. var session quic.EarlySession
  95. if c.dialer != nil {
  96. session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
  97. } else {
  98. session, err = dialAddr(c.hostname, c.tlsConf, c.config)
  99. }
  100. if err != nil {
  101. return err
  102. }
  103. // [Psiphon]
  104. c.sessionMutex.Lock()
  105. if c.closed {
  106. session.CloseWithError(quic.ApplicationErrorCode(errorNoError), "")
  107. err = errors.New("closed while dialing")
  108. } else {
  109. c.session = session
  110. }
  111. c.sessionMutex.Unlock()
  112. // send the SETTINGs frame, using 0-RTT data, if possible
  113. go func() {
  114. if err := c.setupSession(); err != nil {
  115. c.logger.Debugf("Setting up session failed: %s", err)
  116. c.session.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "")
  117. }
  118. }()
  119. go c.handleUnidirectionalStreams()
  120. return nil
  121. }
  122. func (c *client) setupSession() error {
  123. // open the control stream
  124. str, err := c.session.OpenUniStream()
  125. if err != nil {
  126. return err
  127. }
  128. buf := &bytes.Buffer{}
  129. quicvarint.Write(buf, streamTypeControlStream)
  130. // send the SETTINGS frame
  131. (&settingsFrame{Datagram: c.opts.EnableDatagram}).Write(buf)
  132. _, err = str.Write(buf.Bytes())
  133. return err
  134. }
  135. func (c *client) handleUnidirectionalStreams() {
  136. for {
  137. str, err := c.session.AcceptUniStream(context.Background())
  138. if err != nil {
  139. c.logger.Debugf("accepting unidirectional stream failed: %s", err)
  140. return
  141. }
  142. go func() {
  143. streamType, err := quicvarint.Read(quicvarint.NewReader(str))
  144. if err != nil {
  145. c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
  146. return
  147. }
  148. // We're only interested in the control stream here.
  149. switch streamType {
  150. case streamTypeControlStream:
  151. case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream:
  152. // Our QPACK implementation doesn't use the dynamic table yet.
  153. // TODO: check that only one stream of each type is opened.
  154. return
  155. case streamTypePushStream:
  156. // We never increased the Push ID, so we don't expect any push streams.
  157. c.session.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
  158. return
  159. default:
  160. str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
  161. return
  162. }
  163. f, err := parseNextFrame(str)
  164. if err != nil {
  165. c.session.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "")
  166. return
  167. }
  168. sf, ok := f.(*settingsFrame)
  169. if !ok {
  170. c.session.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "")
  171. return
  172. }
  173. if !sf.Datagram {
  174. return
  175. }
  176. // If datagram support was enabled on our side as well as on the server side,
  177. // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
  178. // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
  179. if c.opts.EnableDatagram && !c.session.ConnectionState().SupportsDatagrams {
  180. c.session.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
  181. }
  182. }()
  183. }
  184. }
  185. func (c *client) Close() error {
  186. // [Psiphon]
  187. c.sessionMutex.Lock()
  188. session := c.session
  189. c.closed = true
  190. c.sessionMutex.Unlock()
  191. if session == nil {
  192. return nil
  193. }
  194. return session.CloseWithError(quic.ApplicationErrorCode(errorNoError), "")
  195. }
  196. func (c *client) maxHeaderBytes() uint64 {
  197. if c.opts.MaxHeaderBytes <= 0 {
  198. return defaultMaxResponseHeaderBytes
  199. }
  200. return uint64(c.opts.MaxHeaderBytes)
  201. }
  202. // RoundTrip executes a request and returns a response
  203. func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
  204. if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
  205. return nil, fmt.Errorf("http3 client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
  206. }
  207. c.dialOnce.Do(func() {
  208. c.handshakeErr = c.dial()
  209. })
  210. if c.handshakeErr != nil {
  211. return nil, c.handshakeErr
  212. }
  213. // Immediately send out this request, if this is a 0-RTT request.
  214. if req.Method == MethodGet0RTT {
  215. req.Method = http.MethodGet
  216. } else {
  217. // wait for the handshake to complete
  218. select {
  219. case <-c.session.HandshakeComplete().Done():
  220. case <-req.Context().Done():
  221. return nil, req.Context().Err()
  222. }
  223. }
  224. str, err := c.session.OpenStreamSync(req.Context())
  225. if err != nil {
  226. return nil, err
  227. }
  228. // Request Cancellation:
  229. // This go routine keeps running even after RoundTrip() returns.
  230. // It is shut down when the application is done processing the body.
  231. reqDone := make(chan struct{})
  232. go func() {
  233. select {
  234. case <-req.Context().Done():
  235. str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
  236. str.CancelRead(quic.StreamErrorCode(errorRequestCanceled))
  237. case <-reqDone:
  238. }
  239. }()
  240. rsp, rerr := c.doRequest(req, str, reqDone)
  241. if rerr.err != nil { // if any error occurred
  242. close(reqDone)
  243. if rerr.streamErr != 0 { // if it was a stream error
  244. str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
  245. }
  246. if rerr.connErr != 0 { // if it was a connection error
  247. var reason string
  248. if rerr.err != nil {
  249. reason = rerr.err.Error()
  250. }
  251. c.session.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
  252. }
  253. }
  254. return rsp, rerr.err
  255. }
  256. func (c *client) doRequest(
  257. req *http.Request,
  258. str quic.Stream,
  259. reqDone chan struct{},
  260. ) (*http.Response, requestError) {
  261. var requestGzip bool
  262. if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
  263. requestGzip = true
  264. }
  265. if err := c.requestWriter.WriteRequest(str, req, requestGzip); err != nil {
  266. return nil, newStreamError(errorInternalError, err)
  267. }
  268. frame, err := parseNextFrame(str)
  269. if err != nil {
  270. return nil, newStreamError(errorFrameError, err)
  271. }
  272. hf, ok := frame.(*headersFrame)
  273. if !ok {
  274. return nil, newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
  275. }
  276. if hf.Length > c.maxHeaderBytes() {
  277. return nil, newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes()))
  278. }
  279. headerBlock := make([]byte, hf.Length)
  280. if _, err := io.ReadFull(str, headerBlock); err != nil {
  281. return nil, newStreamError(errorRequestIncomplete, err)
  282. }
  283. hfs, err := c.decoder.DecodeFull(headerBlock)
  284. if err != nil {
  285. // TODO: use the right error code
  286. return nil, newConnError(errorGeneralProtocolError, err)
  287. }
  288. connState := qtls.ToTLSConnectionState(c.session.ConnectionState().TLS)
  289. res := &http.Response{
  290. Proto: "HTTP/3",
  291. ProtoMajor: 3,
  292. Header: http.Header{},
  293. TLS: &connState,
  294. }
  295. for _, hf := range hfs {
  296. switch hf.Name {
  297. case ":status":
  298. status, err := strconv.Atoi(hf.Value)
  299. if err != nil {
  300. return nil, newStreamError(errorGeneralProtocolError, errors.New("malformed non-numeric status pseudo header"))
  301. }
  302. res.StatusCode = status
  303. res.Status = hf.Value + " " + http.StatusText(status)
  304. default:
  305. res.Header.Add(hf.Name, hf.Value)
  306. }
  307. }
  308. respBody := newResponseBody(str, reqDone, func() {
  309. c.session.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "")
  310. })
  311. // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
  312. _, hasTransferEncoding := res.Header["Transfer-Encoding"]
  313. isInformational := res.StatusCode >= 100 && res.StatusCode < 200
  314. isNoContent := res.StatusCode == 204
  315. isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300
  316. if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect {
  317. res.ContentLength = -1
  318. if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 {
  319. if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
  320. res.ContentLength = clen64
  321. }
  322. }
  323. }
  324. if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
  325. res.Header.Del("Content-Encoding")
  326. res.Header.Del("Content-Length")
  327. res.ContentLength = -1
  328. res.Body = newGzipReader(respBody)
  329. res.Uncompressed = true
  330. } else {
  331. res.Body = respBody
  332. }
  333. return res, requestError{}
  334. }