client.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. package http3
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net"
  9. "net/http"
  10. "strconv"
  11. "sync"
  12. "sync/atomic"
  13. "time"
  14. "github.com/Psiphon-Labs/quic-go"
  15. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  16. "github.com/Psiphon-Labs/quic-go/internal/qtls"
  17. "github.com/Psiphon-Labs/quic-go/internal/utils"
  18. "github.com/Psiphon-Labs/quic-go/quicvarint"
  19. "github.com/quic-go/qpack"
  20. )
  21. // MethodGet0RTT allows a GET request to be sent using 0-RTT.
  22. // Note that 0-RTT data doesn't provide replay protection.
  23. const MethodGet0RTT = "GET_0RTT"
  24. const (
  25. defaultUserAgent = "quic-go HTTP/3"
  26. defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
  27. )
  28. var defaultQuicConfig = &quic.Config{
  29. MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
  30. KeepAlivePeriod: 10 * time.Second,
  31. }
  32. type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
  33. var dialAddr dialFunc = quic.DialAddrEarly
  34. type roundTripperOpts struct {
  35. DisableCompression bool
  36. EnableDatagram bool
  37. MaxHeaderBytes int64
  38. AdditionalSettings map[uint64]uint64
  39. StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
  40. UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
  41. }
  42. // client is a HTTP3 client doing requests
  43. type client struct {
  44. tlsConf *tls.Config
  45. config *quic.Config
  46. opts *roundTripperOpts
  47. dialOnce sync.Once
  48. dialer dialFunc
  49. handshakeErr error
  50. requestWriter *requestWriter
  51. decoder *qpack.Decoder
  52. hostname string
  53. conn atomic.Pointer[quic.EarlyConnection]
  54. logger utils.Logger
  55. }
  56. var _ roundTripCloser = &client{}
  57. func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
  58. if conf == nil {
  59. conf = defaultQuicConfig.Clone()
  60. }
  61. if len(conf.Versions) == 0 {
  62. conf = conf.Clone()
  63. conf.Versions = []quic.VersionNumber{protocol.SupportedVersions[0]}
  64. }
  65. if len(conf.Versions) != 1 {
  66. return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
  67. }
  68. if conf.MaxIncomingStreams == 0 {
  69. conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams
  70. }
  71. conf.EnableDatagrams = opts.EnableDatagram
  72. logger := utils.DefaultLogger.WithPrefix("h3 client")
  73. if tlsConf == nil {
  74. tlsConf = &tls.Config{}
  75. } else {
  76. tlsConf = tlsConf.Clone()
  77. }
  78. if tlsConf.ServerName == "" {
  79. sni, _, err := net.SplitHostPort(hostname)
  80. if err != nil {
  81. // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port.
  82. sni = hostname
  83. }
  84. tlsConf.ServerName = sni
  85. }
  86. // Replace existing ALPNs by H3
  87. tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])}
  88. return &client{
  89. hostname: authorityAddr("https", hostname),
  90. tlsConf: tlsConf,
  91. requestWriter: newRequestWriter(logger),
  92. decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
  93. config: conf,
  94. opts: opts,
  95. dialer: dialer,
  96. logger: logger,
  97. }, nil
  98. }
  99. func (c *client) dial(ctx context.Context) error {
  100. var err error
  101. var conn quic.EarlyConnection
  102. if c.dialer != nil {
  103. conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config)
  104. } else {
  105. conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
  106. }
  107. if err != nil {
  108. return err
  109. }
  110. c.conn.Store(&conn)
  111. // send the SETTINGs frame, using 0-RTT data, if possible
  112. go func() {
  113. if err := c.setupConn(conn); err != nil {
  114. c.logger.Debugf("Setting up connection failed: %s", err)
  115. conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
  116. }
  117. }()
  118. if c.opts.StreamHijacker != nil {
  119. go c.handleBidirectionalStreams(conn)
  120. }
  121. go c.handleUnidirectionalStreams(conn)
  122. return nil
  123. }
  124. func (c *client) setupConn(conn quic.EarlyConnection) error {
  125. // open the control stream
  126. str, err := conn.OpenUniStream()
  127. if err != nil {
  128. return err
  129. }
  130. b := make([]byte, 0, 64)
  131. b = quicvarint.Append(b, streamTypeControlStream)
  132. // send the SETTINGS frame
  133. b = (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Append(b)
  134. _, err = str.Write(b)
  135. return err
  136. }
  137. func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) {
  138. for {
  139. str, err := conn.AcceptStream(context.Background())
  140. if err != nil {
  141. c.logger.Debugf("accepting bidirectional stream failed: %s", err)
  142. return
  143. }
  144. go func(str quic.Stream) {
  145. _, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
  146. return c.opts.StreamHijacker(ft, conn, str, e)
  147. })
  148. if err == errHijacked {
  149. return
  150. }
  151. if err != nil {
  152. c.logger.Debugf("error handling stream: %s", err)
  153. }
  154. conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
  155. }(str)
  156. }
  157. }
  158. func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) {
  159. for {
  160. str, err := conn.AcceptUniStream(context.Background())
  161. if err != nil {
  162. c.logger.Debugf("accepting unidirectional stream failed: %s", err)
  163. return
  164. }
  165. go func(str quic.ReceiveStream) {
  166. streamType, err := quicvarint.Read(quicvarint.NewReader(str))
  167. if err != nil {
  168. if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) {
  169. return
  170. }
  171. c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
  172. return
  173. }
  174. // We're only interested in the control stream here.
  175. switch streamType {
  176. case streamTypeControlStream:
  177. case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream:
  178. // Our QPACK implementation doesn't use the dynamic table yet.
  179. // TODO: check that only one stream of each type is opened.
  180. return
  181. case streamTypePushStream:
  182. // We never increased the Push ID, so we don't expect any push streams.
  183. conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "")
  184. return
  185. default:
  186. if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
  187. return
  188. }
  189. str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
  190. return
  191. }
  192. f, err := parseNextFrame(str, nil)
  193. if err != nil {
  194. conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
  195. return
  196. }
  197. sf, ok := f.(*settingsFrame)
  198. if !ok {
  199. conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
  200. return
  201. }
  202. if !sf.Datagram {
  203. return
  204. }
  205. // If datagram support was enabled on our side as well as on the server side,
  206. // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
  207. // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
  208. if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams {
  209. conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
  210. }
  211. }(str)
  212. }
  213. }
  214. func (c *client) Close() error {
  215. conn := c.conn.Load()
  216. if conn == nil {
  217. return nil
  218. }
  219. return (*conn).CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
  220. }
  221. func (c *client) maxHeaderBytes() uint64 {
  222. if c.opts.MaxHeaderBytes <= 0 {
  223. return defaultMaxResponseHeaderBytes
  224. }
  225. return uint64(c.opts.MaxHeaderBytes)
  226. }
  227. // RoundTripOpt executes a request and returns a response
  228. func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
  229. if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
  230. return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
  231. }
  232. c.dialOnce.Do(func() {
  233. c.handshakeErr = c.dial(req.Context())
  234. })
  235. if c.handshakeErr != nil {
  236. return nil, c.handshakeErr
  237. }
  238. // At this point, c.conn is guaranteed to be set.
  239. conn := *c.conn.Load()
  240. // Immediately send out this request, if this is a 0-RTT request.
  241. if req.Method == MethodGet0RTT {
  242. req.Method = http.MethodGet
  243. } else {
  244. // wait for the handshake to complete
  245. select {
  246. case <-conn.HandshakeComplete():
  247. case <-req.Context().Done():
  248. return nil, req.Context().Err()
  249. }
  250. }
  251. str, err := conn.OpenStreamSync(req.Context())
  252. if err != nil {
  253. return nil, err
  254. }
  255. // Request Cancellation:
  256. // This go routine keeps running even after RoundTripOpt() returns.
  257. // It is shut down when the application is done processing the body.
  258. reqDone := make(chan struct{})
  259. done := make(chan struct{})
  260. go func() {
  261. defer close(done)
  262. select {
  263. case <-req.Context().Done():
  264. str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
  265. str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
  266. case <-reqDone:
  267. }
  268. }()
  269. doneChan := reqDone
  270. if opt.DontCloseRequestStream {
  271. doneChan = nil
  272. }
  273. rsp, rerr := c.doRequest(req, conn, str, opt, doneChan)
  274. if rerr.err != nil { // if any error occurred
  275. close(reqDone)
  276. <-done
  277. if rerr.streamErr != 0 { // if it was a stream error
  278. str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
  279. }
  280. if rerr.connErr != 0 { // if it was a connection error
  281. var reason string
  282. if rerr.err != nil {
  283. reason = rerr.err.Error()
  284. }
  285. conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
  286. }
  287. return nil, rerr.err
  288. }
  289. if opt.DontCloseRequestStream {
  290. close(reqDone)
  291. <-done
  292. }
  293. return rsp, rerr.err
  294. }
  295. func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error {
  296. defer body.Close()
  297. b := make([]byte, bodyCopyBufferSize)
  298. for {
  299. n, rerr := body.Read(b)
  300. if n == 0 {
  301. if rerr == nil {
  302. continue
  303. }
  304. if rerr == io.EOF {
  305. break
  306. }
  307. }
  308. if _, err := str.Write(b[:n]); err != nil {
  309. return err
  310. }
  311. if rerr != nil {
  312. if rerr == io.EOF {
  313. break
  314. }
  315. str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
  316. return rerr
  317. }
  318. }
  319. return nil
  320. }
  321. func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
  322. var requestGzip bool
  323. if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
  324. requestGzip = true
  325. }
  326. if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil {
  327. return nil, newStreamError(ErrCodeInternalError, err)
  328. }
  329. if req.Body == nil && !opt.DontCloseRequestStream {
  330. str.Close()
  331. }
  332. hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") })
  333. if req.Body != nil {
  334. // send the request body asynchronously
  335. go func() {
  336. if err := c.sendRequestBody(hstr, req.Body); err != nil {
  337. c.logger.Errorf("Error writing request: %s", err)
  338. }
  339. if !opt.DontCloseRequestStream {
  340. hstr.Close()
  341. }
  342. }()
  343. }
  344. frame, err := parseNextFrame(str, nil)
  345. if err != nil {
  346. return nil, newStreamError(ErrCodeFrameError, err)
  347. }
  348. hf, ok := frame.(*headersFrame)
  349. if !ok {
  350. return nil, newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
  351. }
  352. if hf.Length > c.maxHeaderBytes() {
  353. return nil, newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes()))
  354. }
  355. headerBlock := make([]byte, hf.Length)
  356. if _, err := io.ReadFull(str, headerBlock); err != nil {
  357. return nil, newStreamError(ErrCodeRequestIncomplete, err)
  358. }
  359. hfs, err := c.decoder.DecodeFull(headerBlock)
  360. if err != nil {
  361. // TODO: use the right error code
  362. return nil, newConnError(ErrCodeGeneralProtocolError, err)
  363. }
  364. connState := qtls.ToTLSConnectionState(conn.ConnectionState().TLS)
  365. res := &http.Response{
  366. Proto: "HTTP/3.0",
  367. ProtoMajor: 3,
  368. Header: http.Header{},
  369. TLS: &connState,
  370. Request: req,
  371. }
  372. for _, hf := range hfs {
  373. switch hf.Name {
  374. case ":status":
  375. status, err := strconv.Atoi(hf.Value)
  376. if err != nil {
  377. return nil, newStreamError(ErrCodeGeneralProtocolError, errors.New("malformed non-numeric status pseudo header"))
  378. }
  379. res.StatusCode = status
  380. res.Status = hf.Value + " " + http.StatusText(status)
  381. default:
  382. res.Header.Add(hf.Name, hf.Value)
  383. }
  384. }
  385. respBody := newResponseBody(hstr, conn, reqDone)
  386. // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
  387. _, hasTransferEncoding := res.Header["Transfer-Encoding"]
  388. isInformational := res.StatusCode >= 100 && res.StatusCode < 200
  389. isNoContent := res.StatusCode == http.StatusNoContent
  390. isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300
  391. if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect {
  392. res.ContentLength = -1
  393. if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 {
  394. if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
  395. res.ContentLength = clen64
  396. }
  397. }
  398. }
  399. if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
  400. res.Header.Del("Content-Encoding")
  401. res.Header.Del("Content-Length")
  402. res.ContentLength = -1
  403. res.Body = newGzipReader(respBody)
  404. res.Uncompressed = true
  405. } else {
  406. res.Body = respBody
  407. }
  408. return res, requestError{}
  409. }
  410. func (c *client) HandshakeComplete() bool {
  411. conn := c.conn.Load()
  412. if conn == nil {
  413. return false
  414. }
  415. select {
  416. case <-(*conn).HandshakeComplete():
  417. return true
  418. default:
  419. return false
  420. }
  421. }