| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427 |
- package http3
- import (
- "bytes"
- "context"
- "crypto/tls"
- "errors"
- "fmt"
- "io"
- "net"
- "net/http"
- "runtime"
- "sync"
- "sync/atomic"
- "time"
- "github.com/Psiphon-Labs/quic-go"
- "github.com/Psiphon-Labs/quic-go/internal/utils"
- "github.com/marten-seemann/qpack"
- // [Psiphon]
- // Remove testing dependency.
- //"github.com/onsi/ginkgo"
- )
- // allows mocking of quic.Listen and quic.ListenAddr
- var (
- quicListen = quic.Listen
- quicListenAddr = quic.ListenAddr
- )
- const nextProtoH3 = "h3-24"
- type requestError struct {
- err error
- streamErr errorCode
- connErr errorCode
- }
- func newStreamError(code errorCode, err error) requestError {
- return requestError{err: err, streamErr: code}
- }
- func newConnError(code errorCode, err error) requestError {
- return requestError{err: err, connErr: code}
- }
- // Server is a HTTP2 server listening for QUIC connections.
- type Server struct {
- *http.Server
- // By providing a quic.Config, it is possible to set parameters of the QUIC connection.
- // If nil, it uses reasonable default values.
- QuicConfig *quic.Config
- port uint32 // used atomically
- mutex sync.Mutex
- listeners map[*quic.Listener]struct{}
- closed utils.AtomicBool
- logger utils.Logger
- }
- // ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
- func (s *Server) ListenAndServe() error {
- if s.Server == nil {
- return errors.New("use of http3.Server without http.Server")
- }
- return s.serveImpl(s.TLSConfig, nil)
- }
- // ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
- func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
- var err error
- certs := make([]tls.Certificate, 1)
- certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
- if err != nil {
- return err
- }
- // We currently only use the cert-related stuff from tls.Config,
- // so we don't need to make a full copy.
- config := &tls.Config{
- Certificates: certs,
- }
- return s.serveImpl(config, nil)
- }
- // Serve an existing UDP connection.
- // It is possible to reuse the same connection for outgoing connections.
- // Closing the server does not close the packet conn.
- func (s *Server) Serve(conn net.PacketConn) error {
- return s.serveImpl(s.TLSConfig, conn)
- }
- func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error {
- if s.closed.Get() {
- return http.ErrServerClosed
- }
- if s.Server == nil {
- return errors.New("use of http3.Server without http.Server")
- }
- s.logger = utils.DefaultLogger.WithPrefix("server")
- if tlsConf == nil {
- tlsConf = &tls.Config{}
- } else {
- tlsConf = tlsConf.Clone()
- }
- // Replace existing ALPNs by H3
- tlsConf.NextProtos = []string{nextProtoH3}
- if tlsConf.GetConfigForClient != nil {
- getConfigForClient := tlsConf.GetConfigForClient
- tlsConf.GetConfigForClient = func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
- conf, err := getConfigForClient(ch)
- if err != nil || conf == nil {
- return conf, err
- }
- conf = conf.Clone()
- conf.NextProtos = []string{nextProtoH3}
- return conf, nil
- }
- }
- var ln quic.Listener
- var err error
- if conn == nil {
- ln, err = quicListenAddr(s.Addr, tlsConf, s.QuicConfig)
- } else {
- ln, err = quicListen(conn, tlsConf, s.QuicConfig)
- }
- if err != nil {
- return err
- }
- s.addListener(&ln)
- defer s.removeListener(&ln)
- for {
- sess, err := ln.Accept(context.Background())
- if err != nil {
- return err
- }
- go s.handleConn(sess)
- }
- }
- // We store a pointer to interface in the map set. This is safe because we only
- // call trackListener via Serve and can track+defer untrack the same pointer to
- // local variable there. We never need to compare a Listener from another caller.
- func (s *Server) addListener(l *quic.Listener) {
- s.mutex.Lock()
- if s.listeners == nil {
- s.listeners = make(map[*quic.Listener]struct{})
- }
- s.listeners[l] = struct{}{}
- s.mutex.Unlock()
- }
- func (s *Server) removeListener(l *quic.Listener) {
- s.mutex.Lock()
- delete(s.listeners, l)
- s.mutex.Unlock()
- }
- func (s *Server) handleConn(sess quic.Session) {
- // TODO: accept control streams
- decoder := qpack.NewDecoder(nil)
- // send a SETTINGS frame
- str, err := sess.OpenUniStream()
- if err != nil {
- s.logger.Debugf("Opening the control stream failed.")
- return
- }
- buf := bytes.NewBuffer([]byte{0})
- (&settingsFrame{}).Write(buf)
- str.Write(buf.Bytes())
- for {
- str, err := sess.AcceptStream(context.Background())
- if err != nil {
- s.logger.Debugf("Accepting stream failed: %s", err)
- return
- }
- go func() {
- // [Psiphon]
- //defer ginkgo.GinkgoRecover()
- rerr := s.handleRequest(str, decoder, func() {
- sess.CloseWithError(quic.ErrorCode(errorFrameUnexpected), "")
- })
- if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 {
- s.logger.Debugf("Handling request failed: %s", err)
- if rerr.streamErr != 0 {
- str.CancelWrite(quic.ErrorCode(rerr.streamErr))
- }
- if rerr.connErr != 0 {
- var reason string
- if rerr.err != nil {
- reason = rerr.err.Error()
- }
- sess.CloseWithError(quic.ErrorCode(rerr.connErr), reason)
- }
- return
- }
- str.Close()
- }()
- }
- }
- func (s *Server) maxHeaderBytes() uint64 {
- if s.Server.MaxHeaderBytes <= 0 {
- return http.DefaultMaxHeaderBytes
- }
- return uint64(s.Server.MaxHeaderBytes)
- }
- func (s *Server) handleRequest(str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
- frame, err := parseNextFrame(str)
- if err != nil {
- return newStreamError(errorRequestIncomplete, err)
- }
- hf, ok := frame.(*headersFrame)
- if !ok {
- return newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
- }
- if hf.Length > s.maxHeaderBytes() {
- return newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes()))
- }
- headerBlock := make([]byte, hf.Length)
- if _, err := io.ReadFull(str, headerBlock); err != nil {
- return newStreamError(errorRequestIncomplete, err)
- }
- hfs, err := decoder.DecodeFull(headerBlock)
- if err != nil {
- // TODO: use the right error code
- return newConnError(errorGeneralProtocolError, err)
- }
- req, err := requestFromHeaders(hfs)
- if err != nil {
- // TODO: use the right error code
- return newStreamError(errorGeneralProtocolError, err)
- }
- req.Body = newRequestBody(str, onFrameError)
- if s.logger.Debug() {
- s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID())
- } else {
- s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
- }
- req = req.WithContext(str.Context())
- responseWriter := newResponseWriter(str, s.logger)
- handler := s.Handler
- if handler == nil {
- handler = http.DefaultServeMux
- }
- var panicked, readEOF bool
- func() {
- defer func() {
- if p := recover(); p != nil {
- // Copied from net/http/server.go
- const size = 64 << 10
- buf := make([]byte, size)
- buf = buf[:runtime.Stack(buf, false)]
- s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
- panicked = true
- }
- }()
- handler.ServeHTTP(responseWriter, req)
- // read the eof
- if _, err = str.Read([]byte{0}); err == io.EOF {
- readEOF = true
- }
- }()
- if panicked {
- responseWriter.WriteHeader(500)
- } else {
- responseWriter.WriteHeader(200)
- }
- if !readEOF {
- str.CancelRead(quic.ErrorCode(errorEarlyResponse))
- }
- return requestError{}
- }
- // Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
- // Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
- func (s *Server) Close() error {
- s.closed.Set(true)
- s.mutex.Lock()
- defer s.mutex.Unlock()
- var err error
- for ln := range s.listeners {
- if cerr := (*ln).Close(); cerr != nil && err == nil {
- err = cerr
- }
- }
- return err
- }
- // CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete.
- // CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
- func (s *Server) CloseGracefully(timeout time.Duration) error {
- // TODO: implement
- return nil
- }
- // SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC.
- // The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443):
- // Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
- func (s *Server) SetQuicHeaders(hdr http.Header) error {
- port := atomic.LoadUint32(&s.port)
- if port == 0 {
- // Extract port from s.Server.Addr
- _, portStr, err := net.SplitHostPort(s.Server.Addr)
- if err != nil {
- return err
- }
- portInt, err := net.LookupPort("tcp", portStr)
- if err != nil {
- return err
- }
- port = uint32(portInt)
- atomic.StoreUint32(&s.port, port)
- }
- hdr.Add("Alt-Svc", fmt.Sprintf(`%s=":%d"; ma=2592000`, nextProtoH3, port))
- return nil
- }
- // ListenAndServeQUIC listens on the UDP network address addr and calls the
- // handler for HTTP/3 requests on incoming connections. http.DefaultServeMux is
- // used when handler is nil.
- func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
- server := &Server{
- Server: &http.Server{
- Addr: addr,
- Handler: handler,
- },
- }
- return server.ListenAndServeTLS(certFile, keyFile)
- }
- // ListenAndServe listens on the given network address for both, TLS and QUIC
- // connetions in parallel. It returns if one of the two returns an error.
- // http.DefaultServeMux is used when handler is nil.
- // The correct Alt-Svc headers for QUIC are set.
- func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
- // Load certs
- var err error
- certs := make([]tls.Certificate, 1)
- certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
- if err != nil {
- return err
- }
- // We currently only use the cert-related stuff from tls.Config,
- // so we don't need to make a full copy.
- config := &tls.Config{
- Certificates: certs,
- }
- // Open the listeners
- udpAddr, err := net.ResolveUDPAddr("udp", addr)
- if err != nil {
- return err
- }
- udpConn, err := net.ListenUDP("udp", udpAddr)
- if err != nil {
- return err
- }
- defer udpConn.Close()
- tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
- if err != nil {
- return err
- }
- tcpConn, err := net.ListenTCP("tcp", tcpAddr)
- if err != nil {
- return err
- }
- defer tcpConn.Close()
- tlsConn := tls.NewListener(tcpConn, config)
- defer tlsConn.Close()
- // Start the servers
- httpServer := &http.Server{
- Addr: addr,
- TLSConfig: config,
- }
- quicServer := &Server{
- Server: httpServer,
- }
- if handler == nil {
- handler = http.DefaultServeMux
- }
- httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- quicServer.SetQuicHeaders(w.Header())
- handler.ServeHTTP(w, r)
- })
- hErr := make(chan error)
- qErr := make(chan error)
- go func() {
- hErr <- httpServer.Serve(tlsConn)
- }()
- go func() {
- qErr <- quicServer.Serve(udpConn)
- }()
- select {
- case err := <-hErr:
- quicServer.Close()
- return err
- case err := <-qErr:
- // Cannot close the HTTP server or wait for requests to complete properly :/
- return err
- }
- }
|