client.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. package quic
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "strings"
  9. "sync"
  10. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  11. "github.com/Psiphon-Labs/quic-go/internal/utils"
  12. "github.com/Psiphon-Labs/quic-go/internal/wire"
  13. )
  14. type client struct {
  15. mutex sync.Mutex
  16. conn connection
  17. // If the client is created with DialAddr, we create a packet conn.
  18. // If it is started with Dial, we take a packet conn as a parameter.
  19. createdPacketConn bool
  20. packetHandlers packetHandlerManager
  21. versionNegotiated utils.AtomicBool // has the server accepted our version
  22. receivedVersionNegotiationPacket bool
  23. negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
  24. tlsConf *tls.Config
  25. config *Config
  26. srcConnID protocol.ConnectionID
  27. destConnID protocol.ConnectionID
  28. initialPacketNumber protocol.PacketNumber
  29. initialVersion protocol.VersionNumber
  30. version protocol.VersionNumber
  31. handshakeChan chan struct{}
  32. session quicSession
  33. logger utils.Logger
  34. }
  35. var _ packetHandler = &client{}
  36. var (
  37. // make it possible to mock connection ID generation in the tests
  38. generateConnectionID = protocol.GenerateConnectionID
  39. generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
  40. )
  41. // DialAddr establishes a new QUIC connection to a server.
  42. // It uses a new UDP connection and closes this connection when the QUIC session is closed.
  43. // The hostname for SNI is taken from the given address.
  44. func DialAddr(
  45. addr string,
  46. tlsConf *tls.Config,
  47. config *Config,
  48. ) (Session, error) {
  49. return DialAddrContext(context.Background(), addr, tlsConf, config)
  50. }
  51. // DialAddrContext establishes a new QUIC connection to a server using the provided context.
  52. // See DialAddr for details.
  53. func DialAddrContext(
  54. ctx context.Context,
  55. addr string,
  56. tlsConf *tls.Config,
  57. config *Config,
  58. ) (Session, error) {
  59. udpAddr, err := net.ResolveUDPAddr("udp", addr)
  60. if err != nil {
  61. return nil, err
  62. }
  63. udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
  64. if err != nil {
  65. return nil, err
  66. }
  67. return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
  68. }
  69. // Dial establishes a new QUIC connection to a server using a net.PacketConn.
  70. // The same PacketConn can be used for multiple calls to Dial and Listen,
  71. // QUIC connection IDs are used for demultiplexing the different connections.
  72. // The host parameter is used for SNI.
  73. // The tls.Config must define an application protocol (using NextProtos).
  74. func Dial(
  75. pconn net.PacketConn,
  76. remoteAddr net.Addr,
  77. host string,
  78. tlsConf *tls.Config,
  79. config *Config,
  80. ) (Session, error) {
  81. return DialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
  82. }
  83. // DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
  84. // See Dial for details.
  85. func DialContext(
  86. ctx context.Context,
  87. pconn net.PacketConn,
  88. remoteAddr net.Addr,
  89. host string,
  90. tlsConf *tls.Config,
  91. config *Config,
  92. ) (Session, error) {
  93. return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
  94. }
  95. func dialContext(
  96. ctx context.Context,
  97. pconn net.PacketConn,
  98. remoteAddr net.Addr,
  99. host string,
  100. tlsConf *tls.Config,
  101. config *Config,
  102. createdPacketConn bool,
  103. ) (Session, error) {
  104. if tlsConf == nil {
  105. return nil, errors.New("quic: tls.Config not set")
  106. }
  107. config = populateClientConfig(config, createdPacketConn)
  108. packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey)
  109. if err != nil {
  110. return nil, err
  111. }
  112. c, err := newClient(pconn, remoteAddr, config, tlsConf, host, createdPacketConn)
  113. if err != nil {
  114. return nil, err
  115. }
  116. c.packetHandlers = packetHandlers
  117. if err := c.dial(ctx); err != nil {
  118. return nil, err
  119. }
  120. return c.session, nil
  121. }
  122. func newClient(
  123. pconn net.PacketConn,
  124. remoteAddr net.Addr,
  125. config *Config,
  126. tlsConf *tls.Config,
  127. host string,
  128. createdPacketConn bool,
  129. ) (*client, error) {
  130. if tlsConf == nil {
  131. tlsConf = &tls.Config{}
  132. }
  133. if tlsConf.ServerName == "" {
  134. sni := host
  135. if strings.IndexByte(sni, ':') != -1 {
  136. var err error
  137. sni, _, err = net.SplitHostPort(sni)
  138. if err != nil {
  139. return nil, err
  140. }
  141. }
  142. tlsConf.ServerName = sni
  143. }
  144. // check that all versions are actually supported
  145. if config != nil {
  146. for _, v := range config.Versions {
  147. if !protocol.IsValidVersion(v) {
  148. return nil, fmt.Errorf("%s is not a valid QUIC version", v)
  149. }
  150. }
  151. }
  152. srcConnID, err := generateConnectionID(config.ConnectionIDLength)
  153. if err != nil {
  154. return nil, err
  155. }
  156. destConnID, err := generateConnectionIDForInitial()
  157. if err != nil {
  158. return nil, err
  159. }
  160. c := &client{
  161. srcConnID: srcConnID,
  162. destConnID: destConnID,
  163. conn: &conn{pconn: pconn, currentAddr: remoteAddr},
  164. createdPacketConn: createdPacketConn,
  165. tlsConf: tlsConf,
  166. config: config,
  167. version: config.Versions[0],
  168. handshakeChan: make(chan struct{}),
  169. logger: utils.DefaultLogger.WithPrefix("client"),
  170. }
  171. return c, nil
  172. }
  173. // populateClientConfig populates fields in the quic.Config with their default values, if none are set
  174. // it may be called with nil
  175. func populateClientConfig(config *Config, createdPacketConn bool) *Config {
  176. if config == nil {
  177. config = &Config{}
  178. }
  179. versions := config.Versions
  180. if len(versions) == 0 {
  181. versions = protocol.SupportedVersions
  182. }
  183. handshakeTimeout := protocol.DefaultHandshakeTimeout
  184. if config.HandshakeTimeout != 0 {
  185. handshakeTimeout = config.HandshakeTimeout
  186. }
  187. idleTimeout := protocol.DefaultIdleTimeout
  188. if config.IdleTimeout != 0 {
  189. idleTimeout = config.IdleTimeout
  190. }
  191. maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
  192. if maxReceiveStreamFlowControlWindow == 0 {
  193. maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
  194. }
  195. maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
  196. if maxReceiveConnectionFlowControlWindow == 0 {
  197. maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
  198. }
  199. maxIncomingStreams := config.MaxIncomingStreams
  200. if maxIncomingStreams == 0 {
  201. maxIncomingStreams = protocol.DefaultMaxIncomingStreams
  202. } else if maxIncomingStreams < 0 {
  203. maxIncomingStreams = 0
  204. }
  205. maxIncomingUniStreams := config.MaxIncomingUniStreams
  206. if maxIncomingUniStreams == 0 {
  207. maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
  208. } else if maxIncomingUniStreams < 0 {
  209. maxIncomingUniStreams = 0
  210. }
  211. connIDLen := config.ConnectionIDLength
  212. if connIDLen == 0 && !createdPacketConn {
  213. connIDLen = protocol.DefaultConnectionIDLength
  214. }
  215. return &Config{
  216. Versions: versions,
  217. HandshakeTimeout: handshakeTimeout,
  218. IdleTimeout: idleTimeout,
  219. ConnectionIDLength: connIDLen,
  220. MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
  221. MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
  222. MaxIncomingStreams: maxIncomingStreams,
  223. MaxIncomingUniStreams: maxIncomingUniStreams,
  224. KeepAlive: config.KeepAlive,
  225. StatelessResetKey: config.StatelessResetKey,
  226. QuicTracer: config.QuicTracer,
  227. TokenStore: config.TokenStore,
  228. }
  229. }
  230. func (c *client) dial(ctx context.Context) error {
  231. c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
  232. c.createNewTLSSession(c.version)
  233. err := c.establishSecureConnection(ctx)
  234. if err == errCloseForRecreating {
  235. return c.dial(ctx)
  236. }
  237. return err
  238. }
  239. // establishSecureConnection runs the session, and tries to establish a secure connection
  240. // It returns:
  241. // - errCloseForRecreating when the server sends a version negotiation packet
  242. // - any other error that might occur
  243. // - when the connection is forward-secure
  244. func (c *client) establishSecureConnection(ctx context.Context) error {
  245. errorChan := make(chan error, 1)
  246. go func() {
  247. err := c.session.run() // returns as soon as the session is closed
  248. if err != errCloseForRecreating && c.createdPacketConn {
  249. c.packetHandlers.Close()
  250. }
  251. errorChan <- err
  252. }()
  253. select {
  254. case <-ctx.Done():
  255. // The session will send a PeerGoingAway error to the server.
  256. c.session.Close()
  257. return ctx.Err()
  258. case err := <-errorChan:
  259. return err
  260. case <-c.session.HandshakeComplete().Done():
  261. // handshake successfully completed
  262. return nil
  263. }
  264. }
  265. func (c *client) handlePacket(p *receivedPacket) {
  266. if wire.IsVersionNegotiationPacket(p.data) {
  267. go c.handleVersionNegotiationPacket(p)
  268. return
  269. }
  270. // this is the first packet we are receiving
  271. // since it is not a Version Negotiation Packet, this means the server supports the suggested version
  272. if !c.versionNegotiated.Get() {
  273. c.versionNegotiated.Set(true)
  274. }
  275. c.session.handlePacket(p)
  276. }
  277. func (c *client) handleVersionNegotiationPacket(p *receivedPacket) {
  278. c.mutex.Lock()
  279. defer c.mutex.Unlock()
  280. hdr, _, _, err := wire.ParsePacket(p.data, 0)
  281. if err != nil {
  282. c.logger.Debugf("Error parsing Version Negotiation packet: %s", err)
  283. return
  284. }
  285. // ignore delayed / duplicated version negotiation packets
  286. if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() {
  287. c.logger.Debugf("Received a delayed Version Negotiation packet.")
  288. return
  289. }
  290. for _, v := range hdr.SupportedVersions {
  291. if v == c.version {
  292. // The Version Negotiation packet contains the version that we offered.
  293. // This might be a packet sent by an attacker (or by a terribly broken server implementation).
  294. return
  295. }
  296. }
  297. c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions)
  298. newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
  299. if !ok {
  300. //nolint:stylecheck
  301. c.session.destroy(fmt.Errorf("No compatible QUIC version found. We support %s, server offered %s", c.config.Versions, hdr.SupportedVersions))
  302. c.logger.Debugf("No compatible QUIC version found.")
  303. return
  304. }
  305. c.receivedVersionNegotiationPacket = true
  306. c.negotiatedVersions = hdr.SupportedVersions
  307. // switch to negotiated version
  308. c.initialVersion = c.version
  309. c.version = newVersion
  310. c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
  311. c.initialPacketNumber = c.session.closeForRecreating()
  312. }
  313. func (c *client) createNewTLSSession(_ protocol.VersionNumber) {
  314. c.mutex.Lock()
  315. c.session = newClientSession(
  316. c.conn,
  317. c.packetHandlers,
  318. c.destConnID,
  319. c.srcConnID,
  320. c.config,
  321. c.tlsConf,
  322. c.initialPacketNumber,
  323. c.initialVersion,
  324. c.logger,
  325. c.version,
  326. )
  327. c.mutex.Unlock()
  328. // It's not possible to use the stateless reset token for the client's (first) connection ID,
  329. // since there's no way to securely communicate it to the server.
  330. c.packetHandlers.Add(c.srcConnID, c)
  331. }
  332. func (c *client) Close() error {
  333. c.mutex.Lock()
  334. defer c.mutex.Unlock()
  335. if c.session == nil {
  336. return nil
  337. }
  338. return c.session.Close()
  339. }
  340. func (c *client) destroy(e error) {
  341. c.mutex.Lock()
  342. defer c.mutex.Unlock()
  343. if c.session == nil {
  344. return
  345. }
  346. c.session.destroy(e)
  347. }
  348. func (c *client) GetVersion() protocol.VersionNumber {
  349. c.mutex.Lock()
  350. v := c.version
  351. c.mutex.Unlock()
  352. return v
  353. }
  354. func (c *client) getPerspective() protocol.Perspective {
  355. return protocol.PerspectiveClient
  356. }