context.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package srtp
  4. import (
  5. "fmt"
  6. "github.com/pion/transport/v2/replaydetector"
  7. )
  8. const (
  9. labelSRTPEncryption = 0x00
  10. labelSRTPAuthenticationTag = 0x01
  11. labelSRTPSalt = 0x02
  12. labelSRTCPEncryption = 0x03
  13. labelSRTCPAuthenticationTag = 0x04
  14. labelSRTCPSalt = 0x05
  15. maxSequenceNumber = 65535
  16. maxROC = (1 << 32) - 1
  17. seqNumMedian = 1 << 15
  18. seqNumMax = 1 << 16
  19. srtcpIndexSize = 4
  20. )
  21. // Encrypt/Decrypt state for a single SRTP SSRC
  22. type srtpSSRCState struct {
  23. ssrc uint32
  24. rolloverHasProcessed bool
  25. index uint64
  26. replayDetector replaydetector.ReplayDetector
  27. }
  28. // Encrypt/Decrypt state for a single SRTCP SSRC
  29. type srtcpSSRCState struct {
  30. srtcpIndex uint32
  31. ssrc uint32
  32. replayDetector replaydetector.ReplayDetector
  33. }
  34. // Context represents a SRTP cryptographic context.
  35. // Context can only be used for one-way operations.
  36. // it must either used ONLY for encryption or ONLY for decryption.
  37. // Note that Context does not provide any concurrency protection:
  38. // access to a Context from multiple goroutines requires external
  39. // synchronization.
  40. type Context struct {
  41. cipher srtpCipher
  42. srtpSSRCStates map[uint32]*srtpSSRCState
  43. srtcpSSRCStates map[uint32]*srtcpSSRCState
  44. newSRTCPReplayDetector func() replaydetector.ReplayDetector
  45. newSRTPReplayDetector func() replaydetector.ReplayDetector
  46. }
  47. // CreateContext creates a new SRTP Context.
  48. //
  49. // CreateContext receives variable number of ContextOption-s.
  50. // Passing multiple options which set the same parameter let the last one valid.
  51. // Following example create SRTP Context with replay protection with window size of 256.
  52. //
  53. // decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256))
  54. func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) {
  55. keyLen, err := profile.keyLen()
  56. if err != nil {
  57. return nil, err
  58. }
  59. saltLen, err := profile.saltLen()
  60. if err != nil {
  61. return nil, err
  62. }
  63. if masterKeyLen := len(masterKey); masterKeyLen != keyLen {
  64. return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen)
  65. } else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen {
  66. return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen)
  67. }
  68. c = &Context{
  69. srtpSSRCStates: map[uint32]*srtpSSRCState{},
  70. srtcpSSRCStates: map[uint32]*srtcpSSRCState{},
  71. }
  72. switch profile {
  73. case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm:
  74. c.cipher, err = newSrtpCipherAeadAesGcm(profile, masterKey, masterSalt)
  75. case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80:
  76. c.cipher, err = newSrtpCipherAesCmHmacSha1(profile, masterKey, masterSalt)
  77. default:
  78. return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, profile)
  79. }
  80. if err != nil {
  81. return nil, err
  82. }
  83. for _, o := range append(
  84. []ContextOption{ // Default options
  85. SRTPNoReplayProtection(),
  86. SRTCPNoReplayProtection(),
  87. },
  88. opts..., // User specified options
  89. ) {
  90. if errOpt := o(c); errOpt != nil {
  91. return nil, errOpt
  92. }
  93. }
  94. return c, nil
  95. }
  96. // https://tools.ietf.org/html/rfc3550#appendix-A.1
  97. func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (roc uint32, diff int32, overflow bool) {
  98. seq := int32(sequenceNumber)
  99. localRoc := uint32(s.index >> 16)
  100. localSeq := int32(s.index & (seqNumMax - 1))
  101. guessRoc := localRoc
  102. var difference int32
  103. if s.rolloverHasProcessed {
  104. // When localROC is equal to 0, and entering seq-localSeq > seqNumMedian
  105. // judgment, it will cause guessRoc calculation error
  106. if s.index > seqNumMedian {
  107. if localSeq < seqNumMedian {
  108. if seq-localSeq > seqNumMedian {
  109. guessRoc = localRoc - 1
  110. difference = seq - localSeq - seqNumMax
  111. } else {
  112. guessRoc = localRoc
  113. difference = seq - localSeq
  114. }
  115. } else {
  116. if localSeq-seqNumMedian > seq {
  117. guessRoc = localRoc + 1
  118. difference = seq - localSeq + seqNumMax
  119. } else {
  120. guessRoc = localRoc
  121. difference = seq - localSeq
  122. }
  123. }
  124. } else {
  125. // localRoc is equal to 0
  126. difference = seq - localSeq
  127. }
  128. }
  129. return guessRoc, difference, (guessRoc == 0 && localRoc == maxROC)
  130. }
  131. func (s *srtpSSRCState) updateRolloverCount(sequenceNumber uint16, difference int32) {
  132. if !s.rolloverHasProcessed {
  133. s.index |= uint64(sequenceNumber)
  134. s.rolloverHasProcessed = true
  135. return
  136. }
  137. if difference > 0 {
  138. s.index += uint64(difference)
  139. }
  140. }
  141. func (c *Context) getSRTPSSRCState(ssrc uint32) *srtpSSRCState {
  142. s, ok := c.srtpSSRCStates[ssrc]
  143. if ok {
  144. return s
  145. }
  146. s = &srtpSSRCState{
  147. ssrc: ssrc,
  148. replayDetector: c.newSRTPReplayDetector(),
  149. }
  150. c.srtpSSRCStates[ssrc] = s
  151. return s
  152. }
  153. func (c *Context) getSRTCPSSRCState(ssrc uint32) *srtcpSSRCState {
  154. s, ok := c.srtcpSSRCStates[ssrc]
  155. if ok {
  156. return s
  157. }
  158. s = &srtcpSSRCState{
  159. ssrc: ssrc,
  160. replayDetector: c.newSRTCPReplayDetector(),
  161. }
  162. c.srtcpSSRCStates[ssrc] = s
  163. return s
  164. }
  165. // ROC returns SRTP rollover counter value of specified SSRC.
  166. func (c *Context) ROC(ssrc uint32) (uint32, bool) {
  167. s, ok := c.srtpSSRCStates[ssrc]
  168. if !ok {
  169. return 0, false
  170. }
  171. return uint32(s.index >> 16), true
  172. }
  173. // SetROC sets SRTP rollover counter value of specified SSRC.
  174. func (c *Context) SetROC(ssrc uint32, roc uint32) {
  175. s := c.getSRTPSSRCState(ssrc)
  176. s.index = uint64(roc) << 16
  177. s.rolloverHasProcessed = false
  178. }
  179. // Index returns SRTCP index value of specified SSRC.
  180. func (c *Context) Index(ssrc uint32) (uint32, bool) {
  181. s, ok := c.srtcpSSRCStates[ssrc]
  182. if !ok {
  183. return 0, false
  184. }
  185. return s.srtcpIndex, true
  186. }
  187. // SetIndex sets SRTCP index value of specified SSRC.
  188. func (c *Context) SetIndex(ssrc uint32, index uint32) {
  189. s := c.getSRTCPSSRCState(ssrc)
  190. s.srtcpIndex = index % (maxSRTCPIndex + 1)
  191. }