server.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. package quic
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "time"
  9. "github.com/lucas-clemente/quic-go/internal/crypto"
  10. "github.com/lucas-clemente/quic-go/internal/handshake"
  11. "github.com/lucas-clemente/quic-go/internal/protocol"
  12. "github.com/lucas-clemente/quic-go/internal/utils"
  13. "github.com/lucas-clemente/quic-go/internal/wire"
  14. "github.com/lucas-clemente/quic-go/qerr"
  15. )
  16. // packetHandler handles packets
  17. type packetHandler interface {
  18. handlePacket(*receivedPacket)
  19. Close(error) error
  20. }
  21. type packetHandlerManager interface {
  22. Add(protocol.ConnectionID, packetHandler)
  23. Get(protocol.ConnectionID) (packetHandler, bool)
  24. Remove(protocol.ConnectionID)
  25. Close(error)
  26. }
  27. type quicSession interface {
  28. Session
  29. handlePacket(*receivedPacket)
  30. getCryptoStream() cryptoStreamI
  31. GetVersion() protocol.VersionNumber
  32. run() error
  33. closeRemote(error)
  34. }
  35. type sessionRunner interface {
  36. onHandshakeComplete(Session)
  37. removeConnectionID(protocol.ConnectionID)
  38. }
  39. type runner struct {
  40. onHandshakeCompleteImpl func(Session)
  41. removeConnectionIDImpl func(protocol.ConnectionID)
  42. }
  43. func (r *runner) onHandshakeComplete(s Session) { r.onHandshakeCompleteImpl(s) }
  44. func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) }
  45. var _ sessionRunner = &runner{}
  46. // A Listener of QUIC
  47. type server struct {
  48. tlsConf *tls.Config
  49. config *Config
  50. conn net.PacketConn
  51. supportsTLS bool
  52. serverTLS *serverTLS
  53. certChain crypto.CertChain
  54. scfg *handshake.ServerConfig
  55. sessionHandler packetHandlerManager
  56. serverError error
  57. sessionQueue chan Session
  58. errorChan chan struct{}
  59. sessionRunner sessionRunner
  60. // set as a member, so they can be set in the tests
  61. newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (quicSession, error)
  62. logger utils.Logger
  63. }
  64. var _ Listener = &server{}
  65. // ListenAddr creates a QUIC server listening on a given address.
  66. // The tls.Config must not be nil, the quic.Config may be nil.
  67. func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
  68. udpAddr, err := net.ResolveUDPAddr("udp", addr)
  69. if err != nil {
  70. return nil, err
  71. }
  72. conn, err := net.ListenUDP("udp", udpAddr)
  73. if err != nil {
  74. return nil, err
  75. }
  76. return Listen(conn, tlsConf, config)
  77. }
  78. // Listen listens for QUIC connections on a given net.PacketConn.
  79. // The tls.Config must not be nil, the quic.Config may be nil.
  80. func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
  81. certChain := crypto.NewCertChain(tlsConf)
  82. kex, err := crypto.NewCurve25519KEX()
  83. if err != nil {
  84. return nil, err
  85. }
  86. scfg, err := handshake.NewServerConfig(kex, certChain)
  87. if err != nil {
  88. return nil, err
  89. }
  90. config = populateServerConfig(config)
  91. var supportsTLS bool
  92. for _, v := range config.Versions {
  93. if !protocol.IsValidVersion(v) {
  94. return nil, fmt.Errorf("%s is not a valid QUIC version", v)
  95. }
  96. // check if any of the supported versions supports TLS
  97. if v.UsesTLS() {
  98. supportsTLS = true
  99. break
  100. }
  101. }
  102. s := &server{
  103. conn: conn,
  104. tlsConf: tlsConf,
  105. config: config,
  106. certChain: certChain,
  107. scfg: scfg,
  108. newSession: newSession,
  109. sessionHandler: newPacketHandlerMap(),
  110. sessionQueue: make(chan Session, 5),
  111. errorChan: make(chan struct{}),
  112. supportsTLS: supportsTLS,
  113. logger: utils.DefaultLogger.WithPrefix("server"),
  114. }
  115. s.setup()
  116. if supportsTLS {
  117. if err := s.setupTLS(); err != nil {
  118. return nil, err
  119. }
  120. }
  121. go s.serve()
  122. s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
  123. return s, nil
  124. }
  125. func (s *server) setup() {
  126. s.sessionRunner = &runner{
  127. onHandshakeCompleteImpl: func(sess Session) { s.sessionQueue <- sess },
  128. removeConnectionIDImpl: s.sessionHandler.Remove,
  129. }
  130. }
  131. func (s *server) setupTLS() error {
  132. cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie, s.logger)
  133. if err != nil {
  134. return err
  135. }
  136. serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, s.sessionRunner, cookieHandler, s.tlsConf, s.logger)
  137. if err != nil {
  138. return err
  139. }
  140. s.serverTLS = serverTLS
  141. // handle TLS connection establishment statelessly
  142. go func() {
  143. for {
  144. select {
  145. case <-s.errorChan:
  146. return
  147. case tlsSession := <-sessionChan:
  148. // The connection ID is a randomly chosen 8 byte value.
  149. // It is safe to assume that it doesn't collide with other randomly chosen values.
  150. s.sessionHandler.Add(tlsSession.connID, tlsSession.sess)
  151. }
  152. }
  153. }()
  154. return nil
  155. }
  156. var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool {
  157. if cookie == nil {
  158. return false
  159. }
  160. if time.Now().After(cookie.SentTime.Add(protocol.CookieExpiryTime)) {
  161. return false
  162. }
  163. var sourceAddr string
  164. if udpAddr, ok := clientAddr.(*net.UDPAddr); ok {
  165. sourceAddr = udpAddr.IP.String()
  166. } else {
  167. sourceAddr = clientAddr.String()
  168. }
  169. return sourceAddr == cookie.RemoteAddr
  170. }
  171. // populateServerConfig populates fields in the quic.Config with their default values, if none are set
  172. // it may be called with nil
  173. func populateServerConfig(config *Config) *Config {
  174. if config == nil {
  175. config = &Config{}
  176. }
  177. versions := config.Versions
  178. if len(versions) == 0 {
  179. versions = protocol.SupportedVersions
  180. }
  181. vsa := defaultAcceptCookie
  182. if config.AcceptCookie != nil {
  183. vsa = config.AcceptCookie
  184. }
  185. handshakeTimeout := protocol.DefaultHandshakeTimeout
  186. if config.HandshakeTimeout != 0 {
  187. handshakeTimeout = config.HandshakeTimeout
  188. }
  189. idleTimeout := protocol.DefaultIdleTimeout
  190. if config.IdleTimeout != 0 {
  191. idleTimeout = config.IdleTimeout
  192. }
  193. maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
  194. if maxReceiveStreamFlowControlWindow == 0 {
  195. maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowServer
  196. }
  197. maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
  198. if maxReceiveConnectionFlowControlWindow == 0 {
  199. maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowServer
  200. }
  201. maxIncomingStreams := config.MaxIncomingStreams
  202. if maxIncomingStreams == 0 {
  203. maxIncomingStreams = protocol.DefaultMaxIncomingStreams
  204. } else if maxIncomingStreams < 0 {
  205. maxIncomingStreams = 0
  206. }
  207. maxIncomingUniStreams := config.MaxIncomingUniStreams
  208. if maxIncomingUniStreams == 0 {
  209. maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
  210. } else if maxIncomingUniStreams < 0 {
  211. maxIncomingUniStreams = 0
  212. }
  213. return &Config{
  214. Versions: versions,
  215. HandshakeTimeout: handshakeTimeout,
  216. IdleTimeout: idleTimeout,
  217. AcceptCookie: vsa,
  218. KeepAlive: config.KeepAlive,
  219. MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
  220. MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
  221. MaxIncomingStreams: maxIncomingStreams,
  222. MaxIncomingUniStreams: maxIncomingUniStreams,
  223. }
  224. }
  225. // serve listens on an existing PacketConn
  226. func (s *server) serve() {
  227. for {
  228. data := *getPacketBuffer()
  229. data = data[:protocol.MaxReceivePacketSize]
  230. // The packet size should not exceed protocol.MaxReceivePacketSize bytes
  231. // If it does, we only read a truncated packet, which will then end up undecryptable
  232. n, remoteAddr, err := s.conn.ReadFrom(data)
  233. if err != nil {
  234. s.serverError = err
  235. close(s.errorChan)
  236. _ = s.Close()
  237. return
  238. }
  239. data = data[:n]
  240. if err := s.handlePacket(remoteAddr, data); err != nil {
  241. s.logger.Errorf("error handling packet: %s", err.Error())
  242. }
  243. }
  244. }
  245. // Accept returns newly openend sessions
  246. func (s *server) Accept() (Session, error) {
  247. var sess Session
  248. select {
  249. case sess = <-s.sessionQueue:
  250. return sess, nil
  251. case <-s.errorChan:
  252. return nil, s.serverError
  253. }
  254. }
  255. // Close the server
  256. func (s *server) Close() error {
  257. s.sessionHandler.Close(nil)
  258. err := s.conn.Close()
  259. <-s.errorChan // wait for serve() to return
  260. return err
  261. }
  262. // Addr returns the server's network address
  263. func (s *server) Addr() net.Addr {
  264. return s.conn.LocalAddr()
  265. }
  266. func (s *server) handlePacket(remoteAddr net.Addr, packet []byte) error {
  267. rcvTime := time.Now()
  268. r := bytes.NewReader(packet)
  269. hdr, err := wire.ParseHeaderSentByClient(r)
  270. if err != nil {
  271. return qerr.Error(qerr.InvalidPacketHeader, err.Error())
  272. }
  273. hdr.Raw = packet[:len(packet)-r.Len()]
  274. packetData := packet[len(packet)-r.Len():]
  275. if hdr.IsPublicHeader {
  276. return s.handleGQUICPacket(hdr, packetData, remoteAddr, rcvTime)
  277. }
  278. return s.handleIETFQUICPacket(hdr, packetData, remoteAddr, rcvTime)
  279. }
  280. func (s *server) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
  281. if hdr.IsLongHeader {
  282. if !s.supportsTLS {
  283. return errors.New("Received an IETF QUIC Long Header")
  284. }
  285. if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
  286. return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
  287. }
  288. packetData = packetData[:int(hdr.PayloadLen)]
  289. // TODO(#1312): implement parsing of compound packets
  290. switch hdr.Type {
  291. case protocol.PacketTypeInitial:
  292. go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData)
  293. return nil
  294. case protocol.PacketTypeHandshake:
  295. // nothing to do here. Packet will be passed to the session.
  296. default:
  297. // Note that this also drops 0-RTT packets.
  298. return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
  299. }
  300. }
  301. session, sessionKnown := s.sessionHandler.Get(hdr.DestConnectionID)
  302. if sessionKnown && session == nil {
  303. // Late packet for closed session
  304. return nil
  305. }
  306. if !sessionKnown {
  307. s.logger.Debugf("Received %s packet for unknown connection %s.", hdr.Type, hdr.DestConnectionID)
  308. return nil
  309. }
  310. session.handlePacket(&receivedPacket{
  311. remoteAddr: remoteAddr,
  312. header: hdr,
  313. data: packetData,
  314. rcvTime: rcvTime,
  315. })
  316. return nil
  317. }
  318. func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
  319. // ignore all Public Reset packets
  320. if hdr.ResetFlag {
  321. s.logger.Infof("Received unexpected Public Reset for connection %s.", hdr.DestConnectionID)
  322. return nil
  323. }
  324. session, sessionKnown := s.sessionHandler.Get(hdr.DestConnectionID)
  325. if sessionKnown && session == nil {
  326. // Late packet for closed session
  327. return nil
  328. }
  329. // If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset
  330. // This should only happen after a server restart, when we still receive packets for connections that we lost the state for.
  331. if !sessionKnown && !hdr.VersionFlag {
  332. _, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), remoteAddr)
  333. return err
  334. }
  335. // a session is only created once the client sent a supported version
  336. // if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated
  337. // it is safe to drop it
  338. if sessionKnown && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
  339. return nil
  340. }
  341. // send a Version Negotiation Packet if the client is speaking a different protocol version
  342. // since the client send a Public Header (only gQUIC has a Version Flag), we need to send a gQUIC Version Negotiation Packet
  343. if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
  344. // drop packets that are too small to be valid first packets
  345. if len(packetData) < protocol.MinClientHelloSize {
  346. return errors.New("dropping small packet with unknown version")
  347. }
  348. s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", hdr.Version)
  349. _, err := s.conn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions), remoteAddr)
  350. return err
  351. }
  352. if !sessionKnown {
  353. // This is (potentially) a Client Hello.
  354. // Make sure it has the minimum required size before spending any more ressources on it.
  355. if len(packetData) < protocol.MinClientHelloSize {
  356. return errors.New("dropping small packet for unknown connection")
  357. }
  358. version := hdr.Version
  359. if !protocol.IsSupportedVersion(s.config.Versions, version) {
  360. return errors.New("Server BUG: negotiated version not supported")
  361. }
  362. s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, version, remoteAddr)
  363. sess, err := s.newSession(
  364. &conn{pconn: s.conn, currentAddr: remoteAddr},
  365. s.sessionRunner,
  366. version,
  367. hdr.DestConnectionID,
  368. s.scfg,
  369. s.tlsConf,
  370. s.config,
  371. s.logger,
  372. )
  373. if err != nil {
  374. return err
  375. }
  376. s.sessionHandler.Add(hdr.DestConnectionID, sess)
  377. go sess.run()
  378. session = sess
  379. }
  380. session.handlePacket(&receivedPacket{
  381. remoteAddr: remoteAddr,
  382. header: hdr,
  383. data: packetData,
  384. rcvTime: rcvTime,
  385. })
  386. return nil
  387. }