client.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. package quic
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "strings"
  9. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  10. "github.com/Psiphon-Labs/quic-go/internal/utils"
  11. "github.com/Psiphon-Labs/quic-go/logging"
  12. )
  13. type client struct {
  14. sconn sendConn
  15. // If the client is created with DialAddr, we create a packet conn.
  16. // If it is started with Dial, we take a packet conn as a parameter.
  17. createdPacketConn bool
  18. use0RTT bool
  19. packetHandlers packetHandlerManager
  20. tlsConf *tls.Config
  21. config *Config
  22. srcConnID protocol.ConnectionID
  23. destConnID protocol.ConnectionID
  24. initialPacketNumber protocol.PacketNumber
  25. hasNegotiatedVersion bool
  26. version protocol.VersionNumber
  27. handshakeChan chan struct{}
  28. conn quicConn
  29. tracer logging.ConnectionTracer
  30. tracingID uint64
  31. logger utils.Logger
  32. }
  33. // make it possible to mock connection ID for initial generation in the tests
  34. var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
  35. // DialAddr establishes a new QUIC connection to a server.
  36. // It uses a new UDP connection and closes this connection when the QUIC connection is closed.
  37. // The hostname for SNI is taken from the given address.
  38. // The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites.
  39. func DialAddr(
  40. addr string,
  41. tlsConf *tls.Config,
  42. config *Config,
  43. ) (Connection, error) {
  44. return DialAddrContext(context.Background(), addr, tlsConf, config)
  45. }
  46. // DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
  47. // It uses a new UDP connection and closes this connection when the QUIC connection is closed.
  48. // The hostname for SNI is taken from the given address.
  49. // The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites.
  50. func DialAddrEarly(
  51. addr string,
  52. tlsConf *tls.Config,
  53. config *Config,
  54. ) (EarlyConnection, error) {
  55. return DialAddrEarlyContext(context.Background(), addr, tlsConf, config)
  56. }
  57. // DialAddrEarlyContext establishes a new 0-RTT QUIC connection to a server using provided context.
  58. // See DialAddrEarly for details
  59. func DialAddrEarlyContext(
  60. ctx context.Context,
  61. addr string,
  62. tlsConf *tls.Config,
  63. config *Config,
  64. ) (EarlyConnection, error) {
  65. conn, err := dialAddrContext(ctx, addr, tlsConf, config, true)
  66. if err != nil {
  67. return nil, err
  68. }
  69. utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection")
  70. return conn, nil
  71. }
  72. // DialAddrContext establishes a new QUIC connection to a server using the provided context.
  73. // See DialAddr for details.
  74. func DialAddrContext(
  75. ctx context.Context,
  76. addr string,
  77. tlsConf *tls.Config,
  78. config *Config,
  79. ) (Connection, error) {
  80. return dialAddrContext(ctx, addr, tlsConf, config, false)
  81. }
  82. func dialAddrContext(
  83. ctx context.Context,
  84. addr string,
  85. tlsConf *tls.Config,
  86. config *Config,
  87. use0RTT bool,
  88. ) (quicConn, error) {
  89. udpAddr, err := net.ResolveUDPAddr("udp", addr)
  90. if err != nil {
  91. return nil, err
  92. }
  93. udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
  94. if err != nil {
  95. return nil, err
  96. }
  97. return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true)
  98. }
  99. // Dial establishes a new QUIC connection to a server using a net.PacketConn. If
  100. // the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn
  101. // does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
  102. // and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
  103. // packets. The same PacketConn can be used for multiple calls to Dial and
  104. // Listen, QUIC connection IDs are used for demultiplexing the different
  105. // connections. The host parameter is used for SNI. The tls.Config must define
  106. // an application protocol (using NextProtos).
  107. func Dial(
  108. pconn net.PacketConn,
  109. remoteAddr net.Addr,
  110. host string,
  111. tlsConf *tls.Config,
  112. config *Config,
  113. ) (Connection, error) {
  114. return dialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config, false, false)
  115. }
  116. // DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
  117. // The same PacketConn can be used for multiple calls to Dial and Listen,
  118. // QUIC connection IDs are used for demultiplexing the different connections.
  119. // The host parameter is used for SNI.
  120. // The tls.Config must define an application protocol (using NextProtos).
  121. func DialEarly(
  122. pconn net.PacketConn,
  123. remoteAddr net.Addr,
  124. host string,
  125. tlsConf *tls.Config,
  126. config *Config,
  127. ) (EarlyConnection, error) {
  128. return DialEarlyContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
  129. }
  130. // DialEarlyContext establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context.
  131. // See DialEarly for details.
  132. func DialEarlyContext(
  133. ctx context.Context,
  134. pconn net.PacketConn,
  135. remoteAddr net.Addr,
  136. host string,
  137. tlsConf *tls.Config,
  138. config *Config,
  139. ) (EarlyConnection, error) {
  140. return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, true, false)
  141. }
  142. // DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
  143. // See Dial for details.
  144. func DialContext(
  145. ctx context.Context,
  146. pconn net.PacketConn,
  147. remoteAddr net.Addr,
  148. host string,
  149. tlsConf *tls.Config,
  150. config *Config,
  151. ) (Connection, error) {
  152. return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false, false)
  153. }
  154. func dialContext(
  155. ctx context.Context,
  156. pconn net.PacketConn,
  157. remoteAddr net.Addr,
  158. host string,
  159. tlsConf *tls.Config,
  160. config *Config,
  161. use0RTT bool,
  162. createdPacketConn bool,
  163. ) (quicConn, error) {
  164. if tlsConf == nil {
  165. return nil, errors.New("quic: tls.Config not set")
  166. }
  167. if err := validateConfig(config); err != nil {
  168. return nil, err
  169. }
  170. config = populateClientConfig(config, createdPacketConn)
  171. packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
  172. if err != nil {
  173. return nil, err
  174. }
  175. c, err := newClient(pconn, remoteAddr, config, tlsConf, host, use0RTT, createdPacketConn)
  176. if err != nil {
  177. return nil, err
  178. }
  179. c.packetHandlers = packetHandlers
  180. c.tracingID = nextConnTracingID()
  181. if c.config.Tracer != nil {
  182. c.tracer = c.config.Tracer.TracerForConnection(
  183. context.WithValue(ctx, ConnectionTracingKey, c.tracingID),
  184. protocol.PerspectiveClient,
  185. c.destConnID,
  186. )
  187. }
  188. if c.tracer != nil {
  189. c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID)
  190. }
  191. if err := c.dial(ctx); err != nil {
  192. return nil, err
  193. }
  194. return c.conn, nil
  195. }
  196. func newClient(
  197. pconn net.PacketConn,
  198. remoteAddr net.Addr,
  199. config *Config,
  200. tlsConf *tls.Config,
  201. host string,
  202. use0RTT bool,
  203. createdPacketConn bool,
  204. ) (*client, error) {
  205. if tlsConf == nil {
  206. tlsConf = &tls.Config{}
  207. } else {
  208. tlsConf = tlsConf.Clone()
  209. }
  210. if tlsConf.ServerName == "" {
  211. sni := host
  212. if strings.IndexByte(sni, ':') != -1 {
  213. var err error
  214. sni, _, err = net.SplitHostPort(sni)
  215. if err != nil {
  216. return nil, err
  217. }
  218. }
  219. tlsConf.ServerName = sni
  220. }
  221. // check that all versions are actually supported
  222. if config != nil {
  223. for _, v := range config.Versions {
  224. if !protocol.IsValidVersion(v) {
  225. return nil, fmt.Errorf("%s is not a valid QUIC version", v)
  226. }
  227. }
  228. }
  229. srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID()
  230. if err != nil {
  231. return nil, err
  232. }
  233. destConnID, err := generateConnectionIDForInitial()
  234. if err != nil {
  235. return nil, err
  236. }
  237. c := &client{
  238. srcConnID: srcConnID,
  239. destConnID: destConnID,
  240. sconn: newSendPconn(pconn, remoteAddr),
  241. createdPacketConn: createdPacketConn,
  242. use0RTT: use0RTT,
  243. tlsConf: tlsConf,
  244. config: config,
  245. version: config.Versions[0],
  246. handshakeChan: make(chan struct{}),
  247. logger: utils.DefaultLogger.WithPrefix("client"),
  248. }
  249. return c, nil
  250. }
  251. func (c *client) dial(ctx context.Context) error {
  252. c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
  253. c.conn = newClientConnection(
  254. c.sconn,
  255. c.packetHandlers,
  256. c.destConnID,
  257. c.srcConnID,
  258. c.config,
  259. c.tlsConf,
  260. c.initialPacketNumber,
  261. c.use0RTT,
  262. c.hasNegotiatedVersion,
  263. c.tracer,
  264. c.tracingID,
  265. c.logger,
  266. c.version,
  267. )
  268. c.packetHandlers.Add(c.srcConnID, c.conn)
  269. errorChan := make(chan error, 1)
  270. go func() {
  271. err := c.conn.run() // returns as soon as the connection is closed
  272. if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn {
  273. c.packetHandlers.Destroy()
  274. }
  275. errorChan <- err
  276. }()
  277. // only set when we're using 0-RTT
  278. // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
  279. var earlyConnChan <-chan struct{}
  280. if c.use0RTT {
  281. earlyConnChan = c.conn.earlyConnReady()
  282. }
  283. select {
  284. case <-ctx.Done():
  285. c.conn.shutdown()
  286. return ctx.Err()
  287. case err := <-errorChan:
  288. var recreateErr *errCloseForRecreating
  289. if errors.As(err, &recreateErr) {
  290. c.initialPacketNumber = recreateErr.nextPacketNumber
  291. c.version = recreateErr.nextVersion
  292. c.hasNegotiatedVersion = true
  293. return c.dial(ctx)
  294. }
  295. return err
  296. case <-earlyConnChan:
  297. // ready to send 0-RTT data
  298. return nil
  299. case <-c.conn.HandshakeComplete().Done():
  300. // handshake successfully completed
  301. return nil
  302. }
  303. }