| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551 |
- package mint
- import (
- "fmt"
- "io"
- "net"
- )
- const (
- handshakeHeaderLenTLS = 4 // handshake message header length
- handshakeHeaderLenDTLS = 12 // handshake message header length
- maxHandshakeMessageLen = 1 << 24 // max handshake message length
- )
- // struct {
- // HandshakeType msg_type; /* handshake type */
- // uint24 length; /* bytes in message */
- // select (HandshakeType) {
- // ...
- // } body;
- // } Handshake;
- //
- // We do the select{...} part in a different layer, so we treat the
- // actual message body as opaque:
- //
- // struct {
- // HandshakeType msg_type;
- // opaque msg<0..2^24-1>
- // } Handshake;
- //
- type HandshakeMessage struct {
- msgType HandshakeType
- seq uint32
- body []byte
- datagram bool
- offset uint32 // Used for DTLS
- length uint32
- cipher *cipherState
- }
- // Note: This could be done with the `syntax` module, using the simplified
- // syntax as discussed above. However, since this is so simple, there's not
- // much benefit to doing so.
- // When datagram is set, we marshal this as a whole DTLS record.
- func (hm *HandshakeMessage) Marshal() []byte {
- if hm == nil {
- return []byte{}
- }
- fragLen := len(hm.body)
- var data []byte
- if hm.datagram {
- data = make([]byte, handshakeHeaderLenDTLS+fragLen)
- } else {
- data = make([]byte, handshakeHeaderLenTLS+fragLen)
- }
- tmp := data
- tmp = encodeUint(uint64(hm.msgType), 1, tmp)
- tmp = encodeUint(uint64(hm.length), 3, tmp)
- if hm.datagram {
- tmp = encodeUint(uint64(hm.seq), 2, tmp)
- tmp = encodeUint(uint64(hm.offset), 3, tmp)
- tmp = encodeUint(uint64(fragLen), 3, tmp)
- }
- copy(tmp, hm.body)
- return data
- }
- func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
- logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body)
- var body HandshakeMessageBody
- switch hm.msgType {
- case HandshakeTypeClientHello:
- body = new(ClientHelloBody)
- case HandshakeTypeServerHello:
- body = new(ServerHelloBody)
- case HandshakeTypeEncryptedExtensions:
- body = new(EncryptedExtensionsBody)
- case HandshakeTypeCertificate:
- body = new(CertificateBody)
- case HandshakeTypeCertificateRequest:
- body = new(CertificateRequestBody)
- case HandshakeTypeCertificateVerify:
- body = new(CertificateVerifyBody)
- case HandshakeTypeFinished:
- body = &FinishedBody{VerifyDataLen: len(hm.body)}
- case HandshakeTypeNewSessionTicket:
- body = new(NewSessionTicketBody)
- case HandshakeTypeKeyUpdate:
- body = new(KeyUpdateBody)
- case HandshakeTypeEndOfEarlyData:
- body = new(EndOfEarlyDataBody)
- default:
- return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
- }
- err := safeUnmarshal(body, hm.body)
- return body, err
- }
- func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
- data, err := body.Marshal()
- if err != nil {
- return nil, err
- }
- m := &HandshakeMessage{
- msgType: body.Type(),
- body: data,
- seq: h.msgSeq,
- datagram: h.datagram,
- length: uint32(len(data)),
- }
- h.msgSeq++
- return m, nil
- }
- type HandshakeLayer struct {
- ctx *HandshakeContext // The handshake we are attached to
- nonblocking bool // Should we operate in nonblocking mode
- conn *RecordLayer // Used for reading/writing records
- frame *frameReader // The buffered frame reader
- datagram bool // Is this DTLS?
- msgSeq uint32 // The DTLS message sequence number
- queued []*HandshakeMessage // In/out queue
- sent []*HandshakeMessage // Sent messages for DTLS
- recvdRecords []uint64 // Records we have received.
- maxFragmentLen int
- }
- type handshakeLayerFrameDetails struct {
- datagram bool
- }
- func (d handshakeLayerFrameDetails) headerLen() int {
- if d.datagram {
- return handshakeHeaderLenDTLS
- }
- return handshakeHeaderLenTLS
- }
- func (d handshakeLayerFrameDetails) defaultReadLen() int {
- return d.headerLen() + maxFragmentLen
- }
- func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
- logf(logTypeIO, "Header=%x", hdr)
- // The length of this fragment (as opposed to the message)
- // is always the last three bytes for both TLS and DTLS
- val, _ := decodeUint(hdr[len(hdr)-3:], 3)
- return int(val), nil
- }
- func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
- h := HandshakeLayer{}
- h.ctx = c
- h.conn = r
- h.datagram = false
- h.frame = newFrameReader(&handshakeLayerFrameDetails{false})
- h.maxFragmentLen = maxFragmentLen
- return &h
- }
- func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
- h := HandshakeLayer{}
- h.ctx = c
- h.conn = r
- h.datagram = true
- h.frame = newFrameReader(&handshakeLayerFrameDetails{true})
- h.maxFragmentLen = initialMtu // Not quite right
- return &h
- }
- func (h *HandshakeLayer) readRecord() error {
- logf(logTypeVerbose, "Trying to read record")
- pt, err := h.conn.readRecordAnyEpoch()
- if err != nil {
- return err
- }
- switch pt.contentType {
- case RecordTypeHandshake, RecordTypeAlert, RecordTypeAck:
- default:
- return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
- }
- if pt.contentType == RecordTypeAck {
- if !h.datagram {
- return fmt.Errorf("tls.handshakelayer: can't have ACK with TLS")
- }
- logf(logTypeIO, "read ACK")
- return h.ctx.processAck(pt.fragment)
- }
- if pt.contentType == RecordTypeAlert {
- logf(logTypeIO, "read alert %v", pt.fragment[1])
- if len(pt.fragment) < 2 {
- h.sendAlert(AlertUnexpectedMessage)
- return io.EOF
- }
- return Alert(pt.fragment[1])
- }
- assert(h.ctx.hIn.conn != nil)
- if pt.epoch != h.ctx.hIn.conn.cipher.epoch {
- // This is out of order but we're dropping it.
- // TODO(ekr@rtfm.com): If server, need to retransmit Finished.
- if pt.epoch == EpochClear || pt.epoch == EpochHandshakeData {
- return nil
- }
- // Anything else shouldn't happen.
- return AlertIllegalParameter
- }
- h.recvdRecords = append(h.recvdRecords, pt.seq)
- h.frame.addChunk(pt.fragment)
- return nil
- }
- // sendAlert sends a TLS alert message.
- func (h *HandshakeLayer) sendAlert(err Alert) error {
- tmp := make([]byte, 2)
- tmp[0] = AlertLevelError
- tmp[1] = byte(err)
- h.conn.WriteRecord(&TLSPlaintext{
- contentType: RecordTypeAlert,
- fragment: tmp},
- )
- // closeNotify is a special case in that it isn't an error:
- if err != AlertCloseNotify {
- return &net.OpError{Op: "local error", Err: err}
- }
- return nil
- }
- func (h *HandshakeLayer) noteMessageDelivered(seq uint32) {
- h.msgSeq = seq + 1
- var i int
- var m *HandshakeMessage
- for i, m = range h.queued {
- if m.seq > seq {
- break
- }
- }
- h.queued = h.queued[i:]
- }
- func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMessage, error) {
- if hm.seq < h.msgSeq {
- return nil, nil
- }
- // TODO(ekr@rtfm.com): Send an ACK immediately if we got something
- // out of order.
- h.ctx.receivedHandshakeMessage()
- if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
- // TODO(ekr@rtfm.com): Check the length?
- // This is complete.
- h.noteMessageDelivered(hm.seq)
- return hm, nil
- }
- // Now insert sorted.
- var i int
- for i = 0; i < len(h.queued); i++ {
- f := h.queued[i]
- if hm.seq < f.seq {
- break
- }
- if hm.offset < f.offset {
- break
- }
- }
- tmp := make([]*HandshakeMessage, 0, len(h.queued)+1)
- tmp = append(tmp, h.queued[:i]...)
- tmp = append(tmp, hm)
- tmp = append(tmp, h.queued[i:]...)
- h.queued = tmp
- return h.checkMessageAvailable()
- }
- func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) {
- if len(h.queued) == 0 {
- return nil, nil
- }
- hm := h.queued[0]
- if hm.seq != h.msgSeq {
- return nil, nil
- }
- if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
- // TODO(ekr@rtfm.com): Check the length?
- // This is complete.
- h.noteMessageDelivered(hm.seq)
- return hm, nil
- }
- // OK, this at least might complete the message.
- end := uint32(0)
- buf := make([]byte, hm.length)
- for _, f := range h.queued {
- // Out of fragments
- if f.seq > hm.seq {
- break
- }
- if f.length != uint32(len(buf)) {
- return nil, fmt.Errorf("Mismatched DTLS length")
- }
- if f.offset > end {
- break
- }
- if f.offset+uint32(len(f.body)) > end {
- // OK, this is adding something we don't know about
- copy(buf[f.offset:], f.body)
- end = f.offset + uint32(len(f.body))
- if end == hm.length {
- h2 := *hm
- h2.offset = 0
- h2.body = buf
- h.noteMessageDelivered(hm.seq)
- return &h2, nil
- }
- }
- }
- return nil, nil
- }
- func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
- var hdr, body []byte
- var err error
- hm, err := h.checkMessageAvailable()
- if err != nil {
- return nil, err
- }
- if hm != nil {
- return hm, nil
- }
- for {
- logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder))
- if h.frame.needed() > 0 {
- logf(logTypeVerbose, "Trying to read a new record")
- err = h.readRecord()
- if err != nil && (h.nonblocking || err != AlertWouldBlock) {
- return nil, err
- }
- }
- hdr, body, err = h.frame.process()
- if err == nil {
- break
- }
- if err != nil && (h.nonblocking || err != AlertWouldBlock) {
- return nil, err
- }
- }
- logf(logTypeHandshake, "read handshake message")
- hm = &HandshakeMessage{}
- hm.msgType = HandshakeType(hdr[0])
- hm.datagram = h.datagram
- hm.body = make([]byte, len(body))
- copy(hm.body, body)
- logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
- if h.datagram {
- tmp, hdr := decodeUint(hdr[1:], 3)
- hm.length = uint32(tmp)
- tmp, hdr = decodeUint(hdr, 2)
- hm.seq = uint32(tmp)
- tmp, hdr = decodeUint(hdr, 3)
- hm.offset = uint32(tmp)
- return h.newFragmentReceived(hm)
- }
- hm.length = uint32(len(body))
- return hm, nil
- }
- func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error {
- hm.cipher = h.conn.cipher
- h.queued = append(h.queued, hm)
- return nil
- }
- func (h *HandshakeLayer) SendQueuedMessages() (int, error) {
- logf(logTypeHandshake, "Sending outgoing messages")
- count, err := h.WriteMessages(h.queued)
- if !h.datagram {
- h.ClearQueuedMessages()
- }
- return count, err
- }
- func (h *HandshakeLayer) ClearQueuedMessages() {
- logf(logTypeHandshake, "Clearing outgoing hs message queue")
- h.queued = nil
- }
- func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (bool, int, error) {
- var buf []byte
- // Figure out if we're going to want the full header or just
- // the body
- hdrlen := 0
- if hm.datagram {
- hdrlen = handshakeHeaderLenDTLS
- } else if start == 0 {
- hdrlen = handshakeHeaderLenTLS
- }
- // Compute the amount of body we can fit in
- room -= hdrlen
- if room == 0 {
- // This works because we are doing one record per
- // message
- panic("Too short max fragment len")
- }
- bodylen := len(hm.body) - start
- if bodylen > room {
- bodylen = room
- }
- body := hm.body[start : start+bodylen]
- // Now see if this chunk has been ACKed. This doesn't produce ideal
- // retransmission but is simple.
- if h.ctx.fragmentAcked(hm.seq, start, bodylen) {
- logf(logTypeHandshake, "Fragment %v %v(%v) already acked. Skipping", hm.seq, start, bodylen)
- return false, start + bodylen, nil
- }
- // Encode the data.
- if hdrlen > 0 {
- hm2 := *hm
- hm2.offset = uint32(start)
- hm2.body = body
- buf = hm2.Marshal()
- hm = &hm2
- } else {
- buf = body
- }
- if h.datagram {
- // Remember that we sent this.
- h.ctx.sentFragments = append(h.ctx.sentFragments, &SentHandshakeFragment{
- hm.seq,
- start,
- len(body),
- h.conn.cipher.combineSeq(true),
- false,
- })
- }
- return true, start + bodylen, h.conn.writeRecordWithPadding(
- &TLSPlaintext{
- contentType: RecordTypeHandshake,
- fragment: buf,
- },
- hm.cipher, 0)
- }
- func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) (int, error) {
- start := int(0)
- if len(hm.body) > maxHandshakeMessageLen {
- return 0, fmt.Errorf("Tried to write a handshake message that's too long")
- }
- written := 0
- wrote := false
- // Always make one pass through to allow EOED (which is empty).
- for {
- var err error
- wrote, start, err = h.writeFragment(hm, start, h.maxFragmentLen)
- if err != nil {
- return 0, err
- }
- if wrote {
- written++
- }
- if start >= len(hm.body) {
- break
- }
- }
- return written, nil
- }
- func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) (int, error) {
- written := 0
- for _, hm := range hms {
- logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
- wrote, err := h.WriteMessage(hm)
- if err != nil {
- return 0, err
- }
- written += wrote
- }
- return written, nil
- }
- func encodeUint(v uint64, size int, out []byte) []byte {
- for i := size - 1; i >= 0; i-- {
- out[i] = byte(v & 0xff)
- v >>= 8
- }
- return out[size:]
- }
- func decodeUint(in []byte, size int) (uint64, []byte) {
- val := uint64(0)
- for i := 0; i < size; i++ {
- val <<= 8
- val += uint64(in[i])
- }
- return val, in[size:]
- }
- type marshalledPDU interface {
- Marshal() ([]byte, error)
- Unmarshal(data []byte) (int, error)
- }
- func safeUnmarshal(pdu marshalledPDU, data []byte) error {
- read, err := pdu.Unmarshal(data)
- if err != nil {
- return err
- }
- if len(data) != read {
- return fmt.Errorf("Invalid encoding: Extra data not consumed")
- }
- return nil
- }
|