| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
- // SPDX-License-Identifier: MIT
- package srtp
- import (
- "errors"
- "io"
- "net"
- "sync"
- "time"
- "github.com/pion/logging"
- "github.com/pion/transport/v2/packetio"
- )
- type streamSession interface {
- Close() error
- write([]byte) (int, error)
- decrypt([]byte) error
- }
- type session struct {
- localContextMutex sync.Mutex
- localContext, remoteContext *Context
- localOptions, remoteOptions []ContextOption
- newStream chan readStream
- acceptStreamTimeout time.Time
- started chan interface{}
- closed chan interface{}
- readStreamsClosed bool
- readStreams map[uint32]readStream
- readStreamsLock sync.Mutex
- log logging.LeveledLogger
- bufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser
- nextConn net.Conn
- }
- // Config is used to configure a session.
- // You can provide either a KeyingMaterialExporter to export keys
- // or directly pass the keys themselves.
- // After a Config is passed to a session it must not be modified.
- type Config struct {
- Keys SessionKeys
- Profile ProtectionProfile
- BufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser
- LoggerFactory logging.LoggerFactory
- AcceptStreamTimeout time.Time
- // List of local/remote context options.
- // ReplayProtection is enabled on remote context by default.
- // Default replay protection window size is 64.
- LocalOptions, RemoteOptions []ContextOption
- }
- // SessionKeys bundles the keys required to setup an SRTP session
- type SessionKeys struct {
- LocalMasterKey []byte
- LocalMasterSalt []byte
- RemoteMasterKey []byte
- RemoteMasterSalt []byte
- }
- func (s *session) getOrCreateReadStream(ssrc uint32, child streamSession, proto func() readStream) (readStream, bool) {
- s.readStreamsLock.Lock()
- defer s.readStreamsLock.Unlock()
- if s.readStreamsClosed {
- return nil, false
- }
- r, ok := s.readStreams[ssrc]
- if ok {
- return r, false
- }
- // Create the readStream.
- r = proto()
- if err := r.init(child, ssrc); err != nil {
- return nil, false
- }
- s.readStreams[ssrc] = r
- return r, true
- }
- func (s *session) removeReadStream(ssrc uint32) {
- s.readStreamsLock.Lock()
- defer s.readStreamsLock.Unlock()
- if s.readStreamsClosed {
- return
- }
- delete(s.readStreams, ssrc)
- }
- func (s *session) close() error {
- if s.nextConn == nil {
- return nil
- } else if err := s.nextConn.Close(); err != nil {
- return err
- }
- <-s.closed
- return nil
- }
- func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, child streamSession) error {
- var err error
- s.localContext, err = CreateContext(localMasterKey, localMasterSalt, profile, s.localOptions...)
- if err != nil {
- return err
- }
- s.remoteContext, err = CreateContext(remoteMasterKey, remoteMasterSalt, profile, s.remoteOptions...)
- if err != nil {
- return err
- }
- if err = s.nextConn.SetReadDeadline(s.acceptStreamTimeout); err != nil {
- return err
- }
- go func() {
- defer func() {
- close(s.newStream)
- s.readStreamsLock.Lock()
- s.readStreamsClosed = true
- s.readStreamsLock.Unlock()
- close(s.closed)
- }()
- b := make([]byte, 8192)
- for {
- var i int
- i, err = s.nextConn.Read(b)
- if err != nil {
- if !errors.Is(err, io.EOF) {
- s.log.Error(err.Error())
- }
- return
- }
- if err = child.decrypt(b[:i]); err != nil {
- s.log.Info(err.Error())
- }
- }
- }()
- close(s.started)
- return nil
- }
|