context.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. package srtp
  2. import (
  3. "fmt"
  4. "github.com/pion/transport/v2/replaydetector"
  5. )
  6. const (
  7. labelSRTPEncryption = 0x00
  8. labelSRTPAuthenticationTag = 0x01
  9. labelSRTPSalt = 0x02
  10. labelSRTCPEncryption = 0x03
  11. labelSRTCPAuthenticationTag = 0x04
  12. labelSRTCPSalt = 0x05
  13. maxSequenceNumber = 65535
  14. maxROC = (1 << 32) - 1
  15. seqNumMedian = 1 << 15
  16. seqNumMax = 1 << 16
  17. srtcpIndexSize = 4
  18. )
  19. // Encrypt/Decrypt state for a single SRTP SSRC
  20. type srtpSSRCState struct {
  21. ssrc uint32
  22. index uint64
  23. rolloverHasProcessed bool
  24. replayDetector replaydetector.ReplayDetector
  25. }
  26. // Encrypt/Decrypt state for a single SRTCP SSRC
  27. type srtcpSSRCState struct {
  28. srtcpIndex uint32
  29. ssrc uint32
  30. replayDetector replaydetector.ReplayDetector
  31. }
  32. // Context represents a SRTP cryptographic context.
  33. // Context can only be used for one-way operations.
  34. // it must either used ONLY for encryption or ONLY for decryption.
  35. // Note that Context does not provide any concurrency protection:
  36. // access to a Context from multiple goroutines requires external
  37. // synchronization.
  38. type Context struct {
  39. cipher srtpCipher
  40. srtpSSRCStates map[uint32]*srtpSSRCState
  41. srtcpSSRCStates map[uint32]*srtcpSSRCState
  42. newSRTCPReplayDetector func() replaydetector.ReplayDetector
  43. newSRTPReplayDetector func() replaydetector.ReplayDetector
  44. }
  45. // CreateContext creates a new SRTP Context.
  46. //
  47. // CreateContext receives variable number of ContextOption-s.
  48. // Passing multiple options which set the same parameter let the last one valid.
  49. // Following example create SRTP Context with replay protection with window size of 256.
  50. //
  51. // decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256))
  52. func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) {
  53. keyLen, err := profile.keyLen()
  54. if err != nil {
  55. return nil, err
  56. }
  57. saltLen, err := profile.saltLen()
  58. if err != nil {
  59. return nil, err
  60. }
  61. if masterKeyLen := len(masterKey); masterKeyLen != keyLen {
  62. return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen)
  63. } else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen {
  64. return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen)
  65. }
  66. c = &Context{
  67. srtpSSRCStates: map[uint32]*srtpSSRCState{},
  68. srtcpSSRCStates: map[uint32]*srtcpSSRCState{},
  69. }
  70. switch profile {
  71. case ProtectionProfileAeadAes128Gcm:
  72. c.cipher, err = newSrtpCipherAeadAesGcm(profile, masterKey, masterSalt)
  73. case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80:
  74. c.cipher, err = newSrtpCipherAesCmHmacSha1(profile, masterKey, masterSalt)
  75. default:
  76. return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, profile)
  77. }
  78. if err != nil {
  79. return nil, err
  80. }
  81. for _, o := range append(
  82. []ContextOption{ // Default options
  83. SRTPNoReplayProtection(),
  84. SRTCPNoReplayProtection(),
  85. },
  86. opts..., // User specified options
  87. ) {
  88. if errOpt := o(c); errOpt != nil {
  89. return nil, errOpt
  90. }
  91. }
  92. return c, nil
  93. }
  94. // https://tools.ietf.org/html/rfc3550#appendix-A.1
  95. func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (roc uint32, diff int32, overflow bool) {
  96. seq := int32(sequenceNumber)
  97. localRoc := uint32(s.index >> 16)
  98. localSeq := int32(s.index & (seqNumMax - 1))
  99. guessRoc := localRoc
  100. var difference int32
  101. if s.rolloverHasProcessed {
  102. // When localROC is equal to 0, and entering seq-localSeq > seqNumMedian
  103. // judgment, it will cause guessRoc calculation error
  104. if s.index > seqNumMedian {
  105. if localSeq < seqNumMedian {
  106. if seq-localSeq > seqNumMedian {
  107. guessRoc = localRoc - 1
  108. difference = seq - localSeq - seqNumMax
  109. } else {
  110. guessRoc = localRoc
  111. difference = seq - localSeq
  112. }
  113. } else {
  114. if localSeq-seqNumMedian > seq {
  115. guessRoc = localRoc + 1
  116. difference = seq - localSeq + seqNumMax
  117. } else {
  118. guessRoc = localRoc
  119. difference = seq - localSeq
  120. }
  121. }
  122. } else {
  123. // localRoc is equal to 0
  124. difference = seq - localSeq
  125. }
  126. }
  127. return guessRoc, difference, (guessRoc == 0 && localRoc == maxROC)
  128. }
  129. func (s *srtpSSRCState) updateRolloverCount(sequenceNumber uint16, difference int32) {
  130. if !s.rolloverHasProcessed {
  131. s.index |= uint64(sequenceNumber)
  132. s.rolloverHasProcessed = true
  133. return
  134. }
  135. if difference > 0 {
  136. s.index += uint64(difference)
  137. }
  138. }
  139. func (c *Context) getSRTPSSRCState(ssrc uint32) *srtpSSRCState {
  140. s, ok := c.srtpSSRCStates[ssrc]
  141. if ok {
  142. return s
  143. }
  144. s = &srtpSSRCState{
  145. ssrc: ssrc,
  146. replayDetector: c.newSRTPReplayDetector(),
  147. }
  148. c.srtpSSRCStates[ssrc] = s
  149. return s
  150. }
  151. func (c *Context) getSRTCPSSRCState(ssrc uint32) *srtcpSSRCState {
  152. s, ok := c.srtcpSSRCStates[ssrc]
  153. if ok {
  154. return s
  155. }
  156. s = &srtcpSSRCState{
  157. ssrc: ssrc,
  158. replayDetector: c.newSRTCPReplayDetector(),
  159. }
  160. c.srtcpSSRCStates[ssrc] = s
  161. return s
  162. }
  163. // ROC returns SRTP rollover counter value of specified SSRC.
  164. func (c *Context) ROC(ssrc uint32) (uint32, bool) {
  165. s, ok := c.srtpSSRCStates[ssrc]
  166. if !ok {
  167. return 0, false
  168. }
  169. return uint32(s.index >> 16), true
  170. }
  171. // SetROC sets SRTP rollover counter value of specified SSRC.
  172. func (c *Context) SetROC(ssrc uint32, roc uint32) {
  173. s := c.getSRTPSSRCState(ssrc)
  174. s.index = uint64(roc) << 16
  175. s.rolloverHasProcessed = false
  176. }
  177. // Index returns SRTCP index value of specified SSRC.
  178. func (c *Context) Index(ssrc uint32) (uint32, bool) {
  179. s, ok := c.srtcpSSRCStates[ssrc]
  180. if !ok {
  181. return 0, false
  182. }
  183. return s.srtcpIndex, true
  184. }
  185. // SetIndex sets SRTCP index value of specified SSRC.
  186. func (c *Context) SetIndex(ssrc uint32, index uint32) {
  187. s := c.getSRTCPSSRCState(ssrc)
  188. s.srtcpIndex = index % (maxSRTCPIndex + 1)
  189. }