record-layer.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. package mint
  2. import (
  3. "crypto/cipher"
  4. "fmt"
  5. "io"
  6. "sync"
  7. )
  8. const (
  9. sequenceNumberLen = 8 // sequence number length
  10. recordHeaderLenTLS = 5 // record header length (TLS)
  11. recordHeaderLenDTLS = 13 // record header length (DTLS)
  12. maxFragmentLen = 1 << 14 // max number of bytes in a record
  13. )
  14. type DecryptError string
  15. func (err DecryptError) Error() string {
  16. return string(err)
  17. }
  18. type direction uint8
  19. const (
  20. directionWrite = direction(1)
  21. directionRead = direction(2)
  22. )
  23. // struct {
  24. // ContentType type;
  25. // ProtocolVersion record_version [0301 for CH, 0303 for others]
  26. // uint16 length;
  27. // opaque fragment[TLSPlaintext.length];
  28. // } TLSPlaintext;
  29. type TLSPlaintext struct {
  30. // Omitted: record_version (static)
  31. // Omitted: length (computed from fragment)
  32. contentType RecordType
  33. epoch Epoch
  34. seq uint64
  35. fragment []byte
  36. }
  37. type cipherState struct {
  38. epoch Epoch // DTLS epoch
  39. ivLength int // Length of the seq and nonce fields
  40. seq uint64 // Zero-padded sequence number
  41. iv []byte // Buffer for the IV
  42. cipher cipher.AEAD // AEAD cipher
  43. }
  44. type RecordLayer struct {
  45. sync.Mutex
  46. label string
  47. direction direction
  48. version uint16 // The current version number
  49. conn io.ReadWriter // The underlying connection
  50. frame *frameReader // The buffered frame reader
  51. nextData []byte // The next record to send
  52. cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
  53. cachedError error // Error on the last record read
  54. cipher *cipherState
  55. readCiphers map[Epoch]*cipherState
  56. datagram bool
  57. }
  58. type recordLayerFrameDetails struct {
  59. datagram bool
  60. }
  61. func (d recordLayerFrameDetails) headerLen() int {
  62. if d.datagram {
  63. return recordHeaderLenDTLS
  64. }
  65. return recordHeaderLenTLS
  66. }
  67. func (d recordLayerFrameDetails) defaultReadLen() int {
  68. return d.headerLen() + maxFragmentLen
  69. }
  70. func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
  71. return (int(hdr[d.headerLen()-2]) << 8) | int(hdr[d.headerLen()-1]), nil
  72. }
  73. func newCipherStateNull() *cipherState {
  74. return &cipherState{EpochClear, 0, 0, nil, nil}
  75. }
  76. func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) (*cipherState, error) {
  77. cipher, err := factory(key)
  78. if err != nil {
  79. return nil, err
  80. }
  81. return &cipherState{epoch, len(iv), 0, iv, cipher}, nil
  82. }
  83. func NewRecordLayerTLS(conn io.ReadWriter, dir direction) *RecordLayer {
  84. r := RecordLayer{}
  85. r.label = ""
  86. r.direction = dir
  87. r.conn = conn
  88. r.frame = newFrameReader(recordLayerFrameDetails{false})
  89. r.cipher = newCipherStateNull()
  90. r.version = tls10Version
  91. return &r
  92. }
  93. func NewRecordLayerDTLS(conn io.ReadWriter, dir direction) *RecordLayer {
  94. r := RecordLayer{}
  95. r.label = ""
  96. r.direction = dir
  97. r.conn = conn
  98. r.frame = newFrameReader(recordLayerFrameDetails{true})
  99. r.cipher = newCipherStateNull()
  100. r.readCiphers = make(map[Epoch]*cipherState, 0)
  101. r.readCiphers[0] = r.cipher
  102. r.datagram = true
  103. return &r
  104. }
  105. func (r *RecordLayer) SetVersion(v uint16) {
  106. r.version = v
  107. }
  108. func (r *RecordLayer) ResetClear(seq uint64) {
  109. r.cipher = newCipherStateNull()
  110. r.cipher.seq = seq
  111. }
  112. func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error {
  113. cipher, err := newCipherStateAead(epoch, factory, key, iv)
  114. if err != nil {
  115. return err
  116. }
  117. r.cipher = cipher
  118. if r.datagram && r.direction == directionRead {
  119. r.readCiphers[epoch] = cipher
  120. }
  121. return nil
  122. }
  123. // TODO(ekr@rtfm.com): This is never used, which is a bug.
  124. func (r *RecordLayer) DiscardReadKey(epoch Epoch) {
  125. if !r.datagram {
  126. return
  127. }
  128. _, ok := r.readCiphers[epoch]
  129. assert(ok)
  130. delete(r.readCiphers, epoch)
  131. }
  132. func (c *cipherState) combineSeq(datagram bool) uint64 {
  133. seq := c.seq
  134. if datagram {
  135. seq |= uint64(c.epoch) << 48
  136. }
  137. return seq
  138. }
  139. func (c *cipherState) computeNonce(seq uint64) []byte {
  140. nonce := make([]byte, len(c.iv))
  141. copy(nonce, c.iv)
  142. s := seq
  143. offset := len(c.iv)
  144. for i := 0; i < 8; i++ {
  145. nonce[(offset-i)-1] ^= byte(s & 0xff)
  146. s >>= 8
  147. }
  148. logf(logTypeCrypto, "Computing nonce for sequence # %x -> %x", seq, nonce)
  149. return nonce
  150. }
  151. func (c *cipherState) incrementSequenceNumber() {
  152. if c.seq >= (1<<48 - 1) {
  153. // Not allowed to let sequence number wrap.
  154. // Instead, must renegotiate before it does.
  155. // Not likely enough to bother. This is the
  156. // DTLS limit.
  157. panic("TLS: sequence number wraparound")
  158. }
  159. c.seq++
  160. }
  161. func (c *cipherState) overhead() int {
  162. if c.cipher == nil {
  163. return 0
  164. }
  165. return c.cipher.Overhead()
  166. }
  167. func (r *RecordLayer) encrypt(cipher *cipherState, seq uint64, pt *TLSPlaintext, padLen int) *TLSPlaintext {
  168. assert(r.direction == directionWrite)
  169. logf(logTypeIO, "%s Encrypt seq=[%x]", r.label, seq)
  170. // Expand the fragment to hold contentType, padding, and overhead
  171. originalLen := len(pt.fragment)
  172. plaintextLen := originalLen + 1 + padLen
  173. ciphertextLen := plaintextLen + cipher.overhead()
  174. // Assemble the revised plaintext
  175. out := &TLSPlaintext{
  176. contentType: RecordTypeApplicationData,
  177. fragment: make([]byte, ciphertextLen),
  178. }
  179. copy(out.fragment, pt.fragment)
  180. out.fragment[originalLen] = byte(pt.contentType)
  181. for i := 1; i <= padLen; i++ {
  182. out.fragment[originalLen+i] = 0
  183. }
  184. // Encrypt the fragment
  185. payload := out.fragment[:plaintextLen]
  186. cipher.cipher.Seal(payload[:0], cipher.computeNonce(seq), payload, nil)
  187. return out
  188. }
  189. func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq uint64) (*TLSPlaintext, int, error) {
  190. assert(r.direction == directionRead)
  191. logf(logTypeIO, "%s Decrypt seq=[%x]", r.label, seq)
  192. if len(pt.fragment) < r.cipher.overhead() {
  193. msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead())
  194. return nil, 0, DecryptError(msg)
  195. }
  196. decryptLen := len(pt.fragment) - r.cipher.overhead()
  197. out := &TLSPlaintext{
  198. contentType: pt.contentType,
  199. fragment: make([]byte, decryptLen),
  200. }
  201. // Decrypt
  202. _, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, nil)
  203. if err != nil {
  204. logf(logTypeIO, "%s AEAD decryption failure [%x]", r.label, pt)
  205. return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
  206. }
  207. // Find the padding boundary
  208. padLen := 0
  209. for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ {
  210. }
  211. // Transfer the content type
  212. newLen := decryptLen - padLen - 1
  213. out.contentType = RecordType(out.fragment[newLen])
  214. // Truncate the message to remove contentType, padding, overhead
  215. out.fragment = out.fragment[:newLen]
  216. out.seq = seq
  217. return out, padLen, nil
  218. }
  219. func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
  220. var pt *TLSPlaintext
  221. var err error
  222. for {
  223. pt, err = r.nextRecord(false)
  224. if err == nil {
  225. break
  226. }
  227. if !block || err != AlertWouldBlock {
  228. return 0, err
  229. }
  230. }
  231. return pt.contentType, nil
  232. }
  233. func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
  234. pt, err := r.nextRecord(false)
  235. // Consume the cached record if there was one
  236. r.cachedRecord = nil
  237. r.cachedError = nil
  238. return pt, err
  239. }
  240. func (r *RecordLayer) readRecordAnyEpoch() (*TLSPlaintext, error) {
  241. pt, err := r.nextRecord(true)
  242. // Consume the cached record if there was one
  243. r.cachedRecord = nil
  244. r.cachedError = nil
  245. return pt, err
  246. }
  247. func (r *RecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, error) {
  248. cipher := r.cipher
  249. if r.cachedRecord != nil {
  250. logf(logTypeIO, "%s Returning cached record", r.label)
  251. return r.cachedRecord, r.cachedError
  252. }
  253. // Loop until one of three things happens:
  254. //
  255. // 1. We get a frame
  256. // 2. We try to read off the socket and get nothing, in which case
  257. // returnAlertWouldBlock
  258. // 3. We get an error.
  259. var err error
  260. err = AlertWouldBlock
  261. var header, body []byte
  262. for err != nil {
  263. if r.frame.needed() > 0 {
  264. buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen)
  265. n, err := r.conn.Read(buf)
  266. if err != nil {
  267. logf(logTypeIO, "%s Error reading, %v", r.label, err)
  268. return nil, err
  269. }
  270. if n == 0 {
  271. return nil, AlertWouldBlock
  272. }
  273. logf(logTypeIO, "%s Read %v bytes", r.label, n)
  274. buf = buf[:n]
  275. r.frame.addChunk(buf)
  276. }
  277. header, body, err = r.frame.process()
  278. // Loop around onAlertWouldBlock to see if some
  279. // data is now available.
  280. if err != nil && err != AlertWouldBlock {
  281. return nil, err
  282. }
  283. }
  284. pt := &TLSPlaintext{}
  285. // Validate content type
  286. switch RecordType(header[0]) {
  287. default:
  288. return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
  289. case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData, RecordTypeAck:
  290. pt.contentType = RecordType(header[0])
  291. }
  292. // Validate version
  293. if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) {
  294. return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2])
  295. }
  296. // Validate size < max
  297. size := (int(header[len(header)-2]) << 8) + int(header[len(header)-1])
  298. if size > maxFragmentLen+256 {
  299. return nil, fmt.Errorf("tls.record: Ciphertext size too big")
  300. }
  301. pt.fragment = make([]byte, size)
  302. copy(pt.fragment, body)
  303. // TODO(ekr@rtfm.com): Enforce that for epoch > 0, the content type is app data.
  304. // Attempt to decrypt fragment
  305. seq := cipher.seq
  306. if r.datagram {
  307. // TODO(ekr@rtfm.com): Handle duplicates.
  308. seq, _ = decodeUint(header[3:11], 8)
  309. epoch := Epoch(seq >> 48)
  310. // Look up the cipher suite from the epoch
  311. c, ok := r.readCiphers[epoch]
  312. if !ok {
  313. logf(logTypeIO, "%s Message from unknown epoch: [%v]", r.label, epoch)
  314. return nil, AlertWouldBlock
  315. }
  316. if epoch != cipher.epoch {
  317. logf(logTypeIO, "%s Message from non-current epoch: [%v != %v] out-of-epoch reads=%v", r.label, epoch,
  318. cipher.epoch, allowOldEpoch)
  319. if !allowOldEpoch {
  320. return nil, AlertWouldBlock
  321. }
  322. cipher = c
  323. }
  324. }
  325. if cipher.cipher != nil {
  326. logf(logTypeIO, "%s RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), seq, pt.contentType, pt.fragment)
  327. pt, _, err = r.decrypt(pt, seq)
  328. if err != nil {
  329. logf(logTypeIO, "%s Decryption failed", r.label)
  330. return nil, err
  331. }
  332. }
  333. pt.epoch = cipher.epoch
  334. // Check that plaintext length is not too long
  335. if len(pt.fragment) > maxFragmentLen {
  336. return nil, fmt.Errorf("tls.record: Plaintext size too big")
  337. }
  338. logf(logTypeIO, "%s RecordLayer.ReadRecord [%d] [%x]", r.label, pt.contentType, pt.fragment)
  339. r.cachedRecord = pt
  340. cipher.incrementSequenceNumber()
  341. return pt, nil
  342. }
  343. func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error {
  344. return r.writeRecordWithPadding(pt, r.cipher, 0)
  345. }
  346. func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
  347. return r.writeRecordWithPadding(pt, r.cipher, padLen)
  348. }
  349. func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error {
  350. seq := cipher.combineSeq(r.datagram)
  351. if cipher.cipher != nil {
  352. logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
  353. pt = r.encrypt(cipher, seq, pt, padLen)
  354. } else if padLen > 0 {
  355. return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
  356. }
  357. if len(pt.fragment) > maxFragmentLen {
  358. return fmt.Errorf("tls.record: Record size too big")
  359. }
  360. length := len(pt.fragment)
  361. var header []byte
  362. if !r.datagram {
  363. header = []byte{byte(pt.contentType),
  364. byte(r.version >> 8), byte(r.version & 0xff),
  365. byte(length >> 8), byte(length)}
  366. } else {
  367. header = make([]byte, 13)
  368. version := dtlsConvertVersion(r.version)
  369. copy(header, []byte{byte(pt.contentType),
  370. byte(version >> 8), byte(version & 0xff),
  371. })
  372. encodeUint(seq, 8, header[3:])
  373. encodeUint(uint64(length), 2, header[11:])
  374. }
  375. record := append(header, pt.fragment...)
  376. logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
  377. cipher.incrementSequenceNumber()
  378. _, err := r.conn.Write(record)
  379. return err
  380. }