client.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. package http3
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log/slog"
  8. "net/http"
  9. "net/http/httptrace"
  10. "net/textproto"
  11. "time"
  12. "github.com/Psiphon-Labs/quic-go"
  13. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  14. "github.com/Psiphon-Labs/quic-go/quicvarint"
  15. "github.com/quic-go/qpack"
  16. tls "github.com/Psiphon-Labs/psiphon-tls"
  17. )
  18. const (
  19. // MethodGet0RTT allows a GET request to be sent using 0-RTT.
  20. // Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests.
  21. MethodGet0RTT = "GET_0RTT"
  22. // MethodHead0RTT allows a HEAD request to be sent using 0-RTT.
  23. // Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests.
  24. MethodHead0RTT = "HEAD_0RTT"
  25. )
  26. const (
  27. defaultUserAgent = "quic-go HTTP/3"
  28. defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
  29. )
  30. var defaultQuicConfig = &quic.Config{
  31. MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
  32. KeepAlivePeriod: 10 * time.Second,
  33. }
  34. // ClientConn is an HTTP/3 client doing requests to a single remote server.
  35. type ClientConn struct {
  36. connection
  37. // Enable support for HTTP/3 datagrams (RFC 9297).
  38. // If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting enableDatagrams.
  39. enableDatagrams bool
  40. // Additional HTTP/3 settings.
  41. // It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams).
  42. additionalSettings map[uint64]uint64
  43. // maxResponseHeaderBytes specifies a limit on how many response bytes are
  44. // allowed in the server's response header.
  45. maxResponseHeaderBytes uint64
  46. // disableCompression, if true, prevents the Transport from requesting compression with an
  47. // "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value.
  48. // If the Transport requests gzip on its own and gets a gzipped response, it's transparently
  49. // decoded in the Response.Body.
  50. // However, if the user explicitly requested gzip it is not automatically uncompressed.
  51. disableCompression bool
  52. logger *slog.Logger
  53. requestWriter *requestWriter
  54. decoder *qpack.Decoder
  55. }
  56. var _ http.RoundTripper = &ClientConn{}
  57. // Deprecated: SingleDestinationRoundTripper was renamed to ClientConn.
  58. // It can be obtained by calling NewClientConn on a Transport.
  59. type SingleDestinationRoundTripper = ClientConn
  60. func newClientConn(
  61. conn quic.Connection,
  62. enableDatagrams bool,
  63. additionalSettings map[uint64]uint64,
  64. streamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error),
  65. uniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool),
  66. maxResponseHeaderBytes int64,
  67. disableCompression bool,
  68. logger *slog.Logger,
  69. ) *ClientConn {
  70. c := &ClientConn{
  71. enableDatagrams: enableDatagrams,
  72. additionalSettings: additionalSettings,
  73. disableCompression: disableCompression,
  74. logger: logger,
  75. }
  76. if maxResponseHeaderBytes <= 0 {
  77. c.maxResponseHeaderBytes = defaultMaxResponseHeaderBytes
  78. } else {
  79. c.maxResponseHeaderBytes = uint64(maxResponseHeaderBytes)
  80. }
  81. c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {})
  82. c.requestWriter = newRequestWriter()
  83. c.connection = *newConnection(
  84. conn.Context(),
  85. conn,
  86. c.enableDatagrams,
  87. protocol.PerspectiveClient,
  88. c.logger,
  89. 0,
  90. )
  91. // send the SETTINGs frame, using 0-RTT data, if possible
  92. go func() {
  93. if err := c.setupConn(); err != nil {
  94. if c.logger != nil {
  95. c.logger.Debug("Setting up connection failed", "error", err)
  96. }
  97. c.connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
  98. }
  99. }()
  100. if streamHijacker != nil {
  101. go c.handleBidirectionalStreams(streamHijacker)
  102. }
  103. go c.connection.handleUnidirectionalStreams(uniStreamHijacker)
  104. return c
  105. }
  106. // OpenRequestStream opens a new request stream on the HTTP/3 connection.
  107. func (c *ClientConn) OpenRequestStream(ctx context.Context) (RequestStream, error) {
  108. return c.connection.openRequestStream(ctx, c.requestWriter, nil, c.disableCompression, c.maxResponseHeaderBytes)
  109. }
  110. func (c *ClientConn) setupConn() error {
  111. // open the control stream
  112. str, err := c.connection.OpenUniStream()
  113. if err != nil {
  114. return err
  115. }
  116. b := make([]byte, 0, 64)
  117. b = quicvarint.Append(b, streamTypeControlStream)
  118. // send the SETTINGS frame
  119. b = (&settingsFrame{Datagram: c.enableDatagrams, Other: c.additionalSettings}).Append(b)
  120. _, err = str.Write(b)
  121. return err
  122. }
  123. func (c *ClientConn) handleBidirectionalStreams(streamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)) {
  124. for {
  125. str, err := c.connection.AcceptStream(context.Background())
  126. if err != nil {
  127. if c.logger != nil {
  128. c.logger.Debug("accepting bidirectional stream failed", "error", err)
  129. }
  130. return
  131. }
  132. fp := &frameParser{
  133. r: str,
  134. conn: &c.connection,
  135. unknownFrameHandler: func(ft FrameType, e error) (processed bool, err error) {
  136. id := c.connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
  137. return streamHijacker(ft, id, str, e)
  138. },
  139. }
  140. go func() {
  141. if _, err := fp.ParseNext(); err == errHijacked {
  142. return
  143. }
  144. if err != nil {
  145. if c.logger != nil {
  146. c.logger.Debug("error handling stream", "error", err)
  147. }
  148. }
  149. c.connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
  150. }()
  151. }
  152. }
  153. // RoundTrip executes a request and returns a response
  154. func (c *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
  155. rsp, err := c.roundTrip(req)
  156. if err != nil && req.Context().Err() != nil {
  157. // if the context was canceled, return the context cancellation error
  158. err = req.Context().Err()
  159. }
  160. return rsp, err
  161. }
  162. func (c *ClientConn) roundTrip(req *http.Request) (*http.Response, error) {
  163. // Immediately send out this request, if this is a 0-RTT request.
  164. switch req.Method {
  165. case MethodGet0RTT:
  166. // don't modify the original request
  167. reqCopy := *req
  168. req = &reqCopy
  169. req.Method = http.MethodGet
  170. case MethodHead0RTT:
  171. // don't modify the original request
  172. reqCopy := *req
  173. req = &reqCopy
  174. req.Method = http.MethodHead
  175. default:
  176. // wait for the handshake to complete
  177. earlyConn, ok := c.Connection.(quic.EarlyConnection)
  178. if ok {
  179. select {
  180. case <-earlyConn.HandshakeComplete():
  181. case <-req.Context().Done():
  182. return nil, req.Context().Err()
  183. }
  184. }
  185. }
  186. // It is only possible to send an Extended CONNECT request once the SETTINGS were received.
  187. // See section 3 of RFC 8441.
  188. if isExtendedConnectRequest(req) {
  189. connCtx := c.Connection.Context()
  190. // wait for the server's SETTINGS frame to arrive
  191. select {
  192. case <-c.connection.ReceivedSettings():
  193. case <-connCtx.Done():
  194. return nil, context.Cause(connCtx)
  195. }
  196. if !c.connection.Settings().EnableExtendedConnect {
  197. return nil, errors.New("http3: server didn't enable Extended CONNECT")
  198. }
  199. }
  200. reqDone := make(chan struct{})
  201. str, err := c.connection.openRequestStream(
  202. req.Context(),
  203. c.requestWriter,
  204. reqDone,
  205. c.disableCompression,
  206. c.maxResponseHeaderBytes,
  207. )
  208. if err != nil {
  209. return nil, err
  210. }
  211. // Request Cancellation:
  212. // This go routine keeps running even after RoundTripOpt() returns.
  213. // It is shut down when the application is done processing the body.
  214. done := make(chan struct{})
  215. go func() {
  216. defer close(done)
  217. select {
  218. case <-req.Context().Done():
  219. str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
  220. str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
  221. case <-reqDone:
  222. }
  223. }()
  224. rsp, err := c.doRequest(req, str)
  225. if err != nil { // if any error occurred
  226. close(reqDone)
  227. <-done
  228. return nil, maybeReplaceError(err)
  229. }
  230. return rsp, maybeReplaceError(err)
  231. }
  232. // cancelingReader reads from the io.Reader.
  233. // It cancels writing on the stream if any error other than io.EOF occurs.
  234. type cancelingReader struct {
  235. r io.Reader
  236. str Stream
  237. }
  238. func (r *cancelingReader) Read(b []byte) (int, error) {
  239. n, err := r.r.Read(b)
  240. if err != nil && err != io.EOF {
  241. r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
  242. }
  243. return n, err
  244. }
  245. func (c *ClientConn) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
  246. defer body.Close()
  247. buf := make([]byte, bodyCopyBufferSize)
  248. sr := &cancelingReader{str: str, r: body}
  249. if contentLength == -1 {
  250. _, err := io.CopyBuffer(str, sr, buf)
  251. return err
  252. }
  253. // make sure we don't send more bytes than the content length
  254. n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf)
  255. if err != nil {
  256. return err
  257. }
  258. var extra int64
  259. extra, err = io.CopyBuffer(io.Discard, sr, buf)
  260. n += extra
  261. if n > contentLength {
  262. str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
  263. return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n)
  264. }
  265. return err
  266. }
  267. func (c *ClientConn) doRequest(req *http.Request, str *requestStream) (*http.Response, error) {
  268. trace := httptrace.ContextClientTrace(req.Context())
  269. if err := str.SendRequestHeader(req); err != nil {
  270. traceWroteRequest(trace, err)
  271. return nil, err
  272. }
  273. if req.Body == nil {
  274. traceWroteRequest(trace, nil)
  275. str.Close()
  276. } else {
  277. // send the request body asynchronously
  278. go func() {
  279. contentLength := int64(-1)
  280. // According to the documentation for http.Request.ContentLength,
  281. // a value of 0 with a non-nil Body is also treated as unknown content length.
  282. if req.ContentLength > 0 {
  283. contentLength = req.ContentLength
  284. }
  285. err := c.sendRequestBody(str, req.Body, contentLength)
  286. traceWroteRequest(trace, err)
  287. if err != nil {
  288. if c.logger != nil {
  289. c.logger.Debug("error writing request", "error", err)
  290. }
  291. }
  292. str.Close()
  293. }()
  294. }
  295. // copy from net/http: support 1xx responses
  296. num1xx := 0 // number of informational 1xx headers received
  297. const max1xxResponses = 5 // arbitrary bound on number of informational responses
  298. var res *http.Response
  299. for {
  300. var err error
  301. res, err = str.ReadResponse()
  302. if err != nil {
  303. return nil, err
  304. }
  305. resCode := res.StatusCode
  306. is1xx := 100 <= resCode && resCode <= 199
  307. // treat 101 as a terminal status, see https://github.com/golang/go/issues/26161
  308. is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
  309. if is1xxNonTerminal {
  310. num1xx++
  311. if num1xx > max1xxResponses {
  312. return nil, errors.New("http: too many 1xx informational responses")
  313. }
  314. traceGot1xxResponse(trace, resCode, textproto.MIMEHeader(res.Header))
  315. if resCode == 100 {
  316. traceGot100Continue(trace)
  317. }
  318. continue
  319. }
  320. break
  321. }
  322. connState := c.connection.ConnectionState().TLS
  323. // [Psiphon]
  324. res.TLS = tls.UnsafeFromConnectionState(&connState)
  325. res.Request = req
  326. return res, nil
  327. }