handshake-layer.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. package mint
  2. import (
  3. "fmt"
  4. "io"
  5. "net"
  6. )
  7. const (
  8. handshakeHeaderLenTLS = 4 // handshake message header length
  9. handshakeHeaderLenDTLS = 12 // handshake message header length
  10. maxHandshakeMessageLen = 1 << 24 // max handshake message length
  11. )
  12. // struct {
  13. // HandshakeType msg_type; /* handshake type */
  14. // uint24 length; /* bytes in message */
  15. // select (HandshakeType) {
  16. // ...
  17. // } body;
  18. // } Handshake;
  19. //
  20. // We do the select{...} part in a different layer, so we treat the
  21. // actual message body as opaque:
  22. //
  23. // struct {
  24. // HandshakeType msg_type;
  25. // opaque msg<0..2^24-1>
  26. // } Handshake;
  27. //
  28. type HandshakeMessage struct {
  29. msgType HandshakeType
  30. seq uint32
  31. body []byte
  32. datagram bool
  33. offset uint32 // Used for DTLS
  34. length uint32
  35. cipher *cipherState
  36. }
  37. // Note: This could be done with the `syntax` module, using the simplified
  38. // syntax as discussed above. However, since this is so simple, there's not
  39. // much benefit to doing so.
  40. // When datagram is set, we marshal this as a whole DTLS record.
  41. func (hm *HandshakeMessage) Marshal() []byte {
  42. if hm == nil {
  43. return []byte{}
  44. }
  45. fragLen := len(hm.body)
  46. var data []byte
  47. if hm.datagram {
  48. data = make([]byte, handshakeHeaderLenDTLS+fragLen)
  49. } else {
  50. data = make([]byte, handshakeHeaderLenTLS+fragLen)
  51. }
  52. tmp := data
  53. tmp = encodeUint(uint64(hm.msgType), 1, tmp)
  54. tmp = encodeUint(uint64(hm.length), 3, tmp)
  55. if hm.datagram {
  56. tmp = encodeUint(uint64(hm.seq), 2, tmp)
  57. tmp = encodeUint(uint64(hm.offset), 3, tmp)
  58. tmp = encodeUint(uint64(fragLen), 3, tmp)
  59. }
  60. copy(tmp, hm.body)
  61. return data
  62. }
  63. func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
  64. logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body)
  65. var body HandshakeMessageBody
  66. switch hm.msgType {
  67. case HandshakeTypeClientHello:
  68. body = new(ClientHelloBody)
  69. case HandshakeTypeServerHello:
  70. body = new(ServerHelloBody)
  71. case HandshakeTypeEncryptedExtensions:
  72. body = new(EncryptedExtensionsBody)
  73. case HandshakeTypeCertificate:
  74. body = new(CertificateBody)
  75. case HandshakeTypeCertificateRequest:
  76. body = new(CertificateRequestBody)
  77. case HandshakeTypeCertificateVerify:
  78. body = new(CertificateVerifyBody)
  79. case HandshakeTypeFinished:
  80. body = &FinishedBody{VerifyDataLen: len(hm.body)}
  81. case HandshakeTypeNewSessionTicket:
  82. body = new(NewSessionTicketBody)
  83. case HandshakeTypeKeyUpdate:
  84. body = new(KeyUpdateBody)
  85. case HandshakeTypeEndOfEarlyData:
  86. body = new(EndOfEarlyDataBody)
  87. default:
  88. return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
  89. }
  90. err := safeUnmarshal(body, hm.body)
  91. return body, err
  92. }
  93. func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
  94. data, err := body.Marshal()
  95. if err != nil {
  96. return nil, err
  97. }
  98. m := &HandshakeMessage{
  99. msgType: body.Type(),
  100. body: data,
  101. seq: h.msgSeq,
  102. datagram: h.datagram,
  103. length: uint32(len(data)),
  104. }
  105. h.msgSeq++
  106. return m, nil
  107. }
  108. type HandshakeLayer struct {
  109. ctx *HandshakeContext // The handshake we are attached to
  110. nonblocking bool // Should we operate in nonblocking mode
  111. conn *RecordLayer // Used for reading/writing records
  112. frame *frameReader // The buffered frame reader
  113. datagram bool // Is this DTLS?
  114. msgSeq uint32 // The DTLS message sequence number
  115. queued []*HandshakeMessage // In/out queue
  116. sent []*HandshakeMessage // Sent messages for DTLS
  117. recvdRecords []uint64 // Records we have received.
  118. maxFragmentLen int
  119. }
  120. type handshakeLayerFrameDetails struct {
  121. datagram bool
  122. }
  123. func (d handshakeLayerFrameDetails) headerLen() int {
  124. if d.datagram {
  125. return handshakeHeaderLenDTLS
  126. }
  127. return handshakeHeaderLenTLS
  128. }
  129. func (d handshakeLayerFrameDetails) defaultReadLen() int {
  130. return d.headerLen() + maxFragmentLen
  131. }
  132. func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
  133. logf(logTypeIO, "Header=%x", hdr)
  134. // The length of this fragment (as opposed to the message)
  135. // is always the last three bytes for both TLS and DTLS
  136. val, _ := decodeUint(hdr[len(hdr)-3:], 3)
  137. return int(val), nil
  138. }
  139. func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
  140. h := HandshakeLayer{}
  141. h.ctx = c
  142. h.conn = r
  143. h.datagram = false
  144. h.frame = newFrameReader(&handshakeLayerFrameDetails{false})
  145. h.maxFragmentLen = maxFragmentLen
  146. return &h
  147. }
  148. func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
  149. h := HandshakeLayer{}
  150. h.ctx = c
  151. h.conn = r
  152. h.datagram = true
  153. h.frame = newFrameReader(&handshakeLayerFrameDetails{true})
  154. h.maxFragmentLen = initialMtu // Not quite right
  155. return &h
  156. }
  157. func (h *HandshakeLayer) readRecord() error {
  158. logf(logTypeVerbose, "Trying to read record")
  159. pt, err := h.conn.readRecordAnyEpoch()
  160. if err != nil {
  161. return err
  162. }
  163. switch pt.contentType {
  164. case RecordTypeHandshake, RecordTypeAlert, RecordTypeAck:
  165. default:
  166. return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
  167. }
  168. if pt.contentType == RecordTypeAck {
  169. if !h.datagram {
  170. return fmt.Errorf("tls.handshakelayer: can't have ACK with TLS")
  171. }
  172. logf(logTypeIO, "read ACK")
  173. return h.ctx.processAck(pt.fragment)
  174. }
  175. if pt.contentType == RecordTypeAlert {
  176. logf(logTypeIO, "read alert %v", pt.fragment[1])
  177. if len(pt.fragment) < 2 {
  178. h.sendAlert(AlertUnexpectedMessage)
  179. return io.EOF
  180. }
  181. return Alert(pt.fragment[1])
  182. }
  183. assert(h.ctx.hIn.conn != nil)
  184. if pt.epoch != h.ctx.hIn.conn.cipher.epoch {
  185. // This is out of order but we're dropping it.
  186. // TODO(ekr@rtfm.com): If server, need to retransmit Finished.
  187. if pt.epoch == EpochClear || pt.epoch == EpochHandshakeData {
  188. return nil
  189. }
  190. // Anything else shouldn't happen.
  191. return AlertIllegalParameter
  192. }
  193. h.recvdRecords = append(h.recvdRecords, pt.seq)
  194. h.frame.addChunk(pt.fragment)
  195. return nil
  196. }
  197. // sendAlert sends a TLS alert message.
  198. func (h *HandshakeLayer) sendAlert(err Alert) error {
  199. tmp := make([]byte, 2)
  200. tmp[0] = AlertLevelError
  201. tmp[1] = byte(err)
  202. h.conn.WriteRecord(&TLSPlaintext{
  203. contentType: RecordTypeAlert,
  204. fragment: tmp},
  205. )
  206. // closeNotify is a special case in that it isn't an error:
  207. if err != AlertCloseNotify {
  208. return &net.OpError{Op: "local error", Err: err}
  209. }
  210. return nil
  211. }
  212. func (h *HandshakeLayer) noteMessageDelivered(seq uint32) {
  213. h.msgSeq = seq + 1
  214. var i int
  215. var m *HandshakeMessage
  216. for i, m = range h.queued {
  217. if m.seq > seq {
  218. break
  219. }
  220. }
  221. h.queued = h.queued[i:]
  222. }
  223. func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMessage, error) {
  224. if hm.seq < h.msgSeq {
  225. return nil, nil
  226. }
  227. // TODO(ekr@rtfm.com): Send an ACK immediately if we got something
  228. // out of order.
  229. h.ctx.receivedHandshakeMessage()
  230. if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
  231. // TODO(ekr@rtfm.com): Check the length?
  232. // This is complete.
  233. h.noteMessageDelivered(hm.seq)
  234. return hm, nil
  235. }
  236. // Now insert sorted.
  237. var i int
  238. for i = 0; i < len(h.queued); i++ {
  239. f := h.queued[i]
  240. if hm.seq < f.seq {
  241. break
  242. }
  243. if hm.offset < f.offset {
  244. break
  245. }
  246. }
  247. tmp := make([]*HandshakeMessage, 0, len(h.queued)+1)
  248. tmp = append(tmp, h.queued[:i]...)
  249. tmp = append(tmp, hm)
  250. tmp = append(tmp, h.queued[i:]...)
  251. h.queued = tmp
  252. return h.checkMessageAvailable()
  253. }
  254. func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) {
  255. if len(h.queued) == 0 {
  256. return nil, nil
  257. }
  258. hm := h.queued[0]
  259. if hm.seq != h.msgSeq {
  260. return nil, nil
  261. }
  262. if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
  263. // TODO(ekr@rtfm.com): Check the length?
  264. // This is complete.
  265. h.noteMessageDelivered(hm.seq)
  266. return hm, nil
  267. }
  268. // OK, this at least might complete the message.
  269. end := uint32(0)
  270. buf := make([]byte, hm.length)
  271. for _, f := range h.queued {
  272. // Out of fragments
  273. if f.seq > hm.seq {
  274. break
  275. }
  276. if f.length != uint32(len(buf)) {
  277. return nil, fmt.Errorf("Mismatched DTLS length")
  278. }
  279. if f.offset > end {
  280. break
  281. }
  282. if f.offset+uint32(len(f.body)) > end {
  283. // OK, this is adding something we don't know about
  284. copy(buf[f.offset:], f.body)
  285. end = f.offset + uint32(len(f.body))
  286. if end == hm.length {
  287. h2 := *hm
  288. h2.offset = 0
  289. h2.body = buf
  290. h.noteMessageDelivered(hm.seq)
  291. return &h2, nil
  292. }
  293. }
  294. }
  295. return nil, nil
  296. }
  297. func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
  298. var hdr, body []byte
  299. var err error
  300. hm, err := h.checkMessageAvailable()
  301. if err != nil {
  302. return nil, err
  303. }
  304. if hm != nil {
  305. return hm, nil
  306. }
  307. for {
  308. logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder))
  309. if h.frame.needed() > 0 {
  310. logf(logTypeVerbose, "Trying to read a new record")
  311. err = h.readRecord()
  312. if err != nil && (h.nonblocking || err != AlertWouldBlock) {
  313. return nil, err
  314. }
  315. }
  316. hdr, body, err = h.frame.process()
  317. if err == nil {
  318. break
  319. }
  320. if err != nil && (h.nonblocking || err != AlertWouldBlock) {
  321. return nil, err
  322. }
  323. }
  324. logf(logTypeHandshake, "read handshake message")
  325. hm = &HandshakeMessage{}
  326. hm.msgType = HandshakeType(hdr[0])
  327. hm.datagram = h.datagram
  328. hm.body = make([]byte, len(body))
  329. copy(hm.body, body)
  330. logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
  331. if h.datagram {
  332. tmp, hdr := decodeUint(hdr[1:], 3)
  333. hm.length = uint32(tmp)
  334. tmp, hdr = decodeUint(hdr, 2)
  335. hm.seq = uint32(tmp)
  336. tmp, hdr = decodeUint(hdr, 3)
  337. hm.offset = uint32(tmp)
  338. return h.newFragmentReceived(hm)
  339. }
  340. hm.length = uint32(len(body))
  341. return hm, nil
  342. }
  343. func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error {
  344. hm.cipher = h.conn.cipher
  345. h.queued = append(h.queued, hm)
  346. return nil
  347. }
  348. func (h *HandshakeLayer) SendQueuedMessages() (int, error) {
  349. logf(logTypeHandshake, "Sending outgoing messages")
  350. count, err := h.WriteMessages(h.queued)
  351. if !h.datagram {
  352. h.ClearQueuedMessages()
  353. }
  354. return count, err
  355. }
  356. func (h *HandshakeLayer) ClearQueuedMessages() {
  357. logf(logTypeHandshake, "Clearing outgoing hs message queue")
  358. h.queued = nil
  359. }
  360. func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (bool, int, error) {
  361. var buf []byte
  362. // Figure out if we're going to want the full header or just
  363. // the body
  364. hdrlen := 0
  365. if hm.datagram {
  366. hdrlen = handshakeHeaderLenDTLS
  367. } else if start == 0 {
  368. hdrlen = handshakeHeaderLenTLS
  369. }
  370. // Compute the amount of body we can fit in
  371. room -= hdrlen
  372. if room == 0 {
  373. // This works because we are doing one record per
  374. // message
  375. panic("Too short max fragment len")
  376. }
  377. bodylen := len(hm.body) - start
  378. if bodylen > room {
  379. bodylen = room
  380. }
  381. body := hm.body[start : start+bodylen]
  382. // Now see if this chunk has been ACKed. This doesn't produce ideal
  383. // retransmission but is simple.
  384. if h.ctx.fragmentAcked(hm.seq, start, bodylen) {
  385. logf(logTypeHandshake, "Fragment %v %v(%v) already acked. Skipping", hm.seq, start, bodylen)
  386. return false, start + bodylen, nil
  387. }
  388. // Encode the data.
  389. if hdrlen > 0 {
  390. hm2 := *hm
  391. hm2.offset = uint32(start)
  392. hm2.body = body
  393. buf = hm2.Marshal()
  394. hm = &hm2
  395. } else {
  396. buf = body
  397. }
  398. if h.datagram {
  399. // Remember that we sent this.
  400. h.ctx.sentFragments = append(h.ctx.sentFragments, &SentHandshakeFragment{
  401. hm.seq,
  402. start,
  403. len(body),
  404. h.conn.cipher.combineSeq(true),
  405. false,
  406. })
  407. }
  408. return true, start + bodylen, h.conn.writeRecordWithPadding(
  409. &TLSPlaintext{
  410. contentType: RecordTypeHandshake,
  411. fragment: buf,
  412. },
  413. hm.cipher, 0)
  414. }
  415. func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) (int, error) {
  416. start := int(0)
  417. if len(hm.body) > maxHandshakeMessageLen {
  418. return 0, fmt.Errorf("Tried to write a handshake message that's too long")
  419. }
  420. written := 0
  421. wrote := false
  422. // Always make one pass through to allow EOED (which is empty).
  423. for {
  424. var err error
  425. wrote, start, err = h.writeFragment(hm, start, h.maxFragmentLen)
  426. if err != nil {
  427. return 0, err
  428. }
  429. if wrote {
  430. written++
  431. }
  432. if start >= len(hm.body) {
  433. break
  434. }
  435. }
  436. return written, nil
  437. }
  438. func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) (int, error) {
  439. written := 0
  440. for _, hm := range hms {
  441. logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
  442. wrote, err := h.WriteMessage(hm)
  443. if err != nil {
  444. return 0, err
  445. }
  446. written += wrote
  447. }
  448. return written, nil
  449. }
  450. func encodeUint(v uint64, size int, out []byte) []byte {
  451. for i := size - 1; i >= 0; i-- {
  452. out[i] = byte(v & 0xff)
  453. v >>= 8
  454. }
  455. return out[size:]
  456. }
  457. func decodeUint(in []byte, size int) (uint64, []byte) {
  458. val := uint64(0)
  459. for i := 0; i < size; i++ {
  460. val <<= 8
  461. val += uint64(in[i])
  462. }
  463. return val, in[size:]
  464. }
  465. type marshalledPDU interface {
  466. Marshal() ([]byte, error)
  467. Unmarshal(data []byte) (int, error)
  468. }
  469. func safeUnmarshal(pdu marshalledPDU, data []byte) error {
  470. read, err := pdu.Unmarshal(data)
  471. if err != nil {
  472. return err
  473. }
  474. if len(data) != read {
  475. return fmt.Errorf("Invalid encoding: Extra data not consumed")
  476. }
  477. return nil
  478. }