client.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. package http3
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "net/http"
  9. "strconv"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. tls "github.com/Psiphon-Labs/psiphon-tls"
  14. "github.com/Psiphon-Labs/quic-go"
  15. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  16. "github.com/Psiphon-Labs/quic-go/internal/utils"
  17. "github.com/Psiphon-Labs/quic-go/quicvarint"
  18. "github.com/quic-go/qpack"
  19. )
  20. // MethodGet0RTT allows a GET request to be sent using 0-RTT.
  21. // Note that 0-RTT data doesn't provide replay protection.
  22. const MethodGet0RTT = "GET_0RTT"
  23. const (
  24. defaultUserAgent = "quic-go HTTP/3"
  25. defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
  26. )
  27. var defaultQuicConfig = &quic.Config{
  28. MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
  29. KeepAlivePeriod: 10 * time.Second,
  30. }
  31. type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
  32. var dialAddr dialFunc = quic.DialAddrEarly
  33. type roundTripperOpts struct {
  34. DisableCompression bool
  35. EnableDatagram bool
  36. MaxHeaderBytes int64
  37. AdditionalSettings map[uint64]uint64
  38. StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
  39. UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
  40. }
  41. // client is a HTTP3 client doing requests
  42. type client struct {
  43. tlsConf *tls.Config
  44. config *quic.Config
  45. opts *roundTripperOpts
  46. dialOnce sync.Once
  47. dialer dialFunc
  48. handshakeErr error
  49. requestWriter *requestWriter
  50. decoder *qpack.Decoder
  51. hostname string
  52. conn atomic.Pointer[quic.EarlyConnection]
  53. logger utils.Logger
  54. }
  55. var _ roundTripCloser = &client{}
  56. func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
  57. if conf == nil {
  58. conf = defaultQuicConfig.Clone()
  59. }
  60. if len(conf.Versions) == 0 {
  61. conf = conf.Clone()
  62. conf.Versions = []quic.VersionNumber{protocol.SupportedVersions[0]}
  63. }
  64. if len(conf.Versions) != 1 {
  65. return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
  66. }
  67. if conf.MaxIncomingStreams == 0 {
  68. conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams
  69. }
  70. conf.EnableDatagrams = opts.EnableDatagram
  71. logger := utils.DefaultLogger.WithPrefix("h3 client")
  72. if tlsConf == nil {
  73. tlsConf = &tls.Config{}
  74. } else {
  75. tlsConf = tlsConf.Clone()
  76. }
  77. if tlsConf.ServerName == "" {
  78. sni, _, err := net.SplitHostPort(hostname)
  79. if err != nil {
  80. // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port.
  81. sni = hostname
  82. }
  83. tlsConf.ServerName = sni
  84. }
  85. // Replace existing ALPNs by H3
  86. tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])}
  87. return &client{
  88. hostname: authorityAddr("https", hostname),
  89. tlsConf: tlsConf,
  90. requestWriter: newRequestWriter(logger),
  91. decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
  92. config: conf,
  93. opts: opts,
  94. dialer: dialer,
  95. logger: logger,
  96. }, nil
  97. }
  98. func (c *client) dial(ctx context.Context) error {
  99. var err error
  100. var conn quic.EarlyConnection
  101. if c.dialer != nil {
  102. conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config)
  103. } else {
  104. conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
  105. }
  106. if err != nil {
  107. return err
  108. }
  109. c.conn.Store(&conn)
  110. // send the SETTINGs frame, using 0-RTT data, if possible
  111. go func() {
  112. if err := c.setupConn(conn); err != nil {
  113. c.logger.Debugf("Setting up connection failed: %s", err)
  114. conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
  115. }
  116. }()
  117. if c.opts.StreamHijacker != nil {
  118. go c.handleBidirectionalStreams(conn)
  119. }
  120. go c.handleUnidirectionalStreams(conn)
  121. return nil
  122. }
  123. func (c *client) setupConn(conn quic.EarlyConnection) error {
  124. // open the control stream
  125. str, err := conn.OpenUniStream()
  126. if err != nil {
  127. return err
  128. }
  129. b := make([]byte, 0, 64)
  130. b = quicvarint.Append(b, streamTypeControlStream)
  131. // send the SETTINGS frame
  132. b = (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Append(b)
  133. _, err = str.Write(b)
  134. return err
  135. }
  136. func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) {
  137. for {
  138. str, err := conn.AcceptStream(context.Background())
  139. if err != nil {
  140. c.logger.Debugf("accepting bidirectional stream failed: %s", err)
  141. return
  142. }
  143. go func(str quic.Stream) {
  144. _, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
  145. return c.opts.StreamHijacker(ft, conn, str, e)
  146. })
  147. if err == errHijacked {
  148. return
  149. }
  150. if err != nil {
  151. c.logger.Debugf("error handling stream: %s", err)
  152. }
  153. conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
  154. }(str)
  155. }
  156. }
  157. func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) {
  158. for {
  159. str, err := conn.AcceptUniStream(context.Background())
  160. if err != nil {
  161. c.logger.Debugf("accepting unidirectional stream failed: %s", err)
  162. return
  163. }
  164. go func(str quic.ReceiveStream) {
  165. streamType, err := quicvarint.Read(quicvarint.NewReader(str))
  166. if err != nil {
  167. if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) {
  168. return
  169. }
  170. c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
  171. return
  172. }
  173. // We're only interested in the control stream here.
  174. switch streamType {
  175. case streamTypeControlStream:
  176. case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream:
  177. // Our QPACK implementation doesn't use the dynamic table yet.
  178. // TODO: check that only one stream of each type is opened.
  179. return
  180. case streamTypePushStream:
  181. // We never increased the Push ID, so we don't expect any push streams.
  182. conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "")
  183. return
  184. default:
  185. if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
  186. return
  187. }
  188. str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
  189. return
  190. }
  191. f, err := parseNextFrame(str, nil)
  192. if err != nil {
  193. conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
  194. return
  195. }
  196. sf, ok := f.(*settingsFrame)
  197. if !ok {
  198. conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
  199. return
  200. }
  201. if !sf.Datagram {
  202. return
  203. }
  204. // If datagram support was enabled on our side as well as on the server side,
  205. // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
  206. // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
  207. if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams {
  208. conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
  209. }
  210. }(str)
  211. }
  212. }
  213. func (c *client) Close() error {
  214. conn := c.conn.Load()
  215. if conn == nil {
  216. return nil
  217. }
  218. return (*conn).CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
  219. }
  220. func (c *client) maxHeaderBytes() uint64 {
  221. if c.opts.MaxHeaderBytes <= 0 {
  222. return defaultMaxResponseHeaderBytes
  223. }
  224. return uint64(c.opts.MaxHeaderBytes)
  225. }
  226. // RoundTripOpt executes a request and returns a response
  227. func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
  228. rsp, err := c.roundTripOpt(req, opt)
  229. if err != nil && req.Context().Err() != nil {
  230. // if the context was canceled, return the context cancellation error
  231. err = req.Context().Err()
  232. }
  233. return rsp, err
  234. }
  235. func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
  236. if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
  237. return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
  238. }
  239. c.dialOnce.Do(func() {
  240. c.handshakeErr = c.dial(req.Context())
  241. })
  242. if c.handshakeErr != nil {
  243. return nil, c.handshakeErr
  244. }
  245. // At this point, c.conn is guaranteed to be set.
  246. conn := *c.conn.Load()
  247. // Immediately send out this request, if this is a 0-RTT request.
  248. if req.Method == MethodGet0RTT {
  249. req.Method = http.MethodGet
  250. } else {
  251. // wait for the handshake to complete
  252. select {
  253. case <-conn.HandshakeComplete():
  254. case <-req.Context().Done():
  255. return nil, req.Context().Err()
  256. }
  257. }
  258. str, err := conn.OpenStreamSync(req.Context())
  259. if err != nil {
  260. return nil, err
  261. }
  262. // Request Cancellation:
  263. // This go routine keeps running even after RoundTripOpt() returns.
  264. // It is shut down when the application is done processing the body.
  265. reqDone := make(chan struct{})
  266. done := make(chan struct{})
  267. go func() {
  268. defer close(done)
  269. select {
  270. case <-req.Context().Done():
  271. str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
  272. str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
  273. case <-reqDone:
  274. }
  275. }()
  276. doneChan := reqDone
  277. if opt.DontCloseRequestStream {
  278. doneChan = nil
  279. }
  280. rsp, rerr := c.doRequest(req, conn, str, opt, doneChan)
  281. if rerr.err != nil { // if any error occurred
  282. close(reqDone)
  283. <-done
  284. if rerr.streamErr != 0 { // if it was a stream error
  285. str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
  286. }
  287. if rerr.connErr != 0 { // if it was a connection error
  288. var reason string
  289. if rerr.err != nil {
  290. reason = rerr.err.Error()
  291. }
  292. conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
  293. }
  294. return nil, maybeReplaceError(rerr.err)
  295. }
  296. if opt.DontCloseRequestStream {
  297. close(reqDone)
  298. <-done
  299. }
  300. return rsp, maybeReplaceError(rerr.err)
  301. }
  302. // cancelingReader reads from the io.Reader.
  303. // It cancels writing on the stream if any error other than io.EOF occurs.
  304. type cancelingReader struct {
  305. r io.Reader
  306. str Stream
  307. }
  308. func (r *cancelingReader) Read(b []byte) (int, error) {
  309. n, err := r.r.Read(b)
  310. if err != nil && err != io.EOF {
  311. r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
  312. }
  313. return n, err
  314. }
  315. func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
  316. defer body.Close()
  317. buf := make([]byte, bodyCopyBufferSize)
  318. sr := &cancelingReader{str: str, r: body}
  319. if contentLength == -1 {
  320. _, err := io.CopyBuffer(str, sr, buf)
  321. return err
  322. }
  323. // make sure we don't send more bytes than the content length
  324. n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf)
  325. if err != nil {
  326. return err
  327. }
  328. var extra int64
  329. extra, err = io.CopyBuffer(io.Discard, sr, buf)
  330. n += extra
  331. if n > contentLength {
  332. str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
  333. return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n)
  334. }
  335. return err
  336. }
  337. func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
  338. var requestGzip bool
  339. if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
  340. requestGzip = true
  341. }
  342. if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil {
  343. return nil, newStreamError(ErrCodeInternalError, err)
  344. }
  345. if req.Body == nil && !opt.DontCloseRequestStream {
  346. str.Close()
  347. }
  348. hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") })
  349. if req.Body != nil {
  350. // send the request body asynchronously
  351. go func() {
  352. contentLength := int64(-1)
  353. // According to the documentation for http.Request.ContentLength,
  354. // a value of 0 with a non-nil Body is also treated as unknown content length.
  355. if req.ContentLength > 0 {
  356. contentLength = req.ContentLength
  357. }
  358. if err := c.sendRequestBody(hstr, req.Body, contentLength); err != nil {
  359. c.logger.Errorf("Error writing request: %s", err)
  360. }
  361. if !opt.DontCloseRequestStream {
  362. hstr.Close()
  363. }
  364. }()
  365. }
  366. frame, err := parseNextFrame(str, nil)
  367. if err != nil {
  368. return nil, newStreamError(ErrCodeFrameError, err)
  369. }
  370. hf, ok := frame.(*headersFrame)
  371. if !ok {
  372. return nil, newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
  373. }
  374. if hf.Length > c.maxHeaderBytes() {
  375. return nil, newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes()))
  376. }
  377. headerBlock := make([]byte, hf.Length)
  378. if _, err := io.ReadFull(str, headerBlock); err != nil {
  379. return nil, newStreamError(ErrCodeRequestIncomplete, err)
  380. }
  381. hfs, err := c.decoder.DecodeFull(headerBlock)
  382. if err != nil {
  383. // TODO: use the right error code
  384. return nil, newConnError(ErrCodeGeneralProtocolError, err)
  385. }
  386. res, err := responseFromHeaders(hfs)
  387. if err != nil {
  388. return nil, newStreamError(ErrCodeMessageError, err)
  389. }
  390. connState := conn.ConnectionState().TLS
  391. // [Psiphon]
  392. res.TLS = tls.UnsafeFromConnectionState(&connState)
  393. res.Request = req
  394. // Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
  395. // See section 4.1.2 of RFC 9114.
  396. var httpStr Stream
  397. if _, ok := res.Header["Content-Length"]; ok && res.ContentLength >= 0 {
  398. httpStr = newLengthLimitedStream(hstr, res.ContentLength)
  399. } else {
  400. httpStr = hstr
  401. }
  402. respBody := newResponseBody(httpStr, conn, reqDone)
  403. // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
  404. _, hasTransferEncoding := res.Header["Transfer-Encoding"]
  405. isInformational := res.StatusCode >= 100 && res.StatusCode < 200
  406. isNoContent := res.StatusCode == http.StatusNoContent
  407. isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300
  408. if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect {
  409. res.ContentLength = -1
  410. if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 {
  411. if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
  412. res.ContentLength = clen64
  413. }
  414. }
  415. }
  416. if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
  417. res.Header.Del("Content-Encoding")
  418. res.Header.Del("Content-Length")
  419. res.ContentLength = -1
  420. res.Body = newGzipReader(respBody)
  421. res.Uncompressed = true
  422. } else {
  423. res.Body = respBody
  424. }
  425. return res, requestError{}
  426. }
  427. func (c *client) HandshakeComplete() bool {
  428. conn := c.conn.Load()
  429. if conn == nil {
  430. return false
  431. }
  432. select {
  433. case <-(*conn).HandshakeComplete():
  434. return true
  435. default:
  436. return false
  437. }
  438. }