| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224 |
- // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
- // SPDX-License-Identifier: MIT
- package srtp
- import (
- "fmt"
- "github.com/pion/transport/v2/replaydetector"
- )
- const (
- labelSRTPEncryption = 0x00
- labelSRTPAuthenticationTag = 0x01
- labelSRTPSalt = 0x02
- labelSRTCPEncryption = 0x03
- labelSRTCPAuthenticationTag = 0x04
- labelSRTCPSalt = 0x05
- maxSequenceNumber = 65535
- maxROC = (1 << 32) - 1
- seqNumMedian = 1 << 15
- seqNumMax = 1 << 16
- srtcpIndexSize = 4
- )
- // Encrypt/Decrypt state for a single SRTP SSRC
- type srtpSSRCState struct {
- ssrc uint32
- rolloverHasProcessed bool
- index uint64
- replayDetector replaydetector.ReplayDetector
- }
- // Encrypt/Decrypt state for a single SRTCP SSRC
- type srtcpSSRCState struct {
- srtcpIndex uint32
- ssrc uint32
- replayDetector replaydetector.ReplayDetector
- }
- // Context represents a SRTP cryptographic context.
- // Context can only be used for one-way operations.
- // it must either used ONLY for encryption or ONLY for decryption.
- // Note that Context does not provide any concurrency protection:
- // access to a Context from multiple goroutines requires external
- // synchronization.
- type Context struct {
- cipher srtpCipher
- srtpSSRCStates map[uint32]*srtpSSRCState
- srtcpSSRCStates map[uint32]*srtcpSSRCState
- newSRTCPReplayDetector func() replaydetector.ReplayDetector
- newSRTPReplayDetector func() replaydetector.ReplayDetector
- }
- // CreateContext creates a new SRTP Context.
- //
- // CreateContext receives variable number of ContextOption-s.
- // Passing multiple options which set the same parameter let the last one valid.
- // Following example create SRTP Context with replay protection with window size of 256.
- //
- // decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256))
- func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) {
- keyLen, err := profile.keyLen()
- if err != nil {
- return nil, err
- }
- saltLen, err := profile.saltLen()
- if err != nil {
- return nil, err
- }
- if masterKeyLen := len(masterKey); masterKeyLen != keyLen {
- return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen)
- } else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen {
- return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen)
- }
- c = &Context{
- srtpSSRCStates: map[uint32]*srtpSSRCState{},
- srtcpSSRCStates: map[uint32]*srtcpSSRCState{},
- }
- switch profile {
- case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm:
- c.cipher, err = newSrtpCipherAeadAesGcm(profile, masterKey, masterSalt)
- case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80:
- c.cipher, err = newSrtpCipherAesCmHmacSha1(profile, masterKey, masterSalt)
- default:
- return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, profile)
- }
- if err != nil {
- return nil, err
- }
- for _, o := range append(
- []ContextOption{ // Default options
- SRTPNoReplayProtection(),
- SRTCPNoReplayProtection(),
- },
- opts..., // User specified options
- ) {
- if errOpt := o(c); errOpt != nil {
- return nil, errOpt
- }
- }
- return c, nil
- }
- // https://tools.ietf.org/html/rfc3550#appendix-A.1
- func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (roc uint32, diff int32, overflow bool) {
- seq := int32(sequenceNumber)
- localRoc := uint32(s.index >> 16)
- localSeq := int32(s.index & (seqNumMax - 1))
- guessRoc := localRoc
- var difference int32
- if s.rolloverHasProcessed {
- // When localROC is equal to 0, and entering seq-localSeq > seqNumMedian
- // judgment, it will cause guessRoc calculation error
- if s.index > seqNumMedian {
- if localSeq < seqNumMedian {
- if seq-localSeq > seqNumMedian {
- guessRoc = localRoc - 1
- difference = seq - localSeq - seqNumMax
- } else {
- guessRoc = localRoc
- difference = seq - localSeq
- }
- } else {
- if localSeq-seqNumMedian > seq {
- guessRoc = localRoc + 1
- difference = seq - localSeq + seqNumMax
- } else {
- guessRoc = localRoc
- difference = seq - localSeq
- }
- }
- } else {
- // localRoc is equal to 0
- difference = seq - localSeq
- }
- }
- return guessRoc, difference, (guessRoc == 0 && localRoc == maxROC)
- }
- func (s *srtpSSRCState) updateRolloverCount(sequenceNumber uint16, difference int32) {
- if !s.rolloverHasProcessed {
- s.index |= uint64(sequenceNumber)
- s.rolloverHasProcessed = true
- return
- }
- if difference > 0 {
- s.index += uint64(difference)
- }
- }
- func (c *Context) getSRTPSSRCState(ssrc uint32) *srtpSSRCState {
- s, ok := c.srtpSSRCStates[ssrc]
- if ok {
- return s
- }
- s = &srtpSSRCState{
- ssrc: ssrc,
- replayDetector: c.newSRTPReplayDetector(),
- }
- c.srtpSSRCStates[ssrc] = s
- return s
- }
- func (c *Context) getSRTCPSSRCState(ssrc uint32) *srtcpSSRCState {
- s, ok := c.srtcpSSRCStates[ssrc]
- if ok {
- return s
- }
- s = &srtcpSSRCState{
- ssrc: ssrc,
- replayDetector: c.newSRTCPReplayDetector(),
- }
- c.srtcpSSRCStates[ssrc] = s
- return s
- }
- // ROC returns SRTP rollover counter value of specified SSRC.
- func (c *Context) ROC(ssrc uint32) (uint32, bool) {
- s, ok := c.srtpSSRCStates[ssrc]
- if !ok {
- return 0, false
- }
- return uint32(s.index >> 16), true
- }
- // SetROC sets SRTP rollover counter value of specified SSRC.
- func (c *Context) SetROC(ssrc uint32, roc uint32) {
- s := c.getSRTPSSRCState(ssrc)
- s.index = uint64(roc) << 16
- s.rolloverHasProcessed = false
- }
- // Index returns SRTCP index value of specified SSRC.
- func (c *Context) Index(ssrc uint32) (uint32, bool) {
- s, ok := c.srtcpSSRCStates[ssrc]
- if !ok {
- return 0, false
- }
- return s.srtcpIndex, true
- }
- // SetIndex sets SRTCP index value of specified SSRC.
- func (c *Context) SetIndex(ssrc uint32, index uint32) {
- s := c.getSRTCPSSRCState(ssrc)
- s.srtcpIndex = index % (maxSRTCPIndex + 1)
- }
|