session.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package srtp
  4. import (
  5. "errors"
  6. "io"
  7. "net"
  8. "sync"
  9. "time"
  10. "github.com/pion/logging"
  11. "github.com/pion/transport/v2/packetio"
  12. )
  13. type streamSession interface {
  14. Close() error
  15. write([]byte) (int, error)
  16. decrypt([]byte) error
  17. }
  18. type session struct {
  19. localContextMutex sync.Mutex
  20. localContext, remoteContext *Context
  21. localOptions, remoteOptions []ContextOption
  22. newStream chan readStream
  23. acceptStreamTimeout time.Time
  24. started chan interface{}
  25. closed chan interface{}
  26. readStreamsClosed bool
  27. readStreams map[uint32]readStream
  28. readStreamsLock sync.Mutex
  29. log logging.LeveledLogger
  30. bufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser
  31. nextConn net.Conn
  32. }
  33. // Config is used to configure a session.
  34. // You can provide either a KeyingMaterialExporter to export keys
  35. // or directly pass the keys themselves.
  36. // After a Config is passed to a session it must not be modified.
  37. type Config struct {
  38. Keys SessionKeys
  39. Profile ProtectionProfile
  40. BufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser
  41. LoggerFactory logging.LoggerFactory
  42. AcceptStreamTimeout time.Time
  43. // List of local/remote context options.
  44. // ReplayProtection is enabled on remote context by default.
  45. // Default replay protection window size is 64.
  46. LocalOptions, RemoteOptions []ContextOption
  47. }
  48. // SessionKeys bundles the keys required to setup an SRTP session
  49. type SessionKeys struct {
  50. LocalMasterKey []byte
  51. LocalMasterSalt []byte
  52. RemoteMasterKey []byte
  53. RemoteMasterSalt []byte
  54. }
  55. func (s *session) getOrCreateReadStream(ssrc uint32, child streamSession, proto func() readStream) (readStream, bool) {
  56. s.readStreamsLock.Lock()
  57. defer s.readStreamsLock.Unlock()
  58. if s.readStreamsClosed {
  59. return nil, false
  60. }
  61. r, ok := s.readStreams[ssrc]
  62. if ok {
  63. return r, false
  64. }
  65. // Create the readStream.
  66. r = proto()
  67. if err := r.init(child, ssrc); err != nil {
  68. return nil, false
  69. }
  70. s.readStreams[ssrc] = r
  71. return r, true
  72. }
  73. func (s *session) removeReadStream(ssrc uint32) {
  74. s.readStreamsLock.Lock()
  75. defer s.readStreamsLock.Unlock()
  76. if s.readStreamsClosed {
  77. return
  78. }
  79. delete(s.readStreams, ssrc)
  80. }
  81. func (s *session) close() error {
  82. if s.nextConn == nil {
  83. return nil
  84. } else if err := s.nextConn.Close(); err != nil {
  85. return err
  86. }
  87. <-s.closed
  88. return nil
  89. }
  90. func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, child streamSession) error {
  91. var err error
  92. s.localContext, err = CreateContext(localMasterKey, localMasterSalt, profile, s.localOptions...)
  93. if err != nil {
  94. return err
  95. }
  96. s.remoteContext, err = CreateContext(remoteMasterKey, remoteMasterSalt, profile, s.remoteOptions...)
  97. if err != nil {
  98. return err
  99. }
  100. if err = s.nextConn.SetReadDeadline(s.acceptStreamTimeout); err != nil {
  101. return err
  102. }
  103. go func() {
  104. defer func() {
  105. close(s.newStream)
  106. s.readStreamsLock.Lock()
  107. s.readStreamsClosed = true
  108. s.readStreamsLock.Unlock()
  109. close(s.closed)
  110. }()
  111. b := make([]byte, 8192)
  112. for {
  113. var i int
  114. i, err = s.nextConn.Read(b)
  115. if err != nil {
  116. if !errors.Is(err, io.EOF) {
  117. s.log.Error(err.Error())
  118. }
  119. return
  120. }
  121. if err = child.decrypt(b[:i]); err != nil {
  122. s.log.Info(err.Error())
  123. }
  124. }
  125. }()
  126. close(s.started)
  127. return nil
  128. }