client.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. package http3
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/tls"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "strconv"
  11. "sync"
  12. "time"
  13. "github.com/Psiphon-Labs/quic-go"
  14. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  15. "github.com/Psiphon-Labs/quic-go/internal/qtls"
  16. "github.com/Psiphon-Labs/quic-go/internal/utils"
  17. "github.com/Psiphon-Labs/quic-go/quicvarint"
  18. "github.com/marten-seemann/qpack"
  19. )
  20. // MethodGet0RTT allows a GET request to be sent using 0-RTT.
  21. // Note that 0-RTT data doesn't provide replay protection.
  22. const MethodGet0RTT = "GET_0RTT"
  23. const (
  24. defaultUserAgent = "quic-go HTTP/3"
  25. defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
  26. )
  27. var defaultQuicConfig = &quic.Config{
  28. MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
  29. KeepAlivePeriod: 10 * time.Second,
  30. Versions: []protocol.VersionNumber{protocol.VersionTLS},
  31. }
  32. type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
  33. var dialAddr = quic.DialAddrEarlyContext
  34. type roundTripperOpts struct {
  35. DisableCompression bool
  36. EnableDatagram bool
  37. MaxHeaderBytes int64
  38. AdditionalSettings map[uint64]uint64
  39. StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
  40. UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
  41. }
  42. // client is a HTTP3 client doing requests
  43. type client struct {
  44. tlsConf *tls.Config
  45. config *quic.Config
  46. opts *roundTripperOpts
  47. dialOnce sync.Once
  48. dialer dialFunc
  49. handshakeErr error
  50. requestWriter *requestWriter
  51. decoder *qpack.Decoder
  52. hostname string
  53. // [Psiphon]
  54. // Enable Close to be called concurrently with dial.
  55. connMutex sync.Mutex
  56. closed bool
  57. conn quic.EarlyConnection
  58. logger utils.Logger
  59. }
  60. func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (*client, error) {
  61. if conf == nil {
  62. conf = defaultQuicConfig.Clone()
  63. } else if len(conf.Versions) == 0 {
  64. conf = conf.Clone()
  65. conf.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]}
  66. }
  67. if len(conf.Versions) != 1 {
  68. return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
  69. }
  70. if conf.MaxIncomingStreams == 0 {
  71. conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams
  72. }
  73. conf.EnableDatagrams = opts.EnableDatagram
  74. logger := utils.DefaultLogger.WithPrefix("h3 client")
  75. if tlsConf == nil {
  76. tlsConf = &tls.Config{}
  77. } else {
  78. tlsConf = tlsConf.Clone()
  79. }
  80. // Replace existing ALPNs by H3
  81. tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])}
  82. return &client{
  83. hostname: authorityAddr("https", hostname),
  84. tlsConf: tlsConf,
  85. requestWriter: newRequestWriter(logger),
  86. decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
  87. config: conf,
  88. opts: opts,
  89. dialer: dialer,
  90. logger: logger,
  91. }, nil
  92. }
  93. func (c *client) dial(ctx context.Context) error {
  94. var err error
  95. var conn quic.EarlyConnection
  96. if c.dialer != nil {
  97. conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config)
  98. } else {
  99. conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
  100. }
  101. if err != nil {
  102. return err
  103. }
  104. // [Psiphon]
  105. c.connMutex.Lock()
  106. if c.closed {
  107. conn.CloseWithError(quic.ApplicationErrorCode(errorNoError), "")
  108. err = errors.New("closed while dialing")
  109. } else {
  110. c.conn = conn
  111. }
  112. c.connMutex.Unlock()
  113. // send the SETTINGs frame, using 0-RTT data, if possible
  114. go func() {
  115. if err := c.setupConn(); err != nil {
  116. c.logger.Debugf("Setting up connection failed: %s", err)
  117. c.conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "")
  118. }
  119. }()
  120. if c.opts.StreamHijacker != nil {
  121. go c.handleBidirectionalStreams()
  122. }
  123. go c.handleUnidirectionalStreams()
  124. return nil
  125. }
  126. func (c *client) setupConn() error {
  127. // open the control stream
  128. str, err := c.conn.OpenUniStream()
  129. if err != nil {
  130. return err
  131. }
  132. buf := &bytes.Buffer{}
  133. quicvarint.Write(buf, streamTypeControlStream)
  134. // send the SETTINGS frame
  135. (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Write(buf)
  136. _, err = str.Write(buf.Bytes())
  137. return err
  138. }
  139. func (c *client) handleBidirectionalStreams() {
  140. for {
  141. str, err := c.conn.AcceptStream(context.Background())
  142. if err != nil {
  143. c.logger.Debugf("accepting bidirectional stream failed: %s", err)
  144. return
  145. }
  146. go func(str quic.Stream) {
  147. _, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
  148. return c.opts.StreamHijacker(ft, c.conn, str, e)
  149. })
  150. if err == errHijacked {
  151. return
  152. }
  153. if err != nil {
  154. c.logger.Debugf("error handling stream: %s", err)
  155. }
  156. c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
  157. }(str)
  158. }
  159. }
  160. func (c *client) handleUnidirectionalStreams() {
  161. for {
  162. str, err := c.conn.AcceptUniStream(context.Background())
  163. if err != nil {
  164. c.logger.Debugf("accepting unidirectional stream failed: %s", err)
  165. return
  166. }
  167. go func(str quic.ReceiveStream) {
  168. streamType, err := quicvarint.Read(quicvarint.NewReader(str))
  169. if err != nil {
  170. if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, err) {
  171. return
  172. }
  173. c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
  174. return
  175. }
  176. // We're only interested in the control stream here.
  177. switch streamType {
  178. case streamTypeControlStream:
  179. case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream:
  180. // Our QPACK implementation doesn't use the dynamic table yet.
  181. // TODO: check that only one stream of each type is opened.
  182. return
  183. case streamTypePushStream:
  184. // We never increased the Push ID, so we don't expect any push streams.
  185. c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
  186. return
  187. default:
  188. if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, nil) {
  189. return
  190. }
  191. str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
  192. return
  193. }
  194. f, err := parseNextFrame(str, nil)
  195. if err != nil {
  196. c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "")
  197. return
  198. }
  199. sf, ok := f.(*settingsFrame)
  200. if !ok {
  201. c.conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "")
  202. return
  203. }
  204. if !sf.Datagram {
  205. return
  206. }
  207. // If datagram support was enabled on our side as well as on the server side,
  208. // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
  209. // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
  210. if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams {
  211. c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
  212. }
  213. }(str)
  214. }
  215. }
  216. func (c *client) Close() error {
  217. // [Psiphon]
  218. c.connMutex.Lock()
  219. conn := c.conn
  220. c.closed = true
  221. c.connMutex.Unlock()
  222. if conn == nil {
  223. return nil
  224. }
  225. return conn.CloseWithError(quic.ApplicationErrorCode(errorNoError), "")
  226. }
  227. func (c *client) maxHeaderBytes() uint64 {
  228. if c.opts.MaxHeaderBytes <= 0 {
  229. return defaultMaxResponseHeaderBytes
  230. }
  231. return uint64(c.opts.MaxHeaderBytes)
  232. }
  233. // RoundTripOpt executes a request and returns a response
  234. func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
  235. if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
  236. return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
  237. }
  238. c.dialOnce.Do(func() {
  239. c.handshakeErr = c.dial(req.Context())
  240. })
  241. if c.handshakeErr != nil {
  242. return nil, c.handshakeErr
  243. }
  244. // Immediately send out this request, if this is a 0-RTT request.
  245. if req.Method == MethodGet0RTT {
  246. req.Method = http.MethodGet
  247. } else {
  248. // wait for the handshake to complete
  249. select {
  250. case <-c.conn.HandshakeComplete().Done():
  251. case <-req.Context().Done():
  252. return nil, req.Context().Err()
  253. }
  254. }
  255. str, err := c.conn.OpenStreamSync(req.Context())
  256. if err != nil {
  257. return nil, err
  258. }
  259. // Request Cancellation:
  260. // This go routine keeps running even after RoundTripOpt() returns.
  261. // It is shut down when the application is done processing the body.
  262. reqDone := make(chan struct{})
  263. done := make(chan struct{})
  264. go func() {
  265. defer close(done)
  266. select {
  267. case <-req.Context().Done():
  268. str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
  269. str.CancelRead(quic.StreamErrorCode(errorRequestCanceled))
  270. case <-reqDone:
  271. }
  272. }()
  273. doneChan := reqDone
  274. if opt.DontCloseRequestStream {
  275. doneChan = nil
  276. }
  277. rsp, rerr := c.doRequest(req, str, opt, doneChan)
  278. if rerr.err != nil { // if any error occurred
  279. close(reqDone)
  280. <-done
  281. if rerr.streamErr != 0 { // if it was a stream error
  282. str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
  283. }
  284. if rerr.connErr != 0 { // if it was a connection error
  285. var reason string
  286. if rerr.err != nil {
  287. reason = rerr.err.Error()
  288. }
  289. c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
  290. }
  291. return nil, rerr.err
  292. }
  293. if opt.DontCloseRequestStream {
  294. close(reqDone)
  295. <-done
  296. }
  297. return rsp, rerr.err
  298. }
  299. func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error {
  300. defer body.Close()
  301. b := make([]byte, bodyCopyBufferSize)
  302. for {
  303. n, rerr := body.Read(b)
  304. if n == 0 {
  305. if rerr == nil {
  306. continue
  307. }
  308. if rerr == io.EOF {
  309. break
  310. }
  311. }
  312. if _, err := str.Write(b[:n]); err != nil {
  313. return err
  314. }
  315. if rerr != nil {
  316. if rerr == io.EOF {
  317. break
  318. }
  319. str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
  320. return rerr
  321. }
  322. }
  323. return nil
  324. }
  325. func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
  326. var requestGzip bool
  327. if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
  328. requestGzip = true
  329. }
  330. if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil {
  331. return nil, newStreamError(errorInternalError, err)
  332. }
  333. if req.Body == nil && !opt.DontCloseRequestStream {
  334. str.Close()
  335. }
  336. hstr := newStream(str, func() { c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") })
  337. if req.Body != nil {
  338. // send the request body asynchronously
  339. go func() {
  340. if err := c.sendRequestBody(hstr, req.Body); err != nil {
  341. c.logger.Errorf("Error writing request: %s", err)
  342. }
  343. if !opt.DontCloseRequestStream {
  344. hstr.Close()
  345. }
  346. }()
  347. }
  348. frame, err := parseNextFrame(str, nil)
  349. if err != nil {
  350. return nil, newStreamError(errorFrameError, err)
  351. }
  352. hf, ok := frame.(*headersFrame)
  353. if !ok {
  354. return nil, newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
  355. }
  356. if hf.Length > c.maxHeaderBytes() {
  357. return nil, newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes()))
  358. }
  359. headerBlock := make([]byte, hf.Length)
  360. if _, err := io.ReadFull(str, headerBlock); err != nil {
  361. return nil, newStreamError(errorRequestIncomplete, err)
  362. }
  363. hfs, err := c.decoder.DecodeFull(headerBlock)
  364. if err != nil {
  365. // TODO: use the right error code
  366. return nil, newConnError(errorGeneralProtocolError, err)
  367. }
  368. connState := qtls.ToTLSConnectionState(c.conn.ConnectionState().TLS)
  369. res := &http.Response{
  370. Proto: "HTTP/3.0",
  371. ProtoMajor: 3,
  372. Header: http.Header{},
  373. TLS: &connState,
  374. }
  375. for _, hf := range hfs {
  376. switch hf.Name {
  377. case ":status":
  378. status, err := strconv.Atoi(hf.Value)
  379. if err != nil {
  380. return nil, newStreamError(errorGeneralProtocolError, errors.New("malformed non-numeric status pseudo header"))
  381. }
  382. res.StatusCode = status
  383. res.Status = hf.Value + " " + http.StatusText(status)
  384. default:
  385. res.Header.Add(hf.Name, hf.Value)
  386. }
  387. }
  388. respBody := newResponseBody(hstr, c.conn, reqDone)
  389. // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
  390. _, hasTransferEncoding := res.Header["Transfer-Encoding"]
  391. isInformational := res.StatusCode >= 100 && res.StatusCode < 200
  392. isNoContent := res.StatusCode == 204
  393. isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300
  394. if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect {
  395. res.ContentLength = -1
  396. if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 {
  397. if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
  398. res.ContentLength = clen64
  399. }
  400. }
  401. }
  402. if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
  403. res.Header.Del("Content-Encoding")
  404. res.Header.Del("Content-Length")
  405. res.ContentLength = -1
  406. res.Body = newGzipReader(respBody)
  407. res.Uncompressed = true
  408. } else {
  409. res.Body = respBody
  410. }
  411. return res, requestError{}
  412. }