| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281 |
- package http3
- import (
- "bytes"
- "context"
- "crypto/tls"
- "errors"
- "fmt"
- "io"
- "net/http"
- "strconv"
- "sync"
- "github.com/Psiphon-Labs/quic-go"
- "github.com/Psiphon-Labs/quic-go/internal/utils"
- "github.com/marten-seemann/qpack"
- )
- const defaultUserAgent = "quic-go HTTP/3"
- const defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
- var defaultQuicConfig = &quic.Config{
- MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
- KeepAlive: true,
- }
- var dialAddr = quic.DialAddr
- type roundTripperOpts struct {
- DisableCompression bool
- MaxHeaderBytes int64
- }
- // client is a HTTP3 client doing requests
- type client struct {
- tlsConf *tls.Config
- config *quic.Config
- opts *roundTripperOpts
- dialOnce sync.Once
- dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
- handshakeErr error
- requestWriter *requestWriter
- decoder *qpack.Decoder
- hostname string
- // [Psiphon]
- setSession sync.Mutex
- session quic.Session
- logger utils.Logger
- }
- func newClient(
- hostname string,
- tlsConf *tls.Config,
- opts *roundTripperOpts,
- quicConfig *quic.Config,
- dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
- ) *client {
- if tlsConf == nil {
- tlsConf = &tls.Config{}
- } else {
- tlsConf = tlsConf.Clone()
- }
- // Replace existing ALPNs by H3
- tlsConf.NextProtos = []string{nextProtoH3}
- if quicConfig == nil {
- quicConfig = defaultQuicConfig
- }
- // [Psiphon]
- // Prevent race condition the results from concurrent RoundTrippers using defaultQuicConfig
- if quicConfig.MaxIncomingStreams != -1 {
- quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
- }
- logger := utils.DefaultLogger.WithPrefix("h3 client")
- return &client{
- hostname: authorityAddr("https", hostname),
- tlsConf: tlsConf,
- requestWriter: newRequestWriter(logger),
- decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
- config: quicConfig,
- opts: opts,
- dialer: dialer,
- logger: logger,
- }
- }
- func (c *client) dial() error {
- var err error
- var session quic.Session
- if c.dialer != nil {
- session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
- } else {
- session, err = dialAddr(c.hostname, c.tlsConf, c.config)
- }
- // [Psiphon]
- c.setSession.Lock()
- c.session = session
- c.setSession.Unlock()
- if err != nil {
- return err
- }
- go func() {
- if err := c.setupSession(); err != nil {
- c.logger.Debugf("Setting up session failed: %s", err)
- c.session.CloseWithError(quic.ErrorCode(errorInternalError), "")
- }
- }()
- return nil
- }
- func (c *client) setupSession() error {
- // open the control stream
- str, err := c.session.OpenUniStream()
- if err != nil {
- return err
- }
- buf := &bytes.Buffer{}
- // write the type byte
- buf.Write([]byte{0x0})
- // send the SETTINGS frame
- (&settingsFrame{}).Write(buf)
- if _, err := str.Write(buf.Bytes()); err != nil {
- return err
- }
- return nil
- }
- func (c *client) Close() error {
- // [Psiphon]
- // Prevent panic when c.session is nil
- c.setSession.Lock()
- session := c.session
- c.setSession.Unlock()
- if session == nil {
- return nil
- }
- return c.session.Close()
- }
- func (c *client) maxHeaderBytes() uint64 {
- if c.opts.MaxHeaderBytes <= 0 {
- return defaultMaxResponseHeaderBytes
- }
- return uint64(c.opts.MaxHeaderBytes)
- }
- // RoundTrip executes a request and returns a response
- func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
- if req.URL.Scheme != "https" {
- return nil, errors.New("http3: unsupported scheme")
- }
- if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
- return nil, fmt.Errorf("http3 client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
- }
- c.dialOnce.Do(func() {
- c.handshakeErr = c.dial()
- })
- if c.handshakeErr != nil {
- return nil, c.handshakeErr
- }
- str, err := c.session.OpenStreamSync(context.Background())
- if err != nil {
- return nil, err
- }
- // Request Cancellation:
- // This go routine keeps running even after RoundTrip() returns.
- // It is shut down when the application is done processing the body.
- reqDone := make(chan struct{})
- go func() {
- select {
- case <-req.Context().Done():
- str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
- str.CancelRead(quic.ErrorCode(errorRequestCanceled))
- case <-reqDone:
- }
- }()
- rsp, rerr := c.doRequest(req, str, reqDone)
- if rerr.err != nil { // if any error occurred
- close(reqDone)
- if rerr.streamErr != 0 { // if it was a stream error
- str.CancelWrite(quic.ErrorCode(rerr.streamErr))
- }
- if rerr.connErr != 0 { // if it was a connection error
- var reason string
- if rerr.err != nil {
- reason = rerr.err.Error()
- }
- c.session.CloseWithError(quic.ErrorCode(rerr.connErr), reason)
- }
- }
- return rsp, rerr.err
- }
- func (c *client) doRequest(
- req *http.Request,
- str quic.Stream,
- reqDone chan struct{},
- ) (*http.Response, requestError) {
- var requestGzip bool
- if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
- requestGzip = true
- }
- if err := c.requestWriter.WriteRequest(str, req, requestGzip); err != nil {
- return nil, newStreamError(errorInternalError, err)
- }
- frame, err := parseNextFrame(str)
- if err != nil {
- return nil, newStreamError(errorFrameError, err)
- }
- hf, ok := frame.(*headersFrame)
- if !ok {
- return nil, newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
- }
- if hf.Length > c.maxHeaderBytes() {
- return nil, newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes()))
- }
- headerBlock := make([]byte, hf.Length)
- if _, err := io.ReadFull(str, headerBlock); err != nil {
- return nil, newStreamError(errorRequestIncomplete, err)
- }
- hfs, err := c.decoder.DecodeFull(headerBlock)
- if err != nil {
- // TODO: use the right error code
- return nil, newConnError(errorGeneralProtocolError, err)
- }
- res := &http.Response{
- Proto: "HTTP/3",
- ProtoMajor: 3,
- Header: http.Header{},
- }
- for _, hf := range hfs {
- switch hf.Name {
- case ":status":
- status, err := strconv.Atoi(hf.Value)
- if err != nil {
- return nil, newStreamError(errorGeneralProtocolError, errors.New("malformed non-numeric status pseudo header"))
- }
- res.StatusCode = status
- res.Status = hf.Value + " " + http.StatusText(status)
- default:
- res.Header.Add(hf.Name, hf.Value)
- }
- }
- respBody := newResponseBody(str, reqDone, func() {
- c.session.CloseWithError(quic.ErrorCode(errorFrameUnexpected), "")
- })
- if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
- res.Header.Del("Content-Encoding")
- res.Header.Del("Content-Length")
- res.ContentLength = -1
- res.Body = newGzipReader(respBody)
- res.Uncompressed = true
- } else {
- res.Body = respBody
- }
- return res, requestError{}
- }
|