client.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. package quic
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  9. "github.com/Psiphon-Labs/quic-go/internal/utils"
  10. "github.com/Psiphon-Labs/quic-go/logging"
  11. )
  12. type client struct {
  13. sconn sendConn
  14. // If the client is created with DialAddr, we create a packet conn.
  15. // If it is started with Dial, we take a packet conn as a parameter.
  16. createdPacketConn bool
  17. use0RTT bool
  18. packetHandlers packetHandlerManager
  19. tlsConf *tls.Config
  20. config *Config
  21. srcConnID protocol.ConnectionID
  22. destConnID protocol.ConnectionID
  23. initialPacketNumber protocol.PacketNumber
  24. hasNegotiatedVersion bool
  25. version protocol.VersionNumber
  26. handshakeChan chan struct{}
  27. conn quicConn
  28. tracer logging.ConnectionTracer
  29. tracingID uint64
  30. logger utils.Logger
  31. }
  32. // make it possible to mock connection ID for initial generation in the tests
  33. var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
  34. // DialAddr establishes a new QUIC connection to a server.
  35. // It uses a new UDP connection and closes this connection when the QUIC connection is closed.
  36. // The hostname for SNI is taken from the given address.
  37. // The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites.
  38. func DialAddr(
  39. addr string,
  40. tlsConf *tls.Config,
  41. config *Config,
  42. ) (Connection, error) {
  43. return DialAddrContext(context.Background(), addr, tlsConf, config)
  44. }
  45. // DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
  46. // It uses a new UDP connection and closes this connection when the QUIC connection is closed.
  47. // The hostname for SNI is taken from the given address.
  48. // The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites.
  49. func DialAddrEarly(
  50. addr string,
  51. tlsConf *tls.Config,
  52. config *Config,
  53. ) (EarlyConnection, error) {
  54. return DialAddrEarlyContext(context.Background(), addr, tlsConf, config)
  55. }
  56. // DialAddrEarlyContext establishes a new 0-RTT QUIC connection to a server using provided context.
  57. // See DialAddrEarly for details
  58. func DialAddrEarlyContext(
  59. ctx context.Context,
  60. addr string,
  61. tlsConf *tls.Config,
  62. config *Config,
  63. ) (EarlyConnection, error) {
  64. conn, err := dialAddrContext(ctx, addr, tlsConf, config, true)
  65. if err != nil {
  66. return nil, err
  67. }
  68. utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection")
  69. return conn, nil
  70. }
  71. // DialAddrContext establishes a new QUIC connection to a server using the provided context.
  72. // See DialAddr for details.
  73. func DialAddrContext(
  74. ctx context.Context,
  75. addr string,
  76. tlsConf *tls.Config,
  77. config *Config,
  78. ) (Connection, error) {
  79. return dialAddrContext(ctx, addr, tlsConf, config, false)
  80. }
  81. func dialAddrContext(
  82. ctx context.Context,
  83. addr string,
  84. tlsConf *tls.Config,
  85. config *Config,
  86. use0RTT bool,
  87. ) (quicConn, error) {
  88. udpAddr, err := net.ResolveUDPAddr("udp", addr)
  89. if err != nil {
  90. return nil, err
  91. }
  92. udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
  93. if err != nil {
  94. return nil, err
  95. }
  96. return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true)
  97. }
  98. // Dial establishes a new QUIC connection to a server using a net.PacketConn. If
  99. // the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn
  100. // does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
  101. // and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
  102. // packets. The same PacketConn can be used for multiple calls to Dial and
  103. // Listen, QUIC connection IDs are used for demultiplexing the different
  104. // connections. The host parameter is used for SNI. The tls.Config must define
  105. // an application protocol (using NextProtos).
  106. func Dial(
  107. pconn net.PacketConn,
  108. remoteAddr net.Addr,
  109. host string,
  110. tlsConf *tls.Config,
  111. config *Config,
  112. ) (Connection, error) {
  113. return dialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config, false, false)
  114. }
  115. // DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
  116. // The same PacketConn can be used for multiple calls to Dial and Listen,
  117. // QUIC connection IDs are used for demultiplexing the different connections.
  118. // The host parameter is used for SNI.
  119. // The tls.Config must define an application protocol (using NextProtos).
  120. func DialEarly(
  121. pconn net.PacketConn,
  122. remoteAddr net.Addr,
  123. host string,
  124. tlsConf *tls.Config,
  125. config *Config,
  126. ) (EarlyConnection, error) {
  127. return DialEarlyContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
  128. }
  129. // DialEarlyContext establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context.
  130. // See DialEarly for details.
  131. func DialEarlyContext(
  132. ctx context.Context,
  133. pconn net.PacketConn,
  134. remoteAddr net.Addr,
  135. host string,
  136. tlsConf *tls.Config,
  137. config *Config,
  138. ) (EarlyConnection, error) {
  139. return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, true, false)
  140. }
  141. // DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
  142. // See Dial for details.
  143. func DialContext(
  144. ctx context.Context,
  145. pconn net.PacketConn,
  146. remoteAddr net.Addr,
  147. host string,
  148. tlsConf *tls.Config,
  149. config *Config,
  150. ) (Connection, error) {
  151. return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false, false)
  152. }
  153. func dialContext(
  154. ctx context.Context,
  155. pconn net.PacketConn,
  156. remoteAddr net.Addr,
  157. host string,
  158. tlsConf *tls.Config,
  159. config *Config,
  160. use0RTT bool,
  161. createdPacketConn bool,
  162. ) (quicConn, error) {
  163. if tlsConf == nil {
  164. return nil, errors.New("quic: tls.Config not set")
  165. }
  166. if err := validateConfig(config); err != nil {
  167. return nil, err
  168. }
  169. config = populateClientConfig(config, createdPacketConn)
  170. packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
  171. if err != nil {
  172. return nil, err
  173. }
  174. c, err := newClient(pconn, remoteAddr, config, tlsConf, host, use0RTT, createdPacketConn)
  175. if err != nil {
  176. return nil, err
  177. }
  178. c.packetHandlers = packetHandlers
  179. c.tracingID = nextConnTracingID()
  180. if c.config.Tracer != nil {
  181. c.tracer = c.config.Tracer.TracerForConnection(
  182. context.WithValue(ctx, ConnectionTracingKey, c.tracingID),
  183. protocol.PerspectiveClient,
  184. c.destConnID,
  185. )
  186. }
  187. if c.tracer != nil {
  188. c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID)
  189. }
  190. if err := c.dial(ctx); err != nil {
  191. return nil, err
  192. }
  193. return c.conn, nil
  194. }
  195. func newClient(
  196. pconn net.PacketConn,
  197. remoteAddr net.Addr,
  198. config *Config,
  199. tlsConf *tls.Config,
  200. host string,
  201. use0RTT bool,
  202. createdPacketConn bool,
  203. ) (*client, error) {
  204. if tlsConf == nil {
  205. tlsConf = &tls.Config{}
  206. } else {
  207. tlsConf = tlsConf.Clone()
  208. }
  209. if tlsConf.ServerName == "" {
  210. sni, _, err := net.SplitHostPort(host)
  211. if err != nil {
  212. // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port.
  213. sni = host
  214. }
  215. tlsConf.ServerName = sni
  216. }
  217. // check that all versions are actually supported
  218. if config != nil {
  219. for _, v := range config.Versions {
  220. if !protocol.IsValidVersion(v) {
  221. return nil, fmt.Errorf("%s is not a valid QUIC version", v)
  222. }
  223. }
  224. }
  225. srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID()
  226. if err != nil {
  227. return nil, err
  228. }
  229. destConnID, err := generateConnectionIDForInitial()
  230. if err != nil {
  231. return nil, err
  232. }
  233. c := &client{
  234. srcConnID: srcConnID,
  235. destConnID: destConnID,
  236. sconn: newSendPconn(pconn, remoteAddr),
  237. createdPacketConn: createdPacketConn,
  238. use0RTT: use0RTT,
  239. tlsConf: tlsConf,
  240. config: config,
  241. version: config.Versions[0],
  242. handshakeChan: make(chan struct{}),
  243. logger: utils.DefaultLogger.WithPrefix("client"),
  244. }
  245. return c, nil
  246. }
  247. func (c *client) dial(ctx context.Context) error {
  248. c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
  249. c.conn = newClientConnection(
  250. c.sconn,
  251. c.packetHandlers,
  252. c.destConnID,
  253. c.srcConnID,
  254. c.config,
  255. c.tlsConf,
  256. c.initialPacketNumber,
  257. c.use0RTT,
  258. c.hasNegotiatedVersion,
  259. c.tracer,
  260. c.tracingID,
  261. c.logger,
  262. c.version,
  263. )
  264. c.packetHandlers.Add(c.srcConnID, c.conn)
  265. errorChan := make(chan error, 1)
  266. go func() {
  267. err := c.conn.run() // returns as soon as the connection is closed
  268. if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn {
  269. c.packetHandlers.Destroy()
  270. }
  271. errorChan <- err
  272. }()
  273. // only set when we're using 0-RTT
  274. // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
  275. var earlyConnChan <-chan struct{}
  276. if c.use0RTT {
  277. earlyConnChan = c.conn.earlyConnReady()
  278. }
  279. select {
  280. case <-ctx.Done():
  281. c.conn.shutdown()
  282. return ctx.Err()
  283. case err := <-errorChan:
  284. var recreateErr *errCloseForRecreating
  285. if errors.As(err, &recreateErr) {
  286. c.initialPacketNumber = recreateErr.nextPacketNumber
  287. c.version = recreateErr.nextVersion
  288. c.hasNegotiatedVersion = true
  289. return c.dial(ctx)
  290. }
  291. return err
  292. case <-earlyConnChan:
  293. // ready to send 0-RTT data
  294. return nil
  295. case <-c.conn.HandshakeComplete().Done():
  296. // handshake successfully completed
  297. return nil
  298. }
  299. }