dtls.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. package mint
  2. import (
  3. "fmt"
  4. "github.com/bifurcation/mint/syntax"
  5. "time"
  6. )
  7. const (
  8. initialMtu = 1200
  9. initialTimeout = 100
  10. )
  11. // labels for timers
  12. const (
  13. retransmitTimerLabel = "handshake retransmit"
  14. ackTimerLabel = "ack timer"
  15. )
  16. type SentHandshakeFragment struct {
  17. seq uint32
  18. offset int
  19. fragLength int
  20. record uint64
  21. acked bool
  22. }
  23. type DtlsAck struct {
  24. RecordNumbers []uint64 `tls:"head=2"`
  25. }
  26. func wireVersion(h *HandshakeLayer) uint16 {
  27. if h.datagram {
  28. return dtls12WireVersion
  29. }
  30. return tls12Version
  31. }
  32. func dtlsConvertVersion(version uint16) uint16 {
  33. if version == tls12Version {
  34. return dtls12WireVersion
  35. }
  36. if version == tls10Version {
  37. return 0xfeff
  38. }
  39. panic(fmt.Sprintf("Internal error, unexpected version=%d", version))
  40. }
  41. // TODO(ekr@rtfm.com): Move these to state-machine.go
  42. func (h *HandshakeContext) handshakeRetransmit() error {
  43. if _, err := h.hOut.SendQueuedMessages(); err != nil {
  44. return err
  45. }
  46. h.timers.start(retransmitTimerLabel,
  47. h.handshakeRetransmit,
  48. h.timeoutMS)
  49. // TODO(ekr@rtfm.com): Back off timer
  50. return nil
  51. }
  52. func (h *HandshakeContext) sendAck() error {
  53. toack := h.hIn.recvdRecords
  54. count := (initialMtu - 2) / 8 // TODO(ekr@rtfm.com): Current MTU
  55. if len(toack) > count {
  56. toack = toack[:count]
  57. }
  58. logf(logTypeHandshake, "Sending ACK: [%x]", toack)
  59. ack := &DtlsAck{toack}
  60. body, err := syntax.Marshal(&ack)
  61. if err != nil {
  62. return err
  63. }
  64. err = h.hOut.conn.WriteRecord(&TLSPlaintext{
  65. contentType: RecordTypeAck,
  66. fragment: body,
  67. })
  68. if err != nil {
  69. return err
  70. }
  71. return nil
  72. }
  73. func (h *HandshakeContext) processAck(data []byte) error {
  74. // Cancel the retransmit timer because we will be resending
  75. // and possibly re-arming later.
  76. h.timers.cancel(retransmitTimerLabel)
  77. ack := &DtlsAck{}
  78. read, err := syntax.Unmarshal(data, &ack)
  79. if err != nil {
  80. return err
  81. }
  82. if len(data) != read {
  83. return fmt.Errorf("Invalid encoding: Extra data not consumed")
  84. }
  85. logf(logTypeHandshake, "ACK: [%x]", ack.RecordNumbers)
  86. for _, r := range ack.RecordNumbers {
  87. for _, m := range h.sentFragments {
  88. if r == m.record {
  89. logf(logTypeHandshake, "Marking %v %v(%v) as acked",
  90. m.seq, m.offset, m.fragLength)
  91. m.acked = true
  92. }
  93. }
  94. }
  95. count, err := h.hOut.SendQueuedMessages()
  96. if err != nil {
  97. return err
  98. }
  99. if count == 0 {
  100. logf(logTypeHandshake, "All messages ACKed")
  101. h.hOut.ClearQueuedMessages()
  102. return nil
  103. }
  104. // Reset the timer
  105. h.timers.start(retransmitTimerLabel,
  106. h.handshakeRetransmit,
  107. h.timeoutMS)
  108. return nil
  109. }
  110. func (c *Conn) GetDTLSTimeout() (bool, time.Duration) {
  111. return c.hsCtx.timers.remaining()
  112. }
  113. func (h *HandshakeContext) receivedHandshakeMessage() {
  114. logf(logTypeHandshake, "%p Received handshake, waiting for start of flight = %v", h, h.waitingNextFlight)
  115. // This just enables tests.
  116. if h.hIn == nil {
  117. return
  118. }
  119. if !h.hIn.datagram {
  120. return
  121. }
  122. if h.waitingNextFlight {
  123. logf(logTypeHandshake, "Received the start of the flight")
  124. // Clear the outgoing DTLS queue and terminate the retransmit timer
  125. h.hOut.ClearQueuedMessages()
  126. h.timers.cancel(retransmitTimerLabel)
  127. // OK, we're not waiting any more.
  128. h.waitingNextFlight = false
  129. }
  130. // Now pre-emptively arm the ACK timer if it's not armed already.
  131. // We'll automatically dis-arm it at the end of the handshake.
  132. if h.timers.getTimer(ackTimerLabel) == nil {
  133. h.timers.start(ackTimerLabel, h.sendAck, h.timeoutMS/4)
  134. }
  135. }
  136. func (h *HandshakeContext) receivedEndOfFlight() {
  137. logf(logTypeHandshake, "%p Received the end of the flight", h)
  138. if !h.hIn.datagram {
  139. return
  140. }
  141. // Empty incoming queue
  142. h.hIn.queued = nil
  143. // Note that we are waiting for the next flight.
  144. h.waitingNextFlight = true
  145. // Clear the ACK queue.
  146. h.hIn.recvdRecords = nil
  147. // Disarm the ACK timer
  148. h.timers.cancel(ackTimerLabel)
  149. }
  150. func (h *HandshakeContext) receivedFinalFlight() {
  151. logf(logTypeHandshake, "%p Received final flight", h)
  152. if !h.hIn.datagram {
  153. return
  154. }
  155. // Disarm the ACK timer
  156. h.timers.cancel(ackTimerLabel)
  157. // But send an ACK immediately.
  158. h.sendAck()
  159. }
  160. func (h *HandshakeContext) fragmentAcked(seq uint32, offset int, fraglen int) bool {
  161. logf(logTypeHandshake, "Looking to see if fragment %v %v(%v) was acked", seq, offset, fraglen)
  162. for _, f := range h.sentFragments {
  163. if !f.acked {
  164. continue
  165. }
  166. if f.seq != seq {
  167. continue
  168. }
  169. if f.offset > offset {
  170. continue
  171. }
  172. // At this point, we know that the stored fragment starts
  173. // at or before what we want to send, so check where the end
  174. // is.
  175. if f.offset+f.fragLength < offset+fraglen {
  176. continue
  177. }
  178. return true
  179. }
  180. return false
  181. }