| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458 |
- package mint
- import (
- "crypto/cipher"
- "fmt"
- "io"
- "sync"
- )
- const (
- sequenceNumberLen = 8 // sequence number length
- recordHeaderLenTLS = 5 // record header length (TLS)
- recordHeaderLenDTLS = 13 // record header length (DTLS)
- maxFragmentLen = 1 << 14 // max number of bytes in a record
- )
- type DecryptError string
- func (err DecryptError) Error() string {
- return string(err)
- }
- type direction uint8
- const (
- directionWrite = direction(1)
- directionRead = direction(2)
- )
- // struct {
- // ContentType type;
- // ProtocolVersion record_version [0301 for CH, 0303 for others]
- // uint16 length;
- // opaque fragment[TLSPlaintext.length];
- // } TLSPlaintext;
- type TLSPlaintext struct {
- // Omitted: record_version (static)
- // Omitted: length (computed from fragment)
- contentType RecordType
- epoch Epoch
- seq uint64
- fragment []byte
- }
- type cipherState struct {
- epoch Epoch // DTLS epoch
- ivLength int // Length of the seq and nonce fields
- seq uint64 // Zero-padded sequence number
- iv []byte // Buffer for the IV
- cipher cipher.AEAD // AEAD cipher
- }
- type RecordLayer struct {
- sync.Mutex
- label string
- direction direction
- version uint16 // The current version number
- conn io.ReadWriter // The underlying connection
- frame *frameReader // The buffered frame reader
- nextData []byte // The next record to send
- cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
- cachedError error // Error on the last record read
- cipher *cipherState
- readCiphers map[Epoch]*cipherState
- datagram bool
- }
- type recordLayerFrameDetails struct {
- datagram bool
- }
- func (d recordLayerFrameDetails) headerLen() int {
- if d.datagram {
- return recordHeaderLenDTLS
- }
- return recordHeaderLenTLS
- }
- func (d recordLayerFrameDetails) defaultReadLen() int {
- return d.headerLen() + maxFragmentLen
- }
- func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
- return (int(hdr[d.headerLen()-2]) << 8) | int(hdr[d.headerLen()-1]), nil
- }
- func newCipherStateNull() *cipherState {
- return &cipherState{EpochClear, 0, 0, nil, nil}
- }
- func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) (*cipherState, error) {
- cipher, err := factory(key)
- if err != nil {
- return nil, err
- }
- return &cipherState{epoch, len(iv), 0, iv, cipher}, nil
- }
- func NewRecordLayerTLS(conn io.ReadWriter, dir direction) *RecordLayer {
- r := RecordLayer{}
- r.label = ""
- r.direction = dir
- r.conn = conn
- r.frame = newFrameReader(recordLayerFrameDetails{false})
- r.cipher = newCipherStateNull()
- r.version = tls10Version
- return &r
- }
- func NewRecordLayerDTLS(conn io.ReadWriter, dir direction) *RecordLayer {
- r := RecordLayer{}
- r.label = ""
- r.direction = dir
- r.conn = conn
- r.frame = newFrameReader(recordLayerFrameDetails{true})
- r.cipher = newCipherStateNull()
- r.readCiphers = make(map[Epoch]*cipherState, 0)
- r.readCiphers[0] = r.cipher
- r.datagram = true
- return &r
- }
- func (r *RecordLayer) SetVersion(v uint16) {
- r.version = v
- }
- func (r *RecordLayer) ResetClear(seq uint64) {
- r.cipher = newCipherStateNull()
- r.cipher.seq = seq
- }
- func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error {
- cipher, err := newCipherStateAead(epoch, factory, key, iv)
- if err != nil {
- return err
- }
- r.cipher = cipher
- if r.datagram && r.direction == directionRead {
- r.readCiphers[epoch] = cipher
- }
- return nil
- }
- // TODO(ekr@rtfm.com): This is never used, which is a bug.
- func (r *RecordLayer) DiscardReadKey(epoch Epoch) {
- if !r.datagram {
- return
- }
- _, ok := r.readCiphers[epoch]
- assert(ok)
- delete(r.readCiphers, epoch)
- }
- func (c *cipherState) combineSeq(datagram bool) uint64 {
- seq := c.seq
- if datagram {
- seq |= uint64(c.epoch) << 48
- }
- return seq
- }
- func (c *cipherState) computeNonce(seq uint64) []byte {
- nonce := make([]byte, len(c.iv))
- copy(nonce, c.iv)
- s := seq
- offset := len(c.iv)
- for i := 0; i < 8; i++ {
- nonce[(offset-i)-1] ^= byte(s & 0xff)
- s >>= 8
- }
- logf(logTypeCrypto, "Computing nonce for sequence # %x -> %x", seq, nonce)
- return nonce
- }
- func (c *cipherState) incrementSequenceNumber() {
- if c.seq >= (1<<48 - 1) {
- // Not allowed to let sequence number wrap.
- // Instead, must renegotiate before it does.
- // Not likely enough to bother. This is the
- // DTLS limit.
- panic("TLS: sequence number wraparound")
- }
- c.seq++
- }
- func (c *cipherState) overhead() int {
- if c.cipher == nil {
- return 0
- }
- return c.cipher.Overhead()
- }
- func (r *RecordLayer) encrypt(cipher *cipherState, seq uint64, pt *TLSPlaintext, padLen int) *TLSPlaintext {
- assert(r.direction == directionWrite)
- logf(logTypeIO, "%s Encrypt seq=[%x]", r.label, seq)
- // Expand the fragment to hold contentType, padding, and overhead
- originalLen := len(pt.fragment)
- plaintextLen := originalLen + 1 + padLen
- ciphertextLen := plaintextLen + cipher.overhead()
- // Assemble the revised plaintext
- out := &TLSPlaintext{
- contentType: RecordTypeApplicationData,
- fragment: make([]byte, ciphertextLen),
- }
- copy(out.fragment, pt.fragment)
- out.fragment[originalLen] = byte(pt.contentType)
- for i := 1; i <= padLen; i++ {
- out.fragment[originalLen+i] = 0
- }
- // Encrypt the fragment
- payload := out.fragment[:plaintextLen]
- cipher.cipher.Seal(payload[:0], cipher.computeNonce(seq), payload, nil)
- return out
- }
- func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq uint64) (*TLSPlaintext, int, error) {
- assert(r.direction == directionRead)
- logf(logTypeIO, "%s Decrypt seq=[%x]", r.label, seq)
- if len(pt.fragment) < r.cipher.overhead() {
- msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead())
- return nil, 0, DecryptError(msg)
- }
- decryptLen := len(pt.fragment) - r.cipher.overhead()
- out := &TLSPlaintext{
- contentType: pt.contentType,
- fragment: make([]byte, decryptLen),
- }
- // Decrypt
- _, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, nil)
- if err != nil {
- logf(logTypeIO, "%s AEAD decryption failure [%x]", r.label, pt)
- return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
- }
- // Find the padding boundary
- padLen := 0
- for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ {
- }
- // Transfer the content type
- newLen := decryptLen - padLen - 1
- out.contentType = RecordType(out.fragment[newLen])
- // Truncate the message to remove contentType, padding, overhead
- out.fragment = out.fragment[:newLen]
- out.seq = seq
- return out, padLen, nil
- }
- func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
- var pt *TLSPlaintext
- var err error
- for {
- pt, err = r.nextRecord(false)
- if err == nil {
- break
- }
- if !block || err != AlertWouldBlock {
- return 0, err
- }
- }
- return pt.contentType, nil
- }
- func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
- pt, err := r.nextRecord(false)
- // Consume the cached record if there was one
- r.cachedRecord = nil
- r.cachedError = nil
- return pt, err
- }
- func (r *RecordLayer) readRecordAnyEpoch() (*TLSPlaintext, error) {
- pt, err := r.nextRecord(true)
- // Consume the cached record if there was one
- r.cachedRecord = nil
- r.cachedError = nil
- return pt, err
- }
- func (r *RecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, error) {
- cipher := r.cipher
- if r.cachedRecord != nil {
- logf(logTypeIO, "%s Returning cached record", r.label)
- return r.cachedRecord, r.cachedError
- }
- // Loop until one of three things happens:
- //
- // 1. We get a frame
- // 2. We try to read off the socket and get nothing, in which case
- // returnAlertWouldBlock
- // 3. We get an error.
- var err error
- err = AlertWouldBlock
- var header, body []byte
- for err != nil {
- if r.frame.needed() > 0 {
- buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen)
- n, err := r.conn.Read(buf)
- if err != nil {
- logf(logTypeIO, "%s Error reading, %v", r.label, err)
- return nil, err
- }
- if n == 0 {
- return nil, AlertWouldBlock
- }
- logf(logTypeIO, "%s Read %v bytes", r.label, n)
- buf = buf[:n]
- r.frame.addChunk(buf)
- }
- header, body, err = r.frame.process()
- // Loop around onAlertWouldBlock to see if some
- // data is now available.
- if err != nil && err != AlertWouldBlock {
- return nil, err
- }
- }
- pt := &TLSPlaintext{}
- // Validate content type
- switch RecordType(header[0]) {
- default:
- return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
- case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData, RecordTypeAck:
- pt.contentType = RecordType(header[0])
- }
- // Validate version
- if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) {
- return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2])
- }
- // Validate size < max
- size := (int(header[len(header)-2]) << 8) + int(header[len(header)-1])
- if size > maxFragmentLen+256 {
- return nil, fmt.Errorf("tls.record: Ciphertext size too big")
- }
- pt.fragment = make([]byte, size)
- copy(pt.fragment, body)
- // TODO(ekr@rtfm.com): Enforce that for epoch > 0, the content type is app data.
- // Attempt to decrypt fragment
- seq := cipher.seq
- if r.datagram {
- // TODO(ekr@rtfm.com): Handle duplicates.
- seq, _ = decodeUint(header[3:11], 8)
- epoch := Epoch(seq >> 48)
- // Look up the cipher suite from the epoch
- c, ok := r.readCiphers[epoch]
- if !ok {
- logf(logTypeIO, "%s Message from unknown epoch: [%v]", r.label, epoch)
- return nil, AlertWouldBlock
- }
- if epoch != cipher.epoch {
- logf(logTypeIO, "%s Message from non-current epoch: [%v != %v] out-of-epoch reads=%v", r.label, epoch,
- cipher.epoch, allowOldEpoch)
- if !allowOldEpoch {
- return nil, AlertWouldBlock
- }
- cipher = c
- }
- }
- if cipher.cipher != nil {
- logf(logTypeIO, "%s RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), seq, pt.contentType, pt.fragment)
- pt, _, err = r.decrypt(pt, seq)
- if err != nil {
- logf(logTypeIO, "%s Decryption failed", r.label)
- return nil, err
- }
- }
- pt.epoch = cipher.epoch
- // Check that plaintext length is not too long
- if len(pt.fragment) > maxFragmentLen {
- return nil, fmt.Errorf("tls.record: Plaintext size too big")
- }
- logf(logTypeIO, "%s RecordLayer.ReadRecord [%d] [%x]", r.label, pt.contentType, pt.fragment)
- r.cachedRecord = pt
- cipher.incrementSequenceNumber()
- return pt, nil
- }
- func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error {
- return r.writeRecordWithPadding(pt, r.cipher, 0)
- }
- func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
- return r.writeRecordWithPadding(pt, r.cipher, padLen)
- }
- func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error {
- seq := cipher.combineSeq(r.datagram)
- if cipher.cipher != nil {
- logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
- pt = r.encrypt(cipher, seq, pt, padLen)
- } else if padLen > 0 {
- return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
- }
- if len(pt.fragment) > maxFragmentLen {
- return fmt.Errorf("tls.record: Record size too big")
- }
- length := len(pt.fragment)
- var header []byte
- if !r.datagram {
- header = []byte{byte(pt.contentType),
- byte(r.version >> 8), byte(r.version & 0xff),
- byte(length >> 8), byte(length)}
- } else {
- header = make([]byte, 13)
- version := dtlsConvertVersion(r.version)
- copy(header, []byte{byte(pt.contentType),
- byte(version >> 8), byte(version & 0xff),
- })
- encodeUint(seq, 8, header[3:])
- encodeUint(uint64(length), 2, header[11:])
- }
- record := append(header, pt.fragment...)
- logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
- cipher.incrementSequenceNumber()
- _, err := r.conn.Write(record)
- return err
- }
|