client.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. package quic
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "strings"
  9. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  10. "github.com/Psiphon-Labs/quic-go/internal/utils"
  11. "github.com/Psiphon-Labs/quic-go/logging"
  12. )
  13. type client struct {
  14. conn sendConn
  15. // If the client is created with DialAddr, we create a packet conn.
  16. // If it is started with Dial, we take a packet conn as a parameter.
  17. createdPacketConn bool
  18. use0RTT bool
  19. packetHandlers packetHandlerManager
  20. tlsConf *tls.Config
  21. config *Config
  22. srcConnID protocol.ConnectionID
  23. destConnID protocol.ConnectionID
  24. initialPacketNumber protocol.PacketNumber
  25. hasNegotiatedVersion bool
  26. version protocol.VersionNumber
  27. handshakeChan chan struct{}
  28. session quicSession
  29. tracer logging.ConnectionTracer
  30. tracingID uint64
  31. logger utils.Logger
  32. }
  33. var (
  34. // make it possible to mock connection ID generation in the tests
  35. generateConnectionID = protocol.GenerateConnectionID
  36. generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
  37. )
  38. // DialAddr establishes a new QUIC connection to a server.
  39. // It uses a new UDP connection and closes this connection when the QUIC session is closed.
  40. // The hostname for SNI is taken from the given address.
  41. // The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites.
  42. func DialAddr(
  43. addr string,
  44. tlsConf *tls.Config,
  45. config *Config,
  46. ) (Session, error) {
  47. return DialAddrContext(context.Background(), addr, tlsConf, config)
  48. }
  49. // DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
  50. // It uses a new UDP connection and closes this connection when the QUIC session is closed.
  51. // The hostname for SNI is taken from the given address.
  52. // The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites.
  53. func DialAddrEarly(
  54. addr string,
  55. tlsConf *tls.Config,
  56. config *Config,
  57. ) (EarlySession, error) {
  58. return DialAddrEarlyContext(context.Background(), addr, tlsConf, config)
  59. }
  60. // DialAddrEarlyContext establishes a new 0-RTT QUIC connection to a server using provided context.
  61. // See DialAddrEarly for details
  62. func DialAddrEarlyContext(
  63. ctx context.Context,
  64. addr string,
  65. tlsConf *tls.Config,
  66. config *Config,
  67. ) (EarlySession, error) {
  68. sess, err := dialAddrContext(ctx, addr, tlsConf, config, true)
  69. if err != nil {
  70. return nil, err
  71. }
  72. utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early session")
  73. return sess, nil
  74. }
  75. // DialAddrContext establishes a new QUIC connection to a server using the provided context.
  76. // See DialAddr for details.
  77. func DialAddrContext(
  78. ctx context.Context,
  79. addr string,
  80. tlsConf *tls.Config,
  81. config *Config,
  82. ) (Session, error) {
  83. return dialAddrContext(ctx, addr, tlsConf, config, false)
  84. }
  85. func dialAddrContext(
  86. ctx context.Context,
  87. addr string,
  88. tlsConf *tls.Config,
  89. config *Config,
  90. use0RTT bool,
  91. ) (quicSession, error) {
  92. udpAddr, err := net.ResolveUDPAddr("udp", addr)
  93. if err != nil {
  94. return nil, err
  95. }
  96. udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
  97. if err != nil {
  98. return nil, err
  99. }
  100. return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true)
  101. }
  102. // Dial establishes a new QUIC connection to a server using a net.PacketConn. If
  103. // the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn
  104. // does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
  105. // and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
  106. // packets. The same PacketConn can be used for multiple calls to Dial and
  107. // Listen, QUIC connection IDs are used for demultiplexing the different
  108. // connections. The host parameter is used for SNI. The tls.Config must define
  109. // an application protocol (using NextProtos).
  110. func Dial(
  111. pconn net.PacketConn,
  112. remoteAddr net.Addr,
  113. host string,
  114. tlsConf *tls.Config,
  115. config *Config,
  116. ) (Session, error) {
  117. return dialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config, false, false)
  118. }
  119. // DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
  120. // The same PacketConn can be used for multiple calls to Dial and Listen,
  121. // QUIC connection IDs are used for demultiplexing the different connections.
  122. // The host parameter is used for SNI.
  123. // The tls.Config must define an application protocol (using NextProtos).
  124. func DialEarly(
  125. pconn net.PacketConn,
  126. remoteAddr net.Addr,
  127. host string,
  128. tlsConf *tls.Config,
  129. config *Config,
  130. ) (EarlySession, error) {
  131. return DialEarlyContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
  132. }
  133. // DialEarlyContext establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context.
  134. // See DialEarly for details.
  135. func DialEarlyContext(
  136. ctx context.Context,
  137. pconn net.PacketConn,
  138. remoteAddr net.Addr,
  139. host string,
  140. tlsConf *tls.Config,
  141. config *Config,
  142. ) (EarlySession, error) {
  143. return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, true, false)
  144. }
  145. // DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
  146. // See Dial for details.
  147. func DialContext(
  148. ctx context.Context,
  149. pconn net.PacketConn,
  150. remoteAddr net.Addr,
  151. host string,
  152. tlsConf *tls.Config,
  153. config *Config,
  154. ) (Session, error) {
  155. return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false, false)
  156. }
  157. func dialContext(
  158. ctx context.Context,
  159. pconn net.PacketConn,
  160. remoteAddr net.Addr,
  161. host string,
  162. tlsConf *tls.Config,
  163. config *Config,
  164. use0RTT bool,
  165. createdPacketConn bool,
  166. ) (quicSession, error) {
  167. if tlsConf == nil {
  168. return nil, errors.New("quic: tls.Config not set")
  169. }
  170. if err := validateConfig(config); err != nil {
  171. return nil, err
  172. }
  173. config = populateClientConfig(config, createdPacketConn)
  174. packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer)
  175. if err != nil {
  176. return nil, err
  177. }
  178. c, err := newClient(pconn, remoteAddr, config, tlsConf, host, use0RTT, createdPacketConn)
  179. if err != nil {
  180. return nil, err
  181. }
  182. c.packetHandlers = packetHandlers
  183. c.tracingID = nextSessionTracingID()
  184. if c.config.Tracer != nil {
  185. c.tracer = c.config.Tracer.TracerForConnection(
  186. context.WithValue(ctx, SessionTracingKey, c.tracingID),
  187. protocol.PerspectiveClient,
  188. c.destConnID,
  189. )
  190. }
  191. if c.tracer != nil {
  192. c.tracer.StartedConnection(c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID)
  193. }
  194. if err := c.dial(ctx); err != nil {
  195. return nil, err
  196. }
  197. return c.session, nil
  198. }
  199. func newClient(
  200. pconn net.PacketConn,
  201. remoteAddr net.Addr,
  202. config *Config,
  203. tlsConf *tls.Config,
  204. host string,
  205. use0RTT bool,
  206. createdPacketConn bool,
  207. ) (*client, error) {
  208. if tlsConf == nil {
  209. tlsConf = &tls.Config{}
  210. }
  211. if tlsConf.ServerName == "" {
  212. sni := host
  213. if strings.IndexByte(sni, ':') != -1 {
  214. var err error
  215. sni, _, err = net.SplitHostPort(sni)
  216. if err != nil {
  217. return nil, err
  218. }
  219. }
  220. tlsConf.ServerName = sni
  221. }
  222. // check that all versions are actually supported
  223. if config != nil {
  224. for _, v := range config.Versions {
  225. if !protocol.IsValidVersion(v) {
  226. return nil, fmt.Errorf("%s is not a valid QUIC version", v)
  227. }
  228. }
  229. }
  230. srcConnID, err := generateConnectionID(config.ConnectionIDLength)
  231. if err != nil {
  232. return nil, err
  233. }
  234. destConnID, err := generateConnectionIDForInitial()
  235. if err != nil {
  236. return nil, err
  237. }
  238. c := &client{
  239. srcConnID: srcConnID,
  240. destConnID: destConnID,
  241. conn: newSendPconn(pconn, remoteAddr),
  242. createdPacketConn: createdPacketConn,
  243. use0RTT: use0RTT,
  244. tlsConf: tlsConf,
  245. config: config,
  246. version: config.Versions[0],
  247. handshakeChan: make(chan struct{}),
  248. logger: utils.DefaultLogger.WithPrefix("client"),
  249. }
  250. return c, nil
  251. }
  252. func (c *client) dial(ctx context.Context) error {
  253. c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
  254. c.session = newClientSession(
  255. c.conn,
  256. c.packetHandlers,
  257. c.destConnID,
  258. c.srcConnID,
  259. c.config,
  260. c.tlsConf,
  261. c.initialPacketNumber,
  262. c.use0RTT,
  263. c.hasNegotiatedVersion,
  264. c.tracer,
  265. c.tracingID,
  266. c.logger,
  267. c.version,
  268. )
  269. c.packetHandlers.Add(c.srcConnID, c.session)
  270. errorChan := make(chan error, 1)
  271. go func() {
  272. err := c.session.run() // returns as soon as the session is closed
  273. if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn {
  274. c.packetHandlers.Destroy()
  275. }
  276. errorChan <- err
  277. }()
  278. // only set when we're using 0-RTT
  279. // Otherwise, earlySessionChan will be nil. Receiving from a nil chan blocks forever.
  280. var earlySessionChan <-chan struct{}
  281. if c.use0RTT {
  282. earlySessionChan = c.session.earlySessionReady()
  283. }
  284. select {
  285. case <-ctx.Done():
  286. c.session.shutdown()
  287. return ctx.Err()
  288. case err := <-errorChan:
  289. var recreateErr *errCloseForRecreating
  290. if errors.As(err, &recreateErr) {
  291. c.initialPacketNumber = recreateErr.nextPacketNumber
  292. c.version = recreateErr.nextVersion
  293. c.hasNegotiatedVersion = true
  294. return c.dial(ctx)
  295. }
  296. return err
  297. case <-earlySessionChan:
  298. // ready to send 0-RTT data
  299. return nil
  300. case <-c.session.HandshakeComplete().Done():
  301. // handshake successfully completed
  302. return nil
  303. }
  304. }