client.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. package http3
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/tls"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "strconv"
  11. "sync"
  12. "github.com/Psiphon-Labs/quic-go"
  13. "github.com/Psiphon-Labs/quic-go/internal/utils"
  14. "github.com/marten-seemann/qpack"
  15. )
  16. const defaultUserAgent = "quic-go HTTP/3"
  17. const defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
  18. var defaultQuicConfig = &quic.Config{
  19. MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
  20. KeepAlive: true,
  21. }
  22. var dialAddr = quic.DialAddr
  23. type roundTripperOpts struct {
  24. DisableCompression bool
  25. MaxHeaderBytes int64
  26. }
  27. // client is a HTTP3 client doing requests
  28. type client struct {
  29. tlsConf *tls.Config
  30. config *quic.Config
  31. opts *roundTripperOpts
  32. dialOnce sync.Once
  33. dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
  34. handshakeErr error
  35. requestWriter *requestWriter
  36. decoder *qpack.Decoder
  37. hostname string
  38. // [Psiphon]
  39. setSession sync.Mutex
  40. session quic.Session
  41. logger utils.Logger
  42. }
  43. func newClient(
  44. hostname string,
  45. tlsConf *tls.Config,
  46. opts *roundTripperOpts,
  47. quicConfig *quic.Config,
  48. dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
  49. ) *client {
  50. if tlsConf == nil {
  51. tlsConf = &tls.Config{}
  52. } else {
  53. tlsConf = tlsConf.Clone()
  54. }
  55. // Replace existing ALPNs by H3
  56. tlsConf.NextProtos = []string{nextProtoH3}
  57. if quicConfig == nil {
  58. quicConfig = defaultQuicConfig
  59. }
  60. // [Psiphon]
  61. // Prevent race condition the results from concurrent RoundTrippers using defaultQuicConfig
  62. if quicConfig.MaxIncomingStreams != -1 {
  63. quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
  64. }
  65. logger := utils.DefaultLogger.WithPrefix("h3 client")
  66. return &client{
  67. hostname: authorityAddr("https", hostname),
  68. tlsConf: tlsConf,
  69. requestWriter: newRequestWriter(logger),
  70. decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
  71. config: quicConfig,
  72. opts: opts,
  73. dialer: dialer,
  74. logger: logger,
  75. }
  76. }
  77. func (c *client) dial() error {
  78. var err error
  79. var session quic.Session
  80. if c.dialer != nil {
  81. session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
  82. } else {
  83. session, err = dialAddr(c.hostname, c.tlsConf, c.config)
  84. }
  85. // [Psiphon]
  86. c.setSession.Lock()
  87. c.session = session
  88. c.setSession.Unlock()
  89. if err != nil {
  90. return err
  91. }
  92. go func() {
  93. if err := c.setupSession(); err != nil {
  94. c.logger.Debugf("Setting up session failed: %s", err)
  95. c.session.CloseWithError(quic.ErrorCode(errorInternalError), "")
  96. }
  97. }()
  98. return nil
  99. }
  100. func (c *client) setupSession() error {
  101. // open the control stream
  102. str, err := c.session.OpenUniStream()
  103. if err != nil {
  104. return err
  105. }
  106. buf := &bytes.Buffer{}
  107. // write the type byte
  108. buf.Write([]byte{0x0})
  109. // send the SETTINGS frame
  110. (&settingsFrame{}).Write(buf)
  111. if _, err := str.Write(buf.Bytes()); err != nil {
  112. return err
  113. }
  114. return nil
  115. }
  116. func (c *client) Close() error {
  117. // [Psiphon]
  118. // Prevent panic when c.session is nil
  119. c.setSession.Lock()
  120. session := c.session
  121. c.setSession.Unlock()
  122. if session == nil {
  123. return nil
  124. }
  125. return c.session.Close()
  126. }
  127. func (c *client) maxHeaderBytes() uint64 {
  128. if c.opts.MaxHeaderBytes <= 0 {
  129. return defaultMaxResponseHeaderBytes
  130. }
  131. return uint64(c.opts.MaxHeaderBytes)
  132. }
  133. // RoundTrip executes a request and returns a response
  134. func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
  135. if req.URL.Scheme != "https" {
  136. return nil, errors.New("http3: unsupported scheme")
  137. }
  138. if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
  139. return nil, fmt.Errorf("http3 client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
  140. }
  141. c.dialOnce.Do(func() {
  142. c.handshakeErr = c.dial()
  143. })
  144. if c.handshakeErr != nil {
  145. return nil, c.handshakeErr
  146. }
  147. str, err := c.session.OpenStreamSync(context.Background())
  148. if err != nil {
  149. return nil, err
  150. }
  151. // Request Cancellation:
  152. // This go routine keeps running even after RoundTrip() returns.
  153. // It is shut down when the application is done processing the body.
  154. reqDone := make(chan struct{})
  155. go func() {
  156. select {
  157. case <-req.Context().Done():
  158. str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
  159. str.CancelRead(quic.ErrorCode(errorRequestCanceled))
  160. case <-reqDone:
  161. }
  162. }()
  163. rsp, rerr := c.doRequest(req, str, reqDone)
  164. if rerr.err != nil { // if any error occurred
  165. close(reqDone)
  166. if rerr.streamErr != 0 { // if it was a stream error
  167. str.CancelWrite(quic.ErrorCode(rerr.streamErr))
  168. }
  169. if rerr.connErr != 0 { // if it was a connection error
  170. var reason string
  171. if rerr.err != nil {
  172. reason = rerr.err.Error()
  173. }
  174. c.session.CloseWithError(quic.ErrorCode(rerr.connErr), reason)
  175. }
  176. }
  177. return rsp, rerr.err
  178. }
  179. func (c *client) doRequest(
  180. req *http.Request,
  181. str quic.Stream,
  182. reqDone chan struct{},
  183. ) (*http.Response, requestError) {
  184. var requestGzip bool
  185. if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
  186. requestGzip = true
  187. }
  188. if err := c.requestWriter.WriteRequest(str, req, requestGzip); err != nil {
  189. return nil, newStreamError(errorInternalError, err)
  190. }
  191. frame, err := parseNextFrame(str)
  192. if err != nil {
  193. return nil, newStreamError(errorFrameError, err)
  194. }
  195. hf, ok := frame.(*headersFrame)
  196. if !ok {
  197. return nil, newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
  198. }
  199. if hf.Length > c.maxHeaderBytes() {
  200. return nil, newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes()))
  201. }
  202. headerBlock := make([]byte, hf.Length)
  203. if _, err := io.ReadFull(str, headerBlock); err != nil {
  204. return nil, newStreamError(errorRequestIncomplete, err)
  205. }
  206. hfs, err := c.decoder.DecodeFull(headerBlock)
  207. if err != nil {
  208. // TODO: use the right error code
  209. return nil, newConnError(errorGeneralProtocolError, err)
  210. }
  211. res := &http.Response{
  212. Proto: "HTTP/3",
  213. ProtoMajor: 3,
  214. Header: http.Header{},
  215. }
  216. for _, hf := range hfs {
  217. switch hf.Name {
  218. case ":status":
  219. status, err := strconv.Atoi(hf.Value)
  220. if err != nil {
  221. return nil, newStreamError(errorGeneralProtocolError, errors.New("malformed non-numeric status pseudo header"))
  222. }
  223. res.StatusCode = status
  224. res.Status = hf.Value + " " + http.StatusText(status)
  225. default:
  226. res.Header.Add(hf.Name, hf.Value)
  227. }
  228. }
  229. respBody := newResponseBody(str, reqDone, func() {
  230. c.session.CloseWithError(quic.ErrorCode(errorFrameUnexpected), "")
  231. })
  232. if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
  233. res.Header.Del("Content-Encoding")
  234. res.Header.Del("Content-Length")
  235. res.ContentLength = -1
  236. res.Body = newGzipReader(respBody)
  237. res.Uncompressed = true
  238. } else {
  239. res.Body = respBody
  240. }
  241. return res, requestError{}
  242. }