roundtrip.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. package http3
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "sync"
  10. quic "github.com/Psiphon-Labs/quic-go"
  11. "golang.org/x/net/http/httpguts"
  12. )
  13. type roundTripCloser interface {
  14. http.RoundTripper
  15. io.Closer
  16. }
  17. // RoundTripper implements the http.RoundTripper interface
  18. type RoundTripper struct {
  19. mutex sync.Mutex
  20. // DisableCompression, if true, prevents the Transport from
  21. // requesting compression with an "Accept-Encoding: gzip"
  22. // request header when the Request contains no existing
  23. // Accept-Encoding value. If the Transport requests gzip on
  24. // its own and gets a gzipped response, it's transparently
  25. // decoded in the Response.Body. However, if the user
  26. // explicitly requested gzip it is not automatically
  27. // uncompressed.
  28. DisableCompression bool
  29. // TLSClientConfig specifies the TLS configuration to use with
  30. // tls.Client. If nil, the default configuration is used.
  31. TLSClientConfig *tls.Config
  32. // QuicConfig is the quic.Config used for dialing new connections.
  33. // If nil, reasonable default values will be used.
  34. QuicConfig *quic.Config
  35. // Enable support for HTTP/3 datagrams.
  36. // If set to true, QuicConfig.EnableDatagram will be set.
  37. // See https://www.ietf.org/archive/id/draft-schinazi-masque-h3-datagram-02.html.
  38. EnableDatagrams bool
  39. // Dial specifies an optional dial function for creating QUIC
  40. // connections for requests.
  41. // If Dial is nil, quic.DialAddrEarly will be used.
  42. Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
  43. // MaxResponseHeaderBytes specifies a limit on how many response bytes are
  44. // allowed in the server's response header.
  45. // Zero means to use a default limit.
  46. MaxResponseHeaderBytes int64
  47. clients map[string]roundTripCloser
  48. }
  49. // RoundTripOpt are options for the Transport.RoundTripOpt method.
  50. type RoundTripOpt struct {
  51. // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
  52. // If set true and no cached connection is available, RoundTrip will return ErrNoCachedConn.
  53. OnlyCachedConn bool
  54. // SkipSchemeCheck controls whether we check if the scheme is https.
  55. // This allows the use of different schemes, e.g. masque://target.example.com:443/.
  56. SkipSchemeCheck bool
  57. }
  58. var _ roundTripCloser = &RoundTripper{}
  59. // ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
  60. var ErrNoCachedConn = errors.New("http3: no cached connection was available")
  61. // RoundTripOpt is like RoundTrip, but takes options.
  62. func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
  63. if req.URL == nil {
  64. closeRequestBody(req)
  65. return nil, errors.New("http3: nil Request.URL")
  66. }
  67. if req.URL.Host == "" {
  68. closeRequestBody(req)
  69. return nil, errors.New("http3: no Host in request URL")
  70. }
  71. if req.Header == nil {
  72. closeRequestBody(req)
  73. return nil, errors.New("http3: nil Request.Header")
  74. }
  75. if req.URL.Scheme == "https" {
  76. for k, vv := range req.Header {
  77. if !httpguts.ValidHeaderFieldName(k) {
  78. return nil, fmt.Errorf("http3: invalid http header field name %q", k)
  79. }
  80. for _, v := range vv {
  81. if !httpguts.ValidHeaderFieldValue(v) {
  82. return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
  83. }
  84. }
  85. }
  86. } else if !opt.SkipSchemeCheck {
  87. closeRequestBody(req)
  88. return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
  89. }
  90. if req.Method != "" && !validMethod(req.Method) {
  91. closeRequestBody(req)
  92. return nil, fmt.Errorf("http3: invalid method %q", req.Method)
  93. }
  94. hostname := authorityAddr("https", hostnameFromRequest(req))
  95. cl, err := r.getClient(hostname, opt.OnlyCachedConn)
  96. if err != nil {
  97. return nil, err
  98. }
  99. return cl.RoundTrip(req)
  100. }
  101. // RoundTrip does a round trip.
  102. func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
  103. return r.RoundTripOpt(req, RoundTripOpt{})
  104. }
  105. func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
  106. r.mutex.Lock()
  107. defer r.mutex.Unlock()
  108. if r.clients == nil {
  109. r.clients = make(map[string]roundTripCloser)
  110. }
  111. client, ok := r.clients[hostname]
  112. if !ok {
  113. if onlyCached {
  114. return nil, ErrNoCachedConn
  115. }
  116. var err error
  117. client, err = newClient(
  118. hostname,
  119. r.TLSClientConfig,
  120. &roundTripperOpts{
  121. EnableDatagram: r.EnableDatagrams,
  122. DisableCompression: r.DisableCompression,
  123. MaxHeaderBytes: r.MaxResponseHeaderBytes,
  124. },
  125. r.QuicConfig,
  126. r.Dial,
  127. )
  128. if err != nil {
  129. return nil, err
  130. }
  131. r.clients[hostname] = client
  132. }
  133. return client, nil
  134. }
  135. // Close closes the QUIC connections that this RoundTripper has used
  136. func (r *RoundTripper) Close() error {
  137. r.mutex.Lock()
  138. defer r.mutex.Unlock()
  139. for _, client := range r.clients {
  140. if err := client.Close(); err != nil {
  141. return err
  142. }
  143. }
  144. r.clients = nil
  145. return nil
  146. }
  147. func closeRequestBody(req *http.Request) {
  148. if req.Body != nil {
  149. req.Body.Close()
  150. }
  151. }
  152. func validMethod(method string) bool {
  153. /*
  154. Method = "OPTIONS" ; Section 9.2
  155. | "GET" ; Section 9.3
  156. | "HEAD" ; Section 9.4
  157. | "POST" ; Section 9.5
  158. | "PUT" ; Section 9.6
  159. | "DELETE" ; Section 9.7
  160. | "TRACE" ; Section 9.8
  161. | "CONNECT" ; Section 9.9
  162. | extension-method
  163. extension-method = token
  164. token = 1*<any CHAR except CTLs or separators>
  165. */
  166. return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
  167. }
  168. // copied from net/http/http.go
  169. func isNotToken(r rune) bool {
  170. return !httpguts.IsTokenRune(r)
  171. }