state.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "bytes"
  6. "encoding/gob"
  7. "sync/atomic"
  8. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  9. "github.com/pion/dtls/v2/pkg/crypto/prf"
  10. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  11. "github.com/pion/transport/v2/replaydetector"
  12. )
  13. // State holds the dtls connection state and implements both encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
  14. type State struct {
  15. localEpoch, remoteEpoch atomic.Value
  16. localSequenceNumber []uint64 // uint48
  17. localRandom, remoteRandom handshake.Random
  18. masterSecret []byte
  19. cipherSuite CipherSuite // nil if a cipherSuite hasn't been chosen
  20. srtpProtectionProfile atomic.Value // Negotiated SRTPProtectionProfile
  21. PeerCertificates [][]byte
  22. IdentityHint []byte
  23. SessionID []byte
  24. isClient bool
  25. preMasterSecret []byte
  26. extendedMasterSecret bool
  27. namedCurve elliptic.Curve
  28. localKeypair *elliptic.Keypair
  29. cookie []byte
  30. handshakeSendSequence int
  31. handshakeRecvSequence int
  32. serverName string
  33. remoteRequestedCertificate bool // Did we get a CertificateRequest
  34. localCertificatesVerify []byte // cache CertificateVerify
  35. localVerifyData []byte // cached VerifyData
  36. localKeySignature []byte // cached keySignature
  37. peerCertificatesVerified bool
  38. replayDetector []replaydetector.ReplayDetector
  39. peerSupportedProtocols []string
  40. NegotiatedProtocol string
  41. }
  42. type serializedState struct {
  43. LocalEpoch uint16
  44. RemoteEpoch uint16
  45. LocalRandom [handshake.RandomLength]byte
  46. RemoteRandom [handshake.RandomLength]byte
  47. CipherSuiteID uint16
  48. MasterSecret []byte
  49. SequenceNumber uint64
  50. SRTPProtectionProfile uint16
  51. PeerCertificates [][]byte
  52. IdentityHint []byte
  53. SessionID []byte
  54. IsClient bool
  55. }
  56. func (s *State) clone() *State {
  57. serialized := s.serialize()
  58. state := &State{}
  59. state.deserialize(*serialized)
  60. return state
  61. }
  62. func (s *State) serialize() *serializedState {
  63. // Marshal random values
  64. localRnd := s.localRandom.MarshalFixed()
  65. remoteRnd := s.remoteRandom.MarshalFixed()
  66. epoch := s.getLocalEpoch()
  67. return &serializedState{
  68. LocalEpoch: s.getLocalEpoch(),
  69. RemoteEpoch: s.getRemoteEpoch(),
  70. CipherSuiteID: uint16(s.cipherSuite.ID()),
  71. MasterSecret: s.masterSecret,
  72. SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]),
  73. LocalRandom: localRnd,
  74. RemoteRandom: remoteRnd,
  75. SRTPProtectionProfile: uint16(s.getSRTPProtectionProfile()),
  76. PeerCertificates: s.PeerCertificates,
  77. IdentityHint: s.IdentityHint,
  78. SessionID: s.SessionID,
  79. IsClient: s.isClient,
  80. }
  81. }
  82. func (s *State) deserialize(serialized serializedState) {
  83. // Set epoch values
  84. epoch := serialized.LocalEpoch
  85. s.localEpoch.Store(serialized.LocalEpoch)
  86. s.remoteEpoch.Store(serialized.RemoteEpoch)
  87. for len(s.localSequenceNumber) <= int(epoch) {
  88. s.localSequenceNumber = append(s.localSequenceNumber, uint64(0))
  89. }
  90. // Set random values
  91. localRandom := &handshake.Random{}
  92. localRandom.UnmarshalFixed(serialized.LocalRandom)
  93. s.localRandom = *localRandom
  94. remoteRandom := &handshake.Random{}
  95. remoteRandom.UnmarshalFixed(serialized.RemoteRandom)
  96. s.remoteRandom = *remoteRandom
  97. s.isClient = serialized.IsClient
  98. // Set master secret
  99. s.masterSecret = serialized.MasterSecret
  100. // Set cipher suite
  101. s.cipherSuite = cipherSuiteForID(CipherSuiteID(serialized.CipherSuiteID), nil)
  102. atomic.StoreUint64(&s.localSequenceNumber[epoch], serialized.SequenceNumber)
  103. s.setSRTPProtectionProfile(SRTPProtectionProfile(serialized.SRTPProtectionProfile))
  104. // Set remote certificate
  105. s.PeerCertificates = serialized.PeerCertificates
  106. s.IdentityHint = serialized.IdentityHint
  107. s.SessionID = serialized.SessionID
  108. }
  109. func (s *State) initCipherSuite() error {
  110. if s.cipherSuite.IsInitialized() {
  111. return nil
  112. }
  113. localRandom := s.localRandom.MarshalFixed()
  114. remoteRandom := s.remoteRandom.MarshalFixed()
  115. var err error
  116. if s.isClient {
  117. err = s.cipherSuite.Init(s.masterSecret, localRandom[:], remoteRandom[:], true)
  118. } else {
  119. err = s.cipherSuite.Init(s.masterSecret, remoteRandom[:], localRandom[:], false)
  120. }
  121. if err != nil {
  122. return err
  123. }
  124. return nil
  125. }
  126. // MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation
  127. func (s *State) MarshalBinary() ([]byte, error) {
  128. serialized := s.serialize()
  129. var buf bytes.Buffer
  130. enc := gob.NewEncoder(&buf)
  131. if err := enc.Encode(*serialized); err != nil {
  132. return nil, err
  133. }
  134. return buf.Bytes(), nil
  135. }
  136. // UnmarshalBinary is a binary.BinaryUnmarshaler.UnmarshalBinary implementation
  137. func (s *State) UnmarshalBinary(data []byte) error {
  138. enc := gob.NewDecoder(bytes.NewBuffer(data))
  139. var serialized serializedState
  140. if err := enc.Decode(&serialized); err != nil {
  141. return err
  142. }
  143. s.deserialize(serialized)
  144. return s.initCipherSuite()
  145. }
  146. // ExportKeyingMaterial returns length bytes of exported key material in a new
  147. // slice as defined in RFC 5705.
  148. // This allows protocols to use DTLS for key establishment, but
  149. // then use some of the keying material for their own purposes
  150. func (s *State) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
  151. if s.getLocalEpoch() == 0 {
  152. return nil, errHandshakeInProgress
  153. } else if len(context) != 0 {
  154. return nil, errContextUnsupported
  155. } else if _, ok := invalidKeyingLabels()[label]; ok {
  156. return nil, errReservedExportKeyingMaterial
  157. }
  158. localRandom := s.localRandom.MarshalFixed()
  159. remoteRandom := s.remoteRandom.MarshalFixed()
  160. seed := []byte(label)
  161. if s.isClient {
  162. seed = append(append(seed, localRandom[:]...), remoteRandom[:]...)
  163. } else {
  164. seed = append(append(seed, remoteRandom[:]...), localRandom[:]...)
  165. }
  166. return prf.PHash(s.masterSecret, seed, length, s.cipherSuite.HashFunc())
  167. }
  168. func (s *State) getRemoteEpoch() uint16 {
  169. if remoteEpoch, ok := s.remoteEpoch.Load().(uint16); ok {
  170. return remoteEpoch
  171. }
  172. return 0
  173. }
  174. func (s *State) getLocalEpoch() uint16 {
  175. if localEpoch, ok := s.localEpoch.Load().(uint16); ok {
  176. return localEpoch
  177. }
  178. return 0
  179. }
  180. func (s *State) setSRTPProtectionProfile(profile SRTPProtectionProfile) {
  181. s.srtpProtectionProfile.Store(profile)
  182. }
  183. func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile {
  184. if val, ok := s.srtpProtectionProfile.Load().(SRTPProtectionProfile); ok {
  185. return val
  186. }
  187. return 0
  188. }
  189. // [Psiphon]
  190. // Conjure DTLS support, from: https://github.com/mingyech/dtls/commit/a56eccc1
  191. //
  192. // RemoteRandomBytes returns the random bytes from the client or server hello
  193. func (s *State) RemoteRandomBytes() [handshake.RandomBytesLength]byte {
  194. return s.remoteRandom.RandomBytes
  195. }