server.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. package http3
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/tls"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net"
  10. "net/http"
  11. "runtime"
  12. "sync"
  13. "sync/atomic"
  14. "time"
  15. "github.com/Psiphon-Labs/quic-go"
  16. "github.com/Psiphon-Labs/quic-go/internal/utils"
  17. "github.com/marten-seemann/qpack"
  18. // [Psiphon]
  19. // Remove testing dependency.
  20. //"github.com/onsi/ginkgo"
  21. )
  22. // allows mocking of quic.Listen and quic.ListenAddr
  23. var (
  24. quicListen = quic.Listen
  25. quicListenAddr = quic.ListenAddr
  26. )
  27. const nextProtoH3 = "h3-24"
  28. type requestError struct {
  29. err error
  30. streamErr errorCode
  31. connErr errorCode
  32. }
  33. func newStreamError(code errorCode, err error) requestError {
  34. return requestError{err: err, streamErr: code}
  35. }
  36. func newConnError(code errorCode, err error) requestError {
  37. return requestError{err: err, connErr: code}
  38. }
  39. // Server is a HTTP2 server listening for QUIC connections.
  40. type Server struct {
  41. *http.Server
  42. // By providing a quic.Config, it is possible to set parameters of the QUIC connection.
  43. // If nil, it uses reasonable default values.
  44. QuicConfig *quic.Config
  45. port uint32 // used atomically
  46. mutex sync.Mutex
  47. listeners map[*quic.Listener]struct{}
  48. closed utils.AtomicBool
  49. logger utils.Logger
  50. }
  51. // ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
  52. func (s *Server) ListenAndServe() error {
  53. if s.Server == nil {
  54. return errors.New("use of http3.Server without http.Server")
  55. }
  56. return s.serveImpl(s.TLSConfig, nil)
  57. }
  58. // ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
  59. func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
  60. var err error
  61. certs := make([]tls.Certificate, 1)
  62. certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
  63. if err != nil {
  64. return err
  65. }
  66. // We currently only use the cert-related stuff from tls.Config,
  67. // so we don't need to make a full copy.
  68. config := &tls.Config{
  69. Certificates: certs,
  70. }
  71. return s.serveImpl(config, nil)
  72. }
  73. // Serve an existing UDP connection.
  74. // It is possible to reuse the same connection for outgoing connections.
  75. // Closing the server does not close the packet conn.
  76. func (s *Server) Serve(conn net.PacketConn) error {
  77. return s.serveImpl(s.TLSConfig, conn)
  78. }
  79. func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error {
  80. if s.closed.Get() {
  81. return http.ErrServerClosed
  82. }
  83. if s.Server == nil {
  84. return errors.New("use of http3.Server without http.Server")
  85. }
  86. s.logger = utils.DefaultLogger.WithPrefix("server")
  87. if tlsConf == nil {
  88. tlsConf = &tls.Config{}
  89. } else {
  90. tlsConf = tlsConf.Clone()
  91. }
  92. // Replace existing ALPNs by H3
  93. tlsConf.NextProtos = []string{nextProtoH3}
  94. if tlsConf.GetConfigForClient != nil {
  95. getConfigForClient := tlsConf.GetConfigForClient
  96. tlsConf.GetConfigForClient = func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
  97. conf, err := getConfigForClient(ch)
  98. if err != nil || conf == nil {
  99. return conf, err
  100. }
  101. conf = conf.Clone()
  102. conf.NextProtos = []string{nextProtoH3}
  103. return conf, nil
  104. }
  105. }
  106. var ln quic.Listener
  107. var err error
  108. if conn == nil {
  109. ln, err = quicListenAddr(s.Addr, tlsConf, s.QuicConfig)
  110. } else {
  111. ln, err = quicListen(conn, tlsConf, s.QuicConfig)
  112. }
  113. if err != nil {
  114. return err
  115. }
  116. s.addListener(&ln)
  117. defer s.removeListener(&ln)
  118. for {
  119. sess, err := ln.Accept(context.Background())
  120. if err != nil {
  121. return err
  122. }
  123. go s.handleConn(sess)
  124. }
  125. }
  126. // We store a pointer to interface in the map set. This is safe because we only
  127. // call trackListener via Serve and can track+defer untrack the same pointer to
  128. // local variable there. We never need to compare a Listener from another caller.
  129. func (s *Server) addListener(l *quic.Listener) {
  130. s.mutex.Lock()
  131. if s.listeners == nil {
  132. s.listeners = make(map[*quic.Listener]struct{})
  133. }
  134. s.listeners[l] = struct{}{}
  135. s.mutex.Unlock()
  136. }
  137. func (s *Server) removeListener(l *quic.Listener) {
  138. s.mutex.Lock()
  139. delete(s.listeners, l)
  140. s.mutex.Unlock()
  141. }
  142. func (s *Server) handleConn(sess quic.Session) {
  143. // TODO: accept control streams
  144. decoder := qpack.NewDecoder(nil)
  145. // send a SETTINGS frame
  146. str, err := sess.OpenUniStream()
  147. if err != nil {
  148. s.logger.Debugf("Opening the control stream failed.")
  149. return
  150. }
  151. buf := bytes.NewBuffer([]byte{0})
  152. (&settingsFrame{}).Write(buf)
  153. str.Write(buf.Bytes())
  154. for {
  155. str, err := sess.AcceptStream(context.Background())
  156. if err != nil {
  157. s.logger.Debugf("Accepting stream failed: %s", err)
  158. return
  159. }
  160. go func() {
  161. // [Psiphon]
  162. //defer ginkgo.GinkgoRecover()
  163. rerr := s.handleRequest(str, decoder, func() {
  164. sess.CloseWithError(quic.ErrorCode(errorFrameUnexpected), "")
  165. })
  166. if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 {
  167. s.logger.Debugf("Handling request failed: %s", err)
  168. if rerr.streamErr != 0 {
  169. str.CancelWrite(quic.ErrorCode(rerr.streamErr))
  170. }
  171. if rerr.connErr != 0 {
  172. var reason string
  173. if rerr.err != nil {
  174. reason = rerr.err.Error()
  175. }
  176. sess.CloseWithError(quic.ErrorCode(rerr.connErr), reason)
  177. }
  178. return
  179. }
  180. str.Close()
  181. }()
  182. }
  183. }
  184. func (s *Server) maxHeaderBytes() uint64 {
  185. if s.Server.MaxHeaderBytes <= 0 {
  186. return http.DefaultMaxHeaderBytes
  187. }
  188. return uint64(s.Server.MaxHeaderBytes)
  189. }
  190. func (s *Server) handleRequest(str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
  191. frame, err := parseNextFrame(str)
  192. if err != nil {
  193. return newStreamError(errorRequestIncomplete, err)
  194. }
  195. hf, ok := frame.(*headersFrame)
  196. if !ok {
  197. return newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
  198. }
  199. if hf.Length > s.maxHeaderBytes() {
  200. return newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes()))
  201. }
  202. headerBlock := make([]byte, hf.Length)
  203. if _, err := io.ReadFull(str, headerBlock); err != nil {
  204. return newStreamError(errorRequestIncomplete, err)
  205. }
  206. hfs, err := decoder.DecodeFull(headerBlock)
  207. if err != nil {
  208. // TODO: use the right error code
  209. return newConnError(errorGeneralProtocolError, err)
  210. }
  211. req, err := requestFromHeaders(hfs)
  212. if err != nil {
  213. // TODO: use the right error code
  214. return newStreamError(errorGeneralProtocolError, err)
  215. }
  216. req.Body = newRequestBody(str, onFrameError)
  217. if s.logger.Debug() {
  218. s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID())
  219. } else {
  220. s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
  221. }
  222. req = req.WithContext(str.Context())
  223. responseWriter := newResponseWriter(str, s.logger)
  224. handler := s.Handler
  225. if handler == nil {
  226. handler = http.DefaultServeMux
  227. }
  228. var panicked, readEOF bool
  229. func() {
  230. defer func() {
  231. if p := recover(); p != nil {
  232. // Copied from net/http/server.go
  233. const size = 64 << 10
  234. buf := make([]byte, size)
  235. buf = buf[:runtime.Stack(buf, false)]
  236. s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
  237. panicked = true
  238. }
  239. }()
  240. handler.ServeHTTP(responseWriter, req)
  241. // read the eof
  242. if _, err = str.Read([]byte{0}); err == io.EOF {
  243. readEOF = true
  244. }
  245. }()
  246. if panicked {
  247. responseWriter.WriteHeader(500)
  248. } else {
  249. responseWriter.WriteHeader(200)
  250. }
  251. if !readEOF {
  252. str.CancelRead(quic.ErrorCode(errorEarlyResponse))
  253. }
  254. return requestError{}
  255. }
  256. // Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
  257. // Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
  258. func (s *Server) Close() error {
  259. s.closed.Set(true)
  260. s.mutex.Lock()
  261. defer s.mutex.Unlock()
  262. var err error
  263. for ln := range s.listeners {
  264. if cerr := (*ln).Close(); cerr != nil && err == nil {
  265. err = cerr
  266. }
  267. }
  268. return err
  269. }
  270. // CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete.
  271. // CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
  272. func (s *Server) CloseGracefully(timeout time.Duration) error {
  273. // TODO: implement
  274. return nil
  275. }
  276. // SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC.
  277. // The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443):
  278. // Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
  279. func (s *Server) SetQuicHeaders(hdr http.Header) error {
  280. port := atomic.LoadUint32(&s.port)
  281. if port == 0 {
  282. // Extract port from s.Server.Addr
  283. _, portStr, err := net.SplitHostPort(s.Server.Addr)
  284. if err != nil {
  285. return err
  286. }
  287. portInt, err := net.LookupPort("tcp", portStr)
  288. if err != nil {
  289. return err
  290. }
  291. port = uint32(portInt)
  292. atomic.StoreUint32(&s.port, port)
  293. }
  294. hdr.Add("Alt-Svc", fmt.Sprintf(`%s=":%d"; ma=2592000`, nextProtoH3, port))
  295. return nil
  296. }
  297. // ListenAndServeQUIC listens on the UDP network address addr and calls the
  298. // handler for HTTP/3 requests on incoming connections. http.DefaultServeMux is
  299. // used when handler is nil.
  300. func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
  301. server := &Server{
  302. Server: &http.Server{
  303. Addr: addr,
  304. Handler: handler,
  305. },
  306. }
  307. return server.ListenAndServeTLS(certFile, keyFile)
  308. }
  309. // ListenAndServe listens on the given network address for both, TLS and QUIC
  310. // connetions in parallel. It returns if one of the two returns an error.
  311. // http.DefaultServeMux is used when handler is nil.
  312. // The correct Alt-Svc headers for QUIC are set.
  313. func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
  314. // Load certs
  315. var err error
  316. certs := make([]tls.Certificate, 1)
  317. certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
  318. if err != nil {
  319. return err
  320. }
  321. // We currently only use the cert-related stuff from tls.Config,
  322. // so we don't need to make a full copy.
  323. config := &tls.Config{
  324. Certificates: certs,
  325. }
  326. // Open the listeners
  327. udpAddr, err := net.ResolveUDPAddr("udp", addr)
  328. if err != nil {
  329. return err
  330. }
  331. udpConn, err := net.ListenUDP("udp", udpAddr)
  332. if err != nil {
  333. return err
  334. }
  335. defer udpConn.Close()
  336. tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
  337. if err != nil {
  338. return err
  339. }
  340. tcpConn, err := net.ListenTCP("tcp", tcpAddr)
  341. if err != nil {
  342. return err
  343. }
  344. defer tcpConn.Close()
  345. tlsConn := tls.NewListener(tcpConn, config)
  346. defer tlsConn.Close()
  347. // Start the servers
  348. httpServer := &http.Server{
  349. Addr: addr,
  350. TLSConfig: config,
  351. }
  352. quicServer := &Server{
  353. Server: httpServer,
  354. }
  355. if handler == nil {
  356. handler = http.DefaultServeMux
  357. }
  358. httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  359. quicServer.SetQuicHeaders(w.Header())
  360. handler.ServeHTTP(w, r)
  361. })
  362. hErr := make(chan error)
  363. qErr := make(chan error)
  364. go func() {
  365. hErr <- httpServer.Serve(tlsConn)
  366. }()
  367. go func() {
  368. qErr <- quicServer.Serve(udpConn)
  369. }()
  370. select {
  371. case err := <-hErr:
  372. quicServer.Close()
  373. return err
  374. case err := <-qErr:
  375. // Cannot close the HTTP server or wait for requests to complete properly :/
  376. return err
  377. }
  378. }