client.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. package h2quic
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "net/http"
  9. "strings"
  10. "sync"
  11. "github.com/Psiphon-Labs/net/http2"
  12. "github.com/Psiphon-Labs/net/http2/hpack"
  13. "github.com/Psiphon-Labs/net/idna"
  14. quic "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go"
  15. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol"
  16. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/utils"
  17. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/qerr"
  18. )
  19. type roundTripperOpts struct {
  20. DisableCompression bool
  21. }
  22. var dialAddr = quic.DialAddr
  23. // client is a HTTP2 client doing QUIC requests
  24. type client struct {
  25. mutex sync.RWMutex
  26. tlsConf *tls.Config
  27. config *quic.Config
  28. opts *roundTripperOpts
  29. hostname string
  30. handshakeErr error
  31. dialOnce sync.Once
  32. dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
  33. // [Psiphon]
  34. // Fix close-while-dialing race condition by synchronizing access to
  35. // client.session and adding a closed flag to indicate if the client was
  36. // closed while a dial was in progress.
  37. sessionMutex sync.Mutex
  38. closed bool
  39. session quic.Session
  40. headerStream quic.Stream
  41. headerErr *qerr.QuicError
  42. headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
  43. requestWriter *requestWriter
  44. responses map[protocol.StreamID]chan *http.Response
  45. logger utils.Logger
  46. }
  47. var _ http.RoundTripper = &client{}
  48. var defaultQuicConfig = &quic.Config{
  49. RequestConnectionIDOmission: true,
  50. KeepAlive: true,
  51. }
  52. // newClient creates a new client
  53. func newClient(
  54. hostname string,
  55. tlsConfig *tls.Config,
  56. opts *roundTripperOpts,
  57. quicConfig *quic.Config,
  58. dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
  59. ) *client {
  60. config := defaultQuicConfig
  61. if quicConfig != nil {
  62. config = quicConfig
  63. }
  64. return &client{
  65. hostname: authorityAddr("https", hostname),
  66. responses: make(map[protocol.StreamID]chan *http.Response),
  67. tlsConf: tlsConfig,
  68. config: config,
  69. opts: opts,
  70. headerErrored: make(chan struct{}),
  71. dialer: dialer,
  72. logger: utils.DefaultLogger.WithPrefix("client"),
  73. }
  74. }
  75. // dial dials the connection
  76. func (c *client) dial() error {
  77. var err error
  78. // [Psiphon]
  79. var session quic.Session
  80. if c.dialer != nil {
  81. session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
  82. } else {
  83. session, err = dialAddr(c.hostname, c.tlsConf, c.config)
  84. }
  85. if err != nil {
  86. return err
  87. }
  88. // [Psiphon]
  89. // Only this write and the Close reads of c.session require synchronization.
  90. // After this point, it's safe to concurrently read c.session as it is not
  91. // rewritten.
  92. c.sessionMutex.Lock()
  93. closed := c.closed
  94. if !closed {
  95. c.session = session
  96. }
  97. c.sessionMutex.Unlock()
  98. if closed {
  99. session.Close()
  100. return errors.New("closed while dialing")
  101. }
  102. // [Psiphon]
  103. // once the version has been negotiated, open the header stream
  104. c.headerStream, err = c.session.OpenStream()
  105. if err != nil {
  106. return err
  107. }
  108. c.requestWriter = newRequestWriter(c.headerStream, c.logger)
  109. go c.handleHeaderStream()
  110. return nil
  111. }
  112. func (c *client) handleHeaderStream() {
  113. decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
  114. h2framer := http2.NewFramer(nil, c.headerStream)
  115. var err error
  116. for err == nil {
  117. err = c.readResponse(h2framer, decoder)
  118. }
  119. if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway {
  120. c.logger.Debugf("Error handling header stream: %s", err)
  121. }
  122. c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
  123. // stop all running request
  124. close(c.headerErrored)
  125. }
  126. func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
  127. frame, err := h2framer.ReadFrame()
  128. if err != nil {
  129. return err
  130. }
  131. hframe, ok := frame.(*http2.HeadersFrame)
  132. if !ok {
  133. return errors.New("not a headers frame")
  134. }
  135. mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
  136. mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
  137. if err != nil {
  138. return fmt.Errorf("cannot read header fields: %s", err.Error())
  139. }
  140. c.mutex.RLock()
  141. responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
  142. c.mutex.RUnlock()
  143. if !ok {
  144. return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
  145. }
  146. rsp, err := responseFromHeaders(mhframe)
  147. if err != nil {
  148. return err
  149. }
  150. responseChan <- rsp
  151. return nil
  152. }
  153. // Roundtrip executes a request and returns a response
  154. func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
  155. // TODO: add port to address, if it doesn't have one
  156. if req.URL.Scheme != "https" {
  157. return nil, errors.New("quic http2: unsupported scheme")
  158. }
  159. if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
  160. return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
  161. }
  162. c.dialOnce.Do(func() {
  163. c.handshakeErr = c.dial()
  164. })
  165. if c.handshakeErr != nil {
  166. return nil, c.handshakeErr
  167. }
  168. hasBody := (req.Body != nil)
  169. responseChan := make(chan *http.Response)
  170. dataStream, err := c.session.OpenStreamSync()
  171. if err != nil {
  172. _ = c.closeWithError(err)
  173. return nil, err
  174. }
  175. c.mutex.Lock()
  176. c.responses[dataStream.StreamID()] = responseChan
  177. c.mutex.Unlock()
  178. var requestedGzip bool
  179. if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
  180. requestedGzip = true
  181. }
  182. // TODO: add support for trailers
  183. endStream := !hasBody
  184. err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
  185. if err != nil {
  186. _ = c.closeWithError(err)
  187. return nil, err
  188. }
  189. resc := make(chan error, 1)
  190. if hasBody {
  191. go func() {
  192. resc <- c.writeRequestBody(dataStream, req.Body)
  193. }()
  194. }
  195. var res *http.Response
  196. var receivedResponse bool
  197. var bodySent bool
  198. if !hasBody {
  199. bodySent = true
  200. }
  201. ctx := req.Context()
  202. for !(bodySent && receivedResponse) {
  203. select {
  204. case res = <-responseChan:
  205. receivedResponse = true
  206. c.mutex.Lock()
  207. delete(c.responses, dataStream.StreamID())
  208. c.mutex.Unlock()
  209. case err := <-resc:
  210. bodySent = true
  211. if err != nil {
  212. return nil, err
  213. }
  214. case <-ctx.Done():
  215. // error code 6 signals that stream was canceled
  216. dataStream.CancelRead(6)
  217. dataStream.CancelWrite(6)
  218. c.mutex.Lock()
  219. delete(c.responses, dataStream.StreamID())
  220. c.mutex.Unlock()
  221. return nil, ctx.Err()
  222. case <-c.headerErrored:
  223. // an error occurred on the header stream
  224. _ = c.closeWithError(c.headerErr)
  225. return nil, c.headerErr
  226. }
  227. }
  228. // TODO: correctly set this variable
  229. var streamEnded bool
  230. isHead := (req.Method == "HEAD")
  231. res = setLength(res, isHead, streamEnded)
  232. if streamEnded || isHead {
  233. res.Body = noBody
  234. } else {
  235. res.Body = dataStream
  236. if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
  237. res.Header.Del("Content-Encoding")
  238. res.Header.Del("Content-Length")
  239. res.ContentLength = -1
  240. res.Body = &gzipReader{body: res.Body}
  241. res.Uncompressed = true
  242. }
  243. }
  244. res.Request = req
  245. return res, nil
  246. }
  247. func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
  248. defer func() {
  249. cerr := body.Close()
  250. if err == nil {
  251. // TODO: what to do with dataStream here? Maybe reset it?
  252. err = cerr
  253. }
  254. }()
  255. _, err = io.Copy(dataStream, body)
  256. if err != nil {
  257. // TODO: what to do with dataStream here? Maybe reset it?
  258. return err
  259. }
  260. return dataStream.Close()
  261. }
  262. func (c *client) closeWithError(e error) error {
  263. // [Psiphon]
  264. c.sessionMutex.Lock()
  265. session := c.session
  266. c.closed = true
  267. c.sessionMutex.Unlock()
  268. // [Psiphon]
  269. if session == nil {
  270. return nil
  271. }
  272. return session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
  273. }
  274. // Close closes the client
  275. func (c *client) Close() error {
  276. // [Psiphon]
  277. c.sessionMutex.Lock()
  278. session := c.session
  279. c.closed = true
  280. c.sessionMutex.Unlock()
  281. // [Psiphon]
  282. if session == nil {
  283. return nil
  284. }
  285. return session.Close()
  286. }
  287. // copied from net/transport.go
  288. // authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
  289. // and returns a host:port. The port 443 is added if needed.
  290. func authorityAddr(scheme string, authority string) (addr string) {
  291. host, port, err := net.SplitHostPort(authority)
  292. if err != nil { // authority didn't have a port
  293. port = "443"
  294. if scheme == "http" {
  295. port = "80"
  296. }
  297. host = authority
  298. }
  299. if a, err := idna.ToASCII(host); err == nil {
  300. host = a
  301. }
  302. // IPv6 address literal, without a port:
  303. if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
  304. return host + ":" + port
  305. }
  306. return net.JoinHostPort(host, port)
  307. }