server.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. package quic
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "sync"
  9. "time"
  10. "github.com/lucas-clemente/quic-go/internal/crypto"
  11. "github.com/lucas-clemente/quic-go/internal/handshake"
  12. "github.com/lucas-clemente/quic-go/internal/protocol"
  13. "github.com/lucas-clemente/quic-go/internal/utils"
  14. "github.com/lucas-clemente/quic-go/internal/wire"
  15. )
  16. // packetHandler handles packets
  17. type packetHandler interface {
  18. handlePacket(*receivedPacket)
  19. io.Closer
  20. destroy(error)
  21. GetVersion() protocol.VersionNumber
  22. GetPerspective() protocol.Perspective
  23. }
  24. type unknownPacketHandler interface {
  25. handlePacket(*receivedPacket)
  26. closeWithError(error) error
  27. }
  28. type packetHandlerManager interface {
  29. Add(protocol.ConnectionID, packetHandler)
  30. SetServer(unknownPacketHandler)
  31. Remove(protocol.ConnectionID)
  32. CloseServer()
  33. }
  34. type quicSession interface {
  35. Session
  36. handlePacket(*receivedPacket)
  37. GetVersion() protocol.VersionNumber
  38. run() error
  39. destroy(error)
  40. closeRemote(error)
  41. }
  42. type sessionRunner interface {
  43. onHandshakeComplete(Session)
  44. removeConnectionID(protocol.ConnectionID)
  45. }
  46. type runner struct {
  47. onHandshakeCompleteImpl func(Session)
  48. removeConnectionIDImpl func(protocol.ConnectionID)
  49. }
  50. func (r *runner) onHandshakeComplete(s Session) { r.onHandshakeCompleteImpl(s) }
  51. func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) }
  52. var _ sessionRunner = &runner{}
  53. // A Listener of QUIC
  54. type server struct {
  55. mutex sync.Mutex
  56. tlsConf *tls.Config
  57. config *Config
  58. conn net.PacketConn
  59. // If the server is started with ListenAddr, we create a packet conn.
  60. // If it is started with Listen, we take a packet conn as a parameter.
  61. createdPacketConn bool
  62. supportsTLS bool
  63. serverTLS *serverTLS
  64. certChain crypto.CertChain
  65. scfg *handshake.ServerConfig
  66. sessionHandler packetHandlerManager
  67. serverError error
  68. errorChan chan struct{}
  69. closed bool
  70. sessionQueue chan Session
  71. sessionRunner sessionRunner
  72. // set as a member, so they can be set in the tests
  73. newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (quicSession, error)
  74. logger utils.Logger
  75. }
  76. var _ Listener = &server{}
  77. var _ unknownPacketHandler = &server{}
  78. // ListenAddr creates a QUIC server listening on a given address.
  79. // The tls.Config must not be nil, the quic.Config may be nil.
  80. func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
  81. udpAddr, err := net.ResolveUDPAddr("udp", addr)
  82. if err != nil {
  83. return nil, err
  84. }
  85. conn, err := net.ListenUDP("udp", udpAddr)
  86. if err != nil {
  87. return nil, err
  88. }
  89. serv, err := listen(conn, tlsConf, config)
  90. if err != nil {
  91. return nil, err
  92. }
  93. serv.createdPacketConn = true
  94. return serv, nil
  95. }
  96. // Listen listens for QUIC connections on a given net.PacketConn.
  97. // The tls.Config must not be nil, the quic.Config may be nil.
  98. func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
  99. return listen(conn, tlsConf, config)
  100. }
  101. func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) {
  102. if tlsConf == nil || (len(tlsConf.Certificates) == 0 && tlsConf.GetCertificate == nil) {
  103. return nil, errors.New("quic: neither Certificates nor GetCertificate set in tls.Config")
  104. }
  105. certChain := crypto.NewCertChain(tlsConf)
  106. kex, err := crypto.NewCurve25519KEX()
  107. if err != nil {
  108. return nil, err
  109. }
  110. scfg, err := handshake.NewServerConfig(kex, certChain)
  111. if err != nil {
  112. return nil, err
  113. }
  114. config = populateServerConfig(config)
  115. var supportsTLS bool
  116. for _, v := range config.Versions {
  117. if !protocol.IsValidVersion(v) {
  118. return nil, fmt.Errorf("%s is not a valid QUIC version", v)
  119. }
  120. // check if any of the supported versions supports TLS
  121. if v.UsesTLS() {
  122. supportsTLS = true
  123. break
  124. }
  125. }
  126. sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength)
  127. if err != nil {
  128. return nil, err
  129. }
  130. s := &server{
  131. conn: conn,
  132. tlsConf: tlsConf,
  133. config: config,
  134. certChain: certChain,
  135. scfg: scfg,
  136. newSession: newSession,
  137. sessionHandler: sessionHandler,
  138. sessionQueue: make(chan Session, 5),
  139. errorChan: make(chan struct{}),
  140. supportsTLS: supportsTLS,
  141. logger: utils.DefaultLogger.WithPrefix("server"),
  142. }
  143. s.setup()
  144. if supportsTLS {
  145. if err := s.setupTLS(); err != nil {
  146. return nil, err
  147. }
  148. }
  149. sessionHandler.SetServer(s)
  150. s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
  151. return s, nil
  152. }
  153. func (s *server) setup() {
  154. s.sessionRunner = &runner{
  155. onHandshakeCompleteImpl: func(sess Session) { s.sessionQueue <- sess },
  156. removeConnectionIDImpl: s.sessionHandler.Remove,
  157. }
  158. }
  159. func (s *server) setupTLS() error {
  160. serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, s.sessionRunner, s.tlsConf, s.logger)
  161. if err != nil {
  162. return err
  163. }
  164. s.serverTLS = serverTLS
  165. // handle TLS connection establishment statelessly
  166. go func() {
  167. for {
  168. select {
  169. case <-s.errorChan:
  170. return
  171. case tlsSession := <-sessionChan:
  172. // The connection ID is a randomly chosen value.
  173. // It is safe to assume that it doesn't collide with other randomly chosen values.
  174. serverSession := newServerSession(tlsSession.sess, s.config, s.logger)
  175. s.sessionHandler.Add(tlsSession.connID, serverSession)
  176. }
  177. }
  178. }()
  179. return nil
  180. }
  181. var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool {
  182. if cookie == nil {
  183. return false
  184. }
  185. if time.Now().After(cookie.SentTime.Add(protocol.CookieExpiryTime)) {
  186. return false
  187. }
  188. var sourceAddr string
  189. if udpAddr, ok := clientAddr.(*net.UDPAddr); ok {
  190. sourceAddr = udpAddr.IP.String()
  191. } else {
  192. sourceAddr = clientAddr.String()
  193. }
  194. return sourceAddr == cookie.RemoteAddr
  195. }
  196. // populateServerConfig populates fields in the quic.Config with their default values, if none are set
  197. // it may be called with nil
  198. func populateServerConfig(config *Config) *Config {
  199. if config == nil {
  200. config = &Config{}
  201. }
  202. versions := config.Versions
  203. if len(versions) == 0 {
  204. versions = protocol.SupportedVersions
  205. }
  206. vsa := defaultAcceptCookie
  207. if config.AcceptCookie != nil {
  208. vsa = config.AcceptCookie
  209. }
  210. handshakeTimeout := protocol.DefaultHandshakeTimeout
  211. if config.HandshakeTimeout != 0 {
  212. handshakeTimeout = config.HandshakeTimeout
  213. }
  214. idleTimeout := protocol.DefaultIdleTimeout
  215. if config.IdleTimeout != 0 {
  216. idleTimeout = config.IdleTimeout
  217. }
  218. maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
  219. if maxReceiveStreamFlowControlWindow == 0 {
  220. maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowServer
  221. }
  222. maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
  223. if maxReceiveConnectionFlowControlWindow == 0 {
  224. maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowServer
  225. }
  226. maxIncomingStreams := config.MaxIncomingStreams
  227. if maxIncomingStreams == 0 {
  228. maxIncomingStreams = protocol.DefaultMaxIncomingStreams
  229. } else if maxIncomingStreams < 0 {
  230. maxIncomingStreams = 0
  231. }
  232. maxIncomingUniStreams := config.MaxIncomingUniStreams
  233. if maxIncomingUniStreams == 0 {
  234. maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
  235. } else if maxIncomingUniStreams < 0 {
  236. maxIncomingUniStreams = 0
  237. }
  238. connIDLen := config.ConnectionIDLength
  239. if connIDLen == 0 {
  240. connIDLen = protocol.DefaultConnectionIDLength
  241. }
  242. for _, v := range versions {
  243. if v == protocol.Version44 {
  244. connIDLen = protocol.ConnectionIDLenGQUIC
  245. }
  246. }
  247. return &Config{
  248. Versions: versions,
  249. HandshakeTimeout: handshakeTimeout,
  250. IdleTimeout: idleTimeout,
  251. AcceptCookie: vsa,
  252. KeepAlive: config.KeepAlive,
  253. MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
  254. MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
  255. MaxIncomingStreams: maxIncomingStreams,
  256. MaxIncomingUniStreams: maxIncomingUniStreams,
  257. ConnectionIDLength: connIDLen,
  258. }
  259. }
  260. // Accept returns newly openend sessions
  261. func (s *server) Accept() (Session, error) {
  262. var sess Session
  263. select {
  264. case sess = <-s.sessionQueue:
  265. return sess, nil
  266. case <-s.errorChan:
  267. return nil, s.serverError
  268. }
  269. }
  270. // Close the server
  271. func (s *server) Close() error {
  272. s.mutex.Lock()
  273. defer s.mutex.Unlock()
  274. if s.closed {
  275. return nil
  276. }
  277. return s.closeWithMutex()
  278. }
  279. func (s *server) closeWithMutex() error {
  280. s.sessionHandler.CloseServer()
  281. if s.serverError == nil {
  282. s.serverError = errors.New("server closed")
  283. }
  284. var err error
  285. // If the server was started with ListenAddr, we created the packet conn.
  286. // We need to close it in order to make the go routine reading from that conn return.
  287. if s.createdPacketConn {
  288. err = s.conn.Close()
  289. }
  290. s.closed = true
  291. close(s.errorChan)
  292. return err
  293. }
  294. func (s *server) closeWithError(e error) error {
  295. s.mutex.Lock()
  296. defer s.mutex.Unlock()
  297. if s.closed {
  298. return nil
  299. }
  300. s.serverError = e
  301. return s.closeWithMutex()
  302. }
  303. // Addr returns the server's network address
  304. func (s *server) Addr() net.Addr {
  305. return s.conn.LocalAddr()
  306. }
  307. func (s *server) handlePacket(p *receivedPacket) {
  308. if err := s.handlePacketImpl(p); err != nil {
  309. s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
  310. }
  311. }
  312. func (s *server) handlePacketImpl(p *receivedPacket) error {
  313. hdr := p.header
  314. if hdr.VersionFlag || hdr.IsLongHeader {
  315. // send a Version Negotiation Packet if the client is speaking a different protocol version
  316. if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
  317. return s.sendVersionNegotiationPacket(p)
  318. }
  319. }
  320. if hdr.Type == protocol.PacketTypeInitial && hdr.Version.UsesTLS() {
  321. go s.serverTLS.HandleInitial(p)
  322. return nil
  323. }
  324. // TODO(#943): send Stateless Reset, if this an IETF QUIC packet
  325. if !hdr.VersionFlag && !hdr.Version.UsesIETFHeaderFormat() {
  326. _, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), p.remoteAddr)
  327. return err
  328. }
  329. // This is (potentially) a Client Hello.
  330. // Make sure it has the minimum required size before spending any more ressources on it.
  331. if len(p.data) < protocol.MinClientHelloSize {
  332. return errors.New("dropping small packet for unknown connection")
  333. }
  334. var destConnID, srcConnID protocol.ConnectionID
  335. if hdr.Version.UsesIETFHeaderFormat() {
  336. srcConnID = hdr.DestConnectionID
  337. } else {
  338. destConnID = hdr.DestConnectionID
  339. srcConnID = hdr.DestConnectionID
  340. }
  341. s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, hdr.Version, p.remoteAddr)
  342. sess, err := s.newSession(
  343. &conn{pconn: s.conn, currentAddr: p.remoteAddr},
  344. s.sessionRunner,
  345. hdr.Version,
  346. destConnID,
  347. srcConnID,
  348. s.scfg,
  349. s.tlsConf,
  350. s.config,
  351. s.logger,
  352. )
  353. if err != nil {
  354. return err
  355. }
  356. s.sessionHandler.Add(hdr.DestConnectionID, newServerSession(sess, s.config, s.logger))
  357. go sess.run()
  358. sess.handlePacket(p)
  359. return nil
  360. }
  361. func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error {
  362. hdr := p.header
  363. s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
  364. var data []byte
  365. if hdr.IsPublicHeader {
  366. data = wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions)
  367. } else {
  368. var err error
  369. data, err = wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
  370. if err != nil {
  371. return err
  372. }
  373. }
  374. _, err := s.conn.WriteTo(data, p.remoteAddr)
  375. return err
  376. }