conn.go 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916
  1. package mint
  2. import (
  3. "crypto"
  4. "crypto/x509"
  5. "encoding/hex"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net"
  10. "reflect"
  11. "sync"
  12. "time"
  13. )
  14. type Certificate struct {
  15. Chain []*x509.Certificate
  16. PrivateKey crypto.Signer
  17. }
  18. type PreSharedKey struct {
  19. CipherSuite CipherSuite
  20. IsResumption bool
  21. Identity []byte
  22. Key []byte
  23. NextProto string
  24. ReceivedAt time.Time
  25. ExpiresAt time.Time
  26. TicketAgeAdd uint32
  27. }
  28. type PreSharedKeyCache interface {
  29. Get(string) (PreSharedKey, bool)
  30. Put(string, PreSharedKey)
  31. Size() int
  32. }
  33. // A CookieHandler can be used to give the application more fine-grained control over Cookies.
  34. // Generate receives the Conn as an argument, so the CookieHandler can decide when to send the cookie based on that, and offload state to the client by encoding that into the Cookie.
  35. // When the client echoes the Cookie, Validate is called. The application can then recover the state from the cookie.
  36. type CookieHandler interface {
  37. // Generate a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
  38. // If Generate returns nil, mint will not send a HelloRetryRequest.
  39. Generate(*Conn) ([]byte, error)
  40. // Validate is called when receiving a ClientHello containing a Cookie.
  41. // If validation failed, the handshake is aborted.
  42. Validate(*Conn, []byte) bool
  43. }
  44. type PSKMapCache map[string]PreSharedKey
  45. func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
  46. psk, ok = cache[key]
  47. return
  48. }
  49. func (cache *PSKMapCache) Put(key string, psk PreSharedKey) {
  50. (*cache)[key] = psk
  51. }
  52. func (cache PSKMapCache) Size() int {
  53. return len(cache)
  54. }
  55. // Config is the struct used to pass configuration settings to a TLS client or
  56. // server instance. The settings for client and server are pretty different,
  57. // but we just throw them all in here.
  58. type Config struct {
  59. // Client fields
  60. ServerName string
  61. // Server fields
  62. SendSessionTickets bool
  63. TicketLifetime uint32
  64. TicketLen int
  65. EarlyDataLifetime uint32
  66. AllowEarlyData bool
  67. // Require the client to echo a cookie.
  68. RequireCookie bool
  69. // A CookieHandler can be used to set and validate a cookie.
  70. // The cookie returned by the CookieHandler will be part of the cookie sent on the wire, and encoded using the CookieProtector.
  71. // If no CookieHandler is set, mint will always send a cookie.
  72. // The CookieHandler can be used to decide on a per-connection basis, if a cookie should be sent.
  73. CookieHandler CookieHandler
  74. // The CookieProtector is used to encrypt / decrypt cookies.
  75. // It should make sure that the Cookie cannot be read and tampered with by the client.
  76. // If non-blocking mode is used, and cookies are required, this field has to be set.
  77. // In blocking mode, a default cookie protector is used, if this is unused.
  78. CookieProtector CookieProtector
  79. // The ExtensionHandler is used to add custom extensions.
  80. ExtensionHandler AppExtensionHandler
  81. RequireClientAuth bool
  82. // Time returns the current time as the number of seconds since the epoch.
  83. // If Time is nil, TLS uses time.Now.
  84. Time func() time.Time
  85. // RootCAs defines the set of root certificate authorities
  86. // that clients use when verifying server certificates.
  87. // If RootCAs is nil, TLS uses the host's root CA set.
  88. RootCAs *x509.CertPool
  89. // InsecureSkipVerify controls whether a client verifies the
  90. // server's certificate chain and host name.
  91. // If InsecureSkipVerify is true, TLS accepts any certificate
  92. // presented by the server and any host name in that certificate.
  93. // In this mode, TLS is susceptible to man-in-the-middle attacks.
  94. // This should be used only for testing.
  95. InsecureSkipVerify bool
  96. // Shared fields
  97. Certificates []*Certificate
  98. // VerifyPeerCertificate, if not nil, is called after normal
  99. // certificate verification by either a TLS client or server. It
  100. // receives the raw ASN.1 certificates provided by the peer and also
  101. // any verified chains that normal processing found. If it returns a
  102. // non-nil error, the handshake is aborted and that error results.
  103. //
  104. // If normal verification fails then the handshake will abort before
  105. // considering this callback. If normal verification is disabled by
  106. // setting InsecureSkipVerify then this callback will be considered but
  107. // the verifiedChains argument will always be nil.
  108. VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
  109. CipherSuites []CipherSuite
  110. Groups []NamedGroup
  111. SignatureSchemes []SignatureScheme
  112. NextProtos []string
  113. PSKs PreSharedKeyCache
  114. PSKModes []PSKKeyExchangeMode
  115. NonBlocking bool
  116. UseDTLS bool
  117. // The same config object can be shared among different connections, so it
  118. // needs its own mutex
  119. mutex sync.RWMutex
  120. }
  121. // Clone returns a shallow clone of c. It is safe to clone a Config that is
  122. // being used concurrently by a TLS client or server.
  123. func (c *Config) Clone() *Config {
  124. c.mutex.Lock()
  125. defer c.mutex.Unlock()
  126. return &Config{
  127. ServerName: c.ServerName,
  128. SendSessionTickets: c.SendSessionTickets,
  129. TicketLifetime: c.TicketLifetime,
  130. TicketLen: c.TicketLen,
  131. EarlyDataLifetime: c.EarlyDataLifetime,
  132. AllowEarlyData: c.AllowEarlyData,
  133. RequireCookie: c.RequireCookie,
  134. CookieHandler: c.CookieHandler,
  135. CookieProtector: c.CookieProtector,
  136. ExtensionHandler: c.ExtensionHandler,
  137. RequireClientAuth: c.RequireClientAuth,
  138. Time: c.Time,
  139. RootCAs: c.RootCAs,
  140. InsecureSkipVerify: c.InsecureSkipVerify,
  141. Certificates: c.Certificates,
  142. VerifyPeerCertificate: c.VerifyPeerCertificate,
  143. CipherSuites: c.CipherSuites,
  144. Groups: c.Groups,
  145. SignatureSchemes: c.SignatureSchemes,
  146. NextProtos: c.NextProtos,
  147. PSKs: c.PSKs,
  148. PSKModes: c.PSKModes,
  149. NonBlocking: c.NonBlocking,
  150. UseDTLS: c.UseDTLS,
  151. }
  152. }
  153. func (c *Config) Init(isClient bool) error {
  154. c.mutex.Lock()
  155. defer c.mutex.Unlock()
  156. // Set defaults
  157. if len(c.CipherSuites) == 0 {
  158. c.CipherSuites = defaultSupportedCipherSuites
  159. }
  160. if len(c.Groups) == 0 {
  161. c.Groups = defaultSupportedGroups
  162. }
  163. if len(c.SignatureSchemes) == 0 {
  164. c.SignatureSchemes = defaultSignatureSchemes
  165. }
  166. if c.TicketLen == 0 {
  167. c.TicketLen = defaultTicketLen
  168. }
  169. if !reflect.ValueOf(c.PSKs).IsValid() {
  170. c.PSKs = &PSKMapCache{}
  171. }
  172. if len(c.PSKModes) == 0 {
  173. c.PSKModes = defaultPSKModes
  174. }
  175. return nil
  176. }
  177. func (c *Config) ValidForServer() bool {
  178. return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) ||
  179. (len(c.Certificates) > 0 &&
  180. len(c.Certificates[0].Chain) > 0 &&
  181. c.Certificates[0].PrivateKey != nil)
  182. }
  183. func (c *Config) ValidForClient() bool {
  184. return len(c.ServerName) > 0
  185. }
  186. func (c *Config) time() time.Time {
  187. t := c.Time
  188. if t == nil {
  189. t = time.Now
  190. }
  191. return t()
  192. }
  193. var (
  194. defaultSupportedCipherSuites = []CipherSuite{
  195. TLS_AES_128_GCM_SHA256,
  196. TLS_AES_256_GCM_SHA384,
  197. }
  198. defaultSupportedGroups = []NamedGroup{
  199. P256,
  200. P384,
  201. FFDHE2048,
  202. X25519,
  203. }
  204. defaultSignatureSchemes = []SignatureScheme{
  205. RSA_PSS_SHA256,
  206. RSA_PSS_SHA384,
  207. RSA_PSS_SHA512,
  208. ECDSA_P256_SHA256,
  209. ECDSA_P384_SHA384,
  210. ECDSA_P521_SHA512,
  211. }
  212. defaultTicketLen = 16
  213. defaultPSKModes = []PSKKeyExchangeMode{
  214. PSKModeKE,
  215. PSKModeDHEKE,
  216. }
  217. )
  218. type ConnectionState struct {
  219. HandshakeState State
  220. CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
  221. PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
  222. VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates
  223. NextProto string // Selected ALPN proto
  224. UsingPSK bool // Are we using PSK.
  225. UsingEarlyData bool // Did we negotiate 0-RTT.
  226. }
  227. // Conn implements the net.Conn interface, as with "crypto/tls"
  228. // * Read, Write, and Close are provided locally
  229. // * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn
  230. type Conn struct {
  231. config *Config
  232. conn net.Conn
  233. isClient bool
  234. state stateConnected
  235. hState HandshakeState
  236. handshakeMutex sync.Mutex
  237. handshakeAlert Alert
  238. handshakeComplete bool
  239. readBuffer []byte
  240. in, out *RecordLayer
  241. hsCtx *HandshakeContext
  242. }
  243. func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
  244. c := &Conn{conn: conn, config: config, isClient: isClient, hsCtx: &HandshakeContext{}}
  245. if !config.UseDTLS {
  246. c.in = NewRecordLayerTLS(c.conn, directionRead)
  247. c.out = NewRecordLayerTLS(c.conn, directionWrite)
  248. c.hsCtx.hIn = NewHandshakeLayerTLS(c.hsCtx, c.in)
  249. c.hsCtx.hOut = NewHandshakeLayerTLS(c.hsCtx, c.out)
  250. } else {
  251. c.in = NewRecordLayerDTLS(c.conn, directionRead)
  252. c.out = NewRecordLayerDTLS(c.conn, directionWrite)
  253. c.hsCtx.hIn = NewHandshakeLayerDTLS(c.hsCtx, c.in)
  254. c.hsCtx.hOut = NewHandshakeLayerDTLS(c.hsCtx, c.out)
  255. c.hsCtx.timeoutMS = initialTimeout
  256. c.hsCtx.timers = newTimerSet()
  257. c.hsCtx.waitingNextFlight = true
  258. }
  259. c.in.label = c.label()
  260. c.out.label = c.label()
  261. c.hsCtx.hIn.nonblocking = c.config.NonBlocking
  262. return c
  263. }
  264. // Read up
  265. func (c *Conn) consumeRecord() error {
  266. pt, err := c.in.ReadRecord()
  267. if pt == nil {
  268. logf(logTypeIO, "extendBuffer returns error %v", err)
  269. return err
  270. }
  271. switch pt.contentType {
  272. case RecordTypeHandshake:
  273. logf(logTypeHandshake, "Received post-handshake message")
  274. // We do not support fragmentation of post-handshake handshake messages.
  275. // TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
  276. start := 0
  277. headerLen := handshakeHeaderLenTLS
  278. if c.config.UseDTLS {
  279. headerLen = handshakeHeaderLenDTLS
  280. }
  281. for start < len(pt.fragment) {
  282. if len(pt.fragment[start:]) < headerLen {
  283. return fmt.Errorf("Post-handshake handshake message too short for header")
  284. }
  285. hm := &HandshakeMessage{}
  286. hm.msgType = HandshakeType(pt.fragment[start])
  287. hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
  288. if len(pt.fragment[start+headerLen:]) < hmLen {
  289. return fmt.Errorf("Post-handshake handshake message too short for body")
  290. }
  291. hm.body = pt.fragment[start+headerLen : start+headerLen+hmLen]
  292. // XXX: If we want to support more advanced cases, e.g., post-handshake
  293. // authentication, we'll need to allow transitions other than
  294. // Connected -> Connected
  295. state, actions, alert := c.state.ProcessMessage(hm)
  296. if alert != AlertNoAlert {
  297. logf(logTypeHandshake, "Error in state transition: %v", alert)
  298. c.sendAlert(alert)
  299. return io.EOF
  300. }
  301. for _, action := range actions {
  302. alert = c.takeAction(action)
  303. if alert != AlertNoAlert {
  304. logf(logTypeHandshake, "Error during handshake actions: %v", alert)
  305. c.sendAlert(alert)
  306. return io.EOF
  307. }
  308. }
  309. var connected bool
  310. c.state, connected = state.(stateConnected)
  311. if !connected {
  312. logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
  313. c.sendAlert(alert)
  314. return io.EOF
  315. }
  316. start += headerLen + hmLen
  317. }
  318. case RecordTypeAlert:
  319. logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
  320. if len(pt.fragment) != 2 {
  321. c.sendAlert(AlertUnexpectedMessage)
  322. return io.EOF
  323. }
  324. if Alert(pt.fragment[1]) == AlertCloseNotify {
  325. return io.EOF
  326. }
  327. switch pt.fragment[0] {
  328. case AlertLevelWarning:
  329. // drop on the floor
  330. case AlertLevelError:
  331. return Alert(pt.fragment[1])
  332. default:
  333. c.sendAlert(AlertUnexpectedMessage)
  334. return io.EOF
  335. }
  336. case RecordTypeAck:
  337. if !c.hsCtx.hIn.datagram {
  338. logf(logTypeHandshake, "Received ACK in TLS mode")
  339. return AlertUnexpectedMessage
  340. }
  341. return c.hsCtx.processAck(pt.fragment)
  342. case RecordTypeApplicationData:
  343. c.readBuffer = append(c.readBuffer, pt.fragment...)
  344. logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
  345. }
  346. return err
  347. }
  348. func readPartial(in *[]byte, buffer []byte) int {
  349. logf(logTypeIO, "conn.Read input buffer now has len %d", len((*in)))
  350. read := copy(buffer, *in)
  351. *in = (*in)[read:]
  352. logf(logTypeVerbose, "Returning %v", string(buffer))
  353. return read
  354. }
  355. // Read application data up to the size of buffer. Handshake and alert records
  356. // are consumed by the Conn object directly.
  357. func (c *Conn) Read(buffer []byte) (int, error) {
  358. if _, connected := c.hState.(stateConnected); !connected {
  359. // Clients can't call Read prior to handshake completion.
  360. if c.isClient {
  361. return 0, errors.New("Read called before the handshake completed")
  362. }
  363. // Neither can servers that don't allow early data.
  364. if !c.config.AllowEarlyData {
  365. return 0, errors.New("Read called before the handshake completed")
  366. }
  367. // If there's no early data, then return WouldBlock
  368. if len(c.hsCtx.earlyData) == 0 {
  369. return 0, AlertWouldBlock
  370. }
  371. return readPartial(&c.hsCtx.earlyData, buffer), nil
  372. }
  373. // The handshake is now connected.
  374. logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
  375. if alert := c.Handshake(); alert != AlertNoAlert {
  376. return 0, alert
  377. }
  378. if len(buffer) == 0 {
  379. return 0, nil
  380. }
  381. // Run our timers.
  382. if c.config.UseDTLS {
  383. if err := c.hsCtx.timers.check(time.Now()); err != nil {
  384. return 0, AlertInternalError
  385. }
  386. }
  387. // Lock the input channel
  388. c.in.Lock()
  389. defer c.in.Unlock()
  390. for len(c.readBuffer) == 0 {
  391. err := c.consumeRecord()
  392. // err can be nil if consumeRecord processed a non app-data
  393. // record.
  394. if err != nil {
  395. if c.config.NonBlocking || err != AlertWouldBlock {
  396. logf(logTypeIO, "conn.Read returns err=%v", err)
  397. return 0, err
  398. }
  399. }
  400. }
  401. return readPartial(&c.readBuffer, buffer), nil
  402. }
  403. // Write application data
  404. func (c *Conn) Write(buffer []byte) (int, error) {
  405. // Lock the output channel
  406. c.out.Lock()
  407. defer c.out.Unlock()
  408. if !c.Writable() {
  409. return 0, errors.New("Write called before the handshake completed (and early data not in use)")
  410. }
  411. // Send full-size fragments
  412. var start int
  413. sent := 0
  414. for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
  415. err := c.out.WriteRecord(&TLSPlaintext{
  416. contentType: RecordTypeApplicationData,
  417. fragment: buffer[start : start+maxFragmentLen],
  418. })
  419. if err != nil {
  420. return sent, err
  421. }
  422. sent += maxFragmentLen
  423. }
  424. // Send a final partial fragment if necessary
  425. if start < len(buffer) {
  426. err := c.out.WriteRecord(&TLSPlaintext{
  427. contentType: RecordTypeApplicationData,
  428. fragment: buffer[start:],
  429. })
  430. if err != nil {
  431. return sent, err
  432. }
  433. sent += len(buffer[start:])
  434. }
  435. return sent, nil
  436. }
  437. // sendAlert sends a TLS alert message.
  438. // c.out.Mutex <= L.
  439. func (c *Conn) sendAlert(err Alert) error {
  440. c.handshakeMutex.Lock()
  441. defer c.handshakeMutex.Unlock()
  442. var level int
  443. switch err {
  444. case AlertNoRenegotiation, AlertCloseNotify:
  445. level = AlertLevelWarning
  446. default:
  447. level = AlertLevelError
  448. }
  449. buf := []byte{byte(err), byte(level)}
  450. c.out.WriteRecord(&TLSPlaintext{
  451. contentType: RecordTypeAlert,
  452. fragment: buf,
  453. })
  454. // close_notify and end_of_early_data are not actually errors
  455. if level == AlertLevelWarning {
  456. return &net.OpError{Op: "local error", Err: err}
  457. }
  458. return c.Close()
  459. }
  460. // Close closes the connection.
  461. func (c *Conn) Close() error {
  462. // XXX crypto/tls has an interlock with Write here. Do we need that?
  463. return c.conn.Close()
  464. }
  465. // LocalAddr returns the local network address.
  466. func (c *Conn) LocalAddr() net.Addr {
  467. return c.conn.LocalAddr()
  468. }
  469. // RemoteAddr returns the remote network address.
  470. func (c *Conn) RemoteAddr() net.Addr {
  471. return c.conn.RemoteAddr()
  472. }
  473. // SetDeadline sets the read and write deadlines associated with the connection.
  474. // A zero value for t means Read and Write will not time out.
  475. // After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
  476. func (c *Conn) SetDeadline(t time.Time) error {
  477. return c.conn.SetDeadline(t)
  478. }
  479. // SetReadDeadline sets the read deadline on the underlying connection.
  480. // A zero value for t means Read will not time out.
  481. func (c *Conn) SetReadDeadline(t time.Time) error {
  482. return c.conn.SetReadDeadline(t)
  483. }
  484. // SetWriteDeadline sets the write deadline on the underlying connection.
  485. // A zero value for t means Write will not time out.
  486. // After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
  487. func (c *Conn) SetWriteDeadline(t time.Time) error {
  488. return c.conn.SetWriteDeadline(t)
  489. }
  490. func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
  491. label := "[server]"
  492. if c.isClient {
  493. label = "[client]"
  494. }
  495. switch action := actionGeneric.(type) {
  496. case QueueHandshakeMessage:
  497. logf(logTypeHandshake, "%s queuing handshake message type=%v", label, action.Message.msgType)
  498. err := c.hsCtx.hOut.QueueMessage(action.Message)
  499. if err != nil {
  500. logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
  501. return AlertInternalError
  502. }
  503. case SendQueuedHandshake:
  504. _, err := c.hsCtx.hOut.SendQueuedMessages()
  505. if err != nil {
  506. logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
  507. return AlertInternalError
  508. }
  509. if c.config.UseDTLS {
  510. c.hsCtx.timers.start(retransmitTimerLabel,
  511. c.hsCtx.handshakeRetransmit,
  512. c.hsCtx.timeoutMS)
  513. }
  514. case RekeyIn:
  515. logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.epoch.label(), action.KeySet)
  516. err := c.in.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
  517. if err != nil {
  518. logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
  519. return AlertInternalError
  520. }
  521. case RekeyOut:
  522. logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.epoch.label(), action.KeySet)
  523. err := c.out.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
  524. if err != nil {
  525. logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
  526. return AlertInternalError
  527. }
  528. case ResetOut:
  529. logf(logTypeHandshake, "%s Rekeying out to %s seq=%v", label, EpochClear, action.seq)
  530. c.out.ResetClear(action.seq)
  531. case StorePSK:
  532. logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
  533. if c.isClient {
  534. // Clients look up PSKs based on server name
  535. c.config.PSKs.Put(c.config.ServerName, action.PSK)
  536. } else {
  537. // Servers look them up based on the identity in the extension
  538. c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK)
  539. }
  540. default:
  541. logf(logTypeHandshake, "%s Unknown action type", label)
  542. assert(false)
  543. return AlertInternalError
  544. }
  545. return AlertNoAlert
  546. }
  547. func (c *Conn) HandshakeSetup() Alert {
  548. var state HandshakeState
  549. var actions []HandshakeAction
  550. var alert Alert
  551. if err := c.config.Init(c.isClient); err != nil {
  552. logf(logTypeHandshake, "Error initializing config: %v", err)
  553. return AlertInternalError
  554. }
  555. opts := ConnectionOptions{
  556. ServerName: c.config.ServerName,
  557. NextProtos: c.config.NextProtos,
  558. }
  559. if c.isClient {
  560. state, actions, alert = clientStateStart{Config: c.config, Opts: opts, hsCtx: c.hsCtx}.Next(nil)
  561. if alert != AlertNoAlert {
  562. logf(logTypeHandshake, "Error initializing client state: %v", alert)
  563. return alert
  564. }
  565. for _, action := range actions {
  566. alert = c.takeAction(action)
  567. if alert != AlertNoAlert {
  568. logf(logTypeHandshake, "Error during handshake actions: %v", alert)
  569. return alert
  570. }
  571. }
  572. } else {
  573. if c.config.RequireCookie && c.config.CookieProtector == nil {
  574. logf(logTypeHandshake, "RequireCookie set, but no CookieProtector provided. Using default cookie protector. Stateless Retry not possible.")
  575. if c.config.NonBlocking {
  576. logf(logTypeHandshake, "Not possible in non-blocking mode.")
  577. return AlertInternalError
  578. }
  579. var err error
  580. c.config.CookieProtector, err = NewDefaultCookieProtector()
  581. if err != nil {
  582. logf(logTypeHandshake, "Error initializing cookie source: %v", alert)
  583. return AlertInternalError
  584. }
  585. }
  586. state = serverStateStart{Config: c.config, conn: c, hsCtx: c.hsCtx}
  587. }
  588. c.hState = state
  589. return AlertNoAlert
  590. }
  591. type handshakeMessageReader interface {
  592. ReadMessage() (*HandshakeMessage, Alert)
  593. }
  594. type handshakeMessageReaderImpl struct {
  595. hsCtx *HandshakeContext
  596. }
  597. var _ handshakeMessageReader = &handshakeMessageReaderImpl{}
  598. func (r *handshakeMessageReaderImpl) ReadMessage() (*HandshakeMessage, Alert) {
  599. var hm *HandshakeMessage
  600. var err error
  601. for {
  602. hm, err = r.hsCtx.hIn.ReadMessage()
  603. if err == AlertWouldBlock {
  604. return nil, AlertWouldBlock
  605. }
  606. if err != nil {
  607. logf(logTypeHandshake, "Error reading message: %v", err)
  608. return nil, AlertCloseNotify
  609. }
  610. if hm != nil {
  611. break
  612. }
  613. }
  614. return hm, AlertNoAlert
  615. }
  616. // Handshake causes a TLS handshake on the connection. The `isClient` member
  617. // determines whether a client or server handshake is performed. If a
  618. // handshake has already been performed, then its result will be returned.
  619. func (c *Conn) Handshake() Alert {
  620. label := "[server]"
  621. if c.isClient {
  622. label = "[client]"
  623. }
  624. // TODO Lock handshakeMutex
  625. // TODO Remove CloseNotify hack
  626. if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify {
  627. logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert)
  628. return c.handshakeAlert
  629. }
  630. if c.handshakeComplete {
  631. return AlertNoAlert
  632. }
  633. if c.hState == nil {
  634. logf(logTypeHandshake, "%s First time through handshake (or after stateless retry), setting up", label)
  635. alert := c.HandshakeSetup()
  636. if alert != AlertNoAlert || (c.isClient && c.config.NonBlocking) {
  637. return alert
  638. }
  639. }
  640. logf(logTypeHandshake, "(Re-)entering handshake, state=%v", c.hState)
  641. state := c.hState
  642. _, connected := state.(stateConnected)
  643. hmr := &handshakeMessageReaderImpl{hsCtx: c.hsCtx}
  644. for !connected {
  645. var alert Alert
  646. var actions []HandshakeAction
  647. // Advance the state machine
  648. state, actions, alert = state.Next(hmr)
  649. if alert == AlertWouldBlock {
  650. logf(logTypeHandshake, "%s Would block reading message: %s", label, alert)
  651. // If we blocked, then run our timers to see if any have expired.
  652. if c.hsCtx.hIn.datagram {
  653. if err := c.hsCtx.timers.check(time.Now()); err != nil {
  654. return AlertInternalError
  655. }
  656. }
  657. return AlertWouldBlock
  658. }
  659. if alert == AlertCloseNotify {
  660. logf(logTypeHandshake, "%s Error reading message: %s", label, alert)
  661. c.sendAlert(AlertCloseNotify)
  662. return AlertCloseNotify
  663. }
  664. if alert != AlertNoAlert && alert != AlertStatelessRetry {
  665. logf(logTypeHandshake, "Error in state transition: %v", alert)
  666. return alert
  667. }
  668. for index, action := range actions {
  669. logf(logTypeHandshake, "%s taking next action (%d)", label, index)
  670. if alert := c.takeAction(action); alert != AlertNoAlert {
  671. logf(logTypeHandshake, "Error during handshake actions: %v", alert)
  672. c.sendAlert(alert)
  673. return alert
  674. }
  675. }
  676. c.hState = state
  677. logf(logTypeHandshake, "state is now %s", c.GetHsState())
  678. _, connected = state.(stateConnected)
  679. if connected {
  680. c.state = state.(stateConnected)
  681. c.handshakeComplete = true
  682. if !c.isClient {
  683. // Send NewSessionTicket if configured to
  684. if c.config.SendSessionTickets {
  685. actions, alert := c.state.NewSessionTicket(
  686. c.config.TicketLen,
  687. c.config.TicketLifetime,
  688. c.config.EarlyDataLifetime)
  689. for _, action := range actions {
  690. alert = c.takeAction(action)
  691. if alert != AlertNoAlert {
  692. logf(logTypeHandshake, "Error during handshake actions: %v", alert)
  693. c.sendAlert(alert)
  694. return alert
  695. }
  696. }
  697. }
  698. // If there is early data, move it into the main buffer
  699. if c.hsCtx.earlyData != nil {
  700. c.readBuffer = c.hsCtx.earlyData
  701. c.hsCtx.earlyData = nil
  702. }
  703. } else {
  704. assert(c.hsCtx.earlyData == nil)
  705. }
  706. }
  707. if c.config.NonBlocking {
  708. if alert == AlertStatelessRetry {
  709. return AlertStatelessRetry
  710. }
  711. return AlertNoAlert
  712. }
  713. }
  714. return AlertNoAlert
  715. }
  716. func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
  717. if !c.handshakeComplete {
  718. return fmt.Errorf("Cannot update keys until after handshake")
  719. }
  720. request := KeyUpdateNotRequested
  721. if requestUpdate {
  722. request = KeyUpdateRequested
  723. }
  724. // Create the key update and update state
  725. actions, alert := c.state.KeyUpdate(request)
  726. if alert != AlertNoAlert {
  727. c.sendAlert(alert)
  728. return fmt.Errorf("Alert while generating key update: %v", alert)
  729. }
  730. // Take actions (send key update and rekey)
  731. for _, action := range actions {
  732. alert = c.takeAction(action)
  733. if alert != AlertNoAlert {
  734. c.sendAlert(alert)
  735. return fmt.Errorf("Alert during key update actions: %v", alert)
  736. }
  737. }
  738. return nil
  739. }
  740. func (c *Conn) GetHsState() State {
  741. if c.hState == nil {
  742. return StateInit
  743. }
  744. return c.hState.State()
  745. }
  746. func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
  747. _, connected := c.hState.(stateConnected)
  748. if !connected {
  749. return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
  750. }
  751. if c.state.exporterSecret == nil {
  752. return nil, fmt.Errorf("Internal error: no exporter secret")
  753. }
  754. h0 := c.state.cryptoParams.Hash.New().Sum(nil)
  755. tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0)
  756. hc := c.state.cryptoParams.Hash.New().Sum(context)
  757. return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
  758. }
  759. func (c *Conn) ConnectionState() ConnectionState {
  760. state := ConnectionState{
  761. HandshakeState: c.GetHsState(),
  762. }
  763. if c.handshakeComplete {
  764. state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
  765. state.NextProto = c.state.Params.NextProto
  766. state.VerifiedChains = c.state.verifiedChains
  767. state.PeerCertificates = c.state.peerCertificates
  768. state.UsingPSK = c.state.Params.UsingPSK
  769. state.UsingEarlyData = c.state.Params.UsingEarlyData
  770. }
  771. return state
  772. }
  773. func (c *Conn) Writable() bool {
  774. // If we're connected, we're writable.
  775. if _, connected := c.hState.(stateConnected); connected {
  776. return true
  777. }
  778. // If we're a client in 0-RTT, then we're writable.
  779. if c.isClient && c.out.cipher.epoch == EpochEarlyData {
  780. return true
  781. }
  782. return false
  783. }
  784. func (c *Conn) label() string {
  785. if c.isClient {
  786. return "client"
  787. }
  788. return "server"
  789. }