dtlstransport.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. //go:build !js
  4. // +build !js
  5. package webrtc
  6. import (
  7. "crypto/ecdsa"
  8. "crypto/elliptic"
  9. "crypto/rand"
  10. "crypto/tls"
  11. "crypto/x509"
  12. "errors"
  13. "fmt"
  14. "strings"
  15. "sync"
  16. "sync/atomic"
  17. "time"
  18. "github.com/pion/dtls/v2"
  19. "github.com/pion/dtls/v2/pkg/crypto/fingerprint"
  20. "github.com/pion/interceptor"
  21. "github.com/pion/logging"
  22. "github.com/pion/rtcp"
  23. "github.com/pion/srtp/v2"
  24. "github.com/pion/webrtc/v3/internal/mux"
  25. "github.com/pion/webrtc/v3/internal/util"
  26. "github.com/pion/webrtc/v3/pkg/rtcerr"
  27. )
  28. // DTLSTransport allows an application access to information about the DTLS
  29. // transport over which RTP and RTCP packets are sent and received by
  30. // RTPSender and RTPReceiver, as well other data such as SCTP packets sent
  31. // and received by data channels.
  32. type DTLSTransport struct {
  33. lock sync.RWMutex
  34. iceTransport *ICETransport
  35. certificates []Certificate
  36. remoteParameters DTLSParameters
  37. remoteCertificate []byte
  38. state DTLSTransportState
  39. srtpProtectionProfile srtp.ProtectionProfile
  40. onStateChangeHandler func(DTLSTransportState)
  41. conn *dtls.Conn
  42. srtpSession, srtcpSession atomic.Value
  43. srtpEndpoint, srtcpEndpoint *mux.Endpoint
  44. simulcastStreams []*srtp.ReadStreamSRTP
  45. srtpReady chan struct{}
  46. dtlsMatcher mux.MatchFunc
  47. api *API
  48. log logging.LeveledLogger
  49. }
  50. // NewDTLSTransport creates a new DTLSTransport.
  51. // This constructor is part of the ORTC API. It is not
  52. // meant to be used together with the basic WebRTC API.
  53. func (api *API) NewDTLSTransport(transport *ICETransport, certificates []Certificate) (*DTLSTransport, error) {
  54. t := &DTLSTransport{
  55. iceTransport: transport,
  56. api: api,
  57. state: DTLSTransportStateNew,
  58. dtlsMatcher: mux.MatchDTLS,
  59. srtpReady: make(chan struct{}),
  60. log: api.settingEngine.LoggerFactory.NewLogger("DTLSTransport"),
  61. }
  62. if len(certificates) > 0 {
  63. now := time.Now()
  64. for _, x509Cert := range certificates {
  65. if !x509Cert.Expires().IsZero() && now.After(x509Cert.Expires()) {
  66. return nil, &rtcerr.InvalidAccessError{Err: ErrCertificateExpired}
  67. }
  68. t.certificates = append(t.certificates, x509Cert)
  69. }
  70. } else {
  71. sk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  72. if err != nil {
  73. return nil, &rtcerr.UnknownError{Err: err}
  74. }
  75. certificate, err := GenerateCertificate(sk)
  76. if err != nil {
  77. return nil, err
  78. }
  79. t.certificates = []Certificate{*certificate}
  80. }
  81. return t, nil
  82. }
  83. // ICETransport returns the currently-configured *ICETransport or nil
  84. // if one has not been configured
  85. func (t *DTLSTransport) ICETransport() *ICETransport {
  86. t.lock.RLock()
  87. defer t.lock.RUnlock()
  88. return t.iceTransport
  89. }
  90. // onStateChange requires the caller holds the lock
  91. func (t *DTLSTransport) onStateChange(state DTLSTransportState) {
  92. t.state = state
  93. handler := t.onStateChangeHandler
  94. if handler != nil {
  95. handler(state)
  96. }
  97. }
  98. // OnStateChange sets a handler that is fired when the DTLS
  99. // connection state changes.
  100. func (t *DTLSTransport) OnStateChange(f func(DTLSTransportState)) {
  101. t.lock.Lock()
  102. defer t.lock.Unlock()
  103. t.onStateChangeHandler = f
  104. }
  105. // State returns the current dtls transport state.
  106. func (t *DTLSTransport) State() DTLSTransportState {
  107. t.lock.RLock()
  108. defer t.lock.RUnlock()
  109. return t.state
  110. }
  111. // WriteRTCP sends a user provided RTCP packet to the connected peer. If no peer is connected the
  112. // packet is discarded.
  113. func (t *DTLSTransport) WriteRTCP(pkts []rtcp.Packet) (int, error) {
  114. raw, err := rtcp.Marshal(pkts)
  115. if err != nil {
  116. return 0, err
  117. }
  118. srtcpSession, err := t.getSRTCPSession()
  119. if err != nil {
  120. return 0, err
  121. }
  122. writeStream, err := srtcpSession.OpenWriteStream()
  123. if err != nil {
  124. // nolint
  125. return 0, fmt.Errorf("%w: %v", errPeerConnWriteRTCPOpenWriteStream, err)
  126. }
  127. return writeStream.Write(raw)
  128. }
  129. // GetLocalParameters returns the DTLS parameters of the local DTLSTransport upon construction.
  130. func (t *DTLSTransport) GetLocalParameters() (DTLSParameters, error) {
  131. fingerprints := []DTLSFingerprint{}
  132. for _, c := range t.certificates {
  133. prints, err := c.GetFingerprints()
  134. if err != nil {
  135. return DTLSParameters{}, err
  136. }
  137. fingerprints = append(fingerprints, prints...)
  138. }
  139. return DTLSParameters{
  140. Role: DTLSRoleAuto, // always returns the default role
  141. Fingerprints: fingerprints,
  142. }, nil
  143. }
  144. // GetRemoteCertificate returns the certificate chain in use by the remote side
  145. // returns an empty list prior to selection of the remote certificate
  146. func (t *DTLSTransport) GetRemoteCertificate() []byte {
  147. t.lock.RLock()
  148. defer t.lock.RUnlock()
  149. return t.remoteCertificate
  150. }
  151. func (t *DTLSTransport) startSRTP() error {
  152. srtpConfig := &srtp.Config{
  153. Profile: t.srtpProtectionProfile,
  154. BufferFactory: t.api.settingEngine.BufferFactory,
  155. LoggerFactory: t.api.settingEngine.LoggerFactory,
  156. }
  157. if t.api.settingEngine.replayProtection.SRTP != nil {
  158. srtpConfig.RemoteOptions = append(
  159. srtpConfig.RemoteOptions,
  160. srtp.SRTPReplayProtection(*t.api.settingEngine.replayProtection.SRTP),
  161. )
  162. }
  163. if t.api.settingEngine.disableSRTPReplayProtection {
  164. srtpConfig.RemoteOptions = append(
  165. srtpConfig.RemoteOptions,
  166. srtp.SRTPNoReplayProtection(),
  167. )
  168. }
  169. if t.api.settingEngine.replayProtection.SRTCP != nil {
  170. srtpConfig.RemoteOptions = append(
  171. srtpConfig.RemoteOptions,
  172. srtp.SRTCPReplayProtection(*t.api.settingEngine.replayProtection.SRTCP),
  173. )
  174. }
  175. if t.api.settingEngine.disableSRTCPReplayProtection {
  176. srtpConfig.RemoteOptions = append(
  177. srtpConfig.RemoteOptions,
  178. srtp.SRTCPNoReplayProtection(),
  179. )
  180. }
  181. connState := t.conn.ConnectionState()
  182. err := srtpConfig.ExtractSessionKeysFromDTLS(&connState, t.role() == DTLSRoleClient)
  183. if err != nil {
  184. // nolint
  185. return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
  186. }
  187. srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig)
  188. if err != nil {
  189. // nolint
  190. return fmt.Errorf("%w: %v", errFailedToStartSRTP, err)
  191. }
  192. srtcpSession, err := srtp.NewSessionSRTCP(t.srtcpEndpoint, srtpConfig)
  193. if err != nil {
  194. // nolint
  195. return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err)
  196. }
  197. t.srtpSession.Store(srtpSession)
  198. t.srtcpSession.Store(srtcpSession)
  199. close(t.srtpReady)
  200. return nil
  201. }
  202. func (t *DTLSTransport) getSRTPSession() (*srtp.SessionSRTP, error) {
  203. if value, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok {
  204. return value, nil
  205. }
  206. return nil, errDtlsTransportNotStarted
  207. }
  208. func (t *DTLSTransport) getSRTCPSession() (*srtp.SessionSRTCP, error) {
  209. if value, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok {
  210. return value, nil
  211. }
  212. return nil, errDtlsTransportNotStarted
  213. }
  214. func (t *DTLSTransport) role() DTLSRole {
  215. // If remote has an explicit role use the inverse
  216. switch t.remoteParameters.Role {
  217. case DTLSRoleClient:
  218. return DTLSRoleServer
  219. case DTLSRoleServer:
  220. return DTLSRoleClient
  221. default:
  222. }
  223. // If SettingEngine has an explicit role
  224. switch t.api.settingEngine.answeringDTLSRole {
  225. case DTLSRoleServer:
  226. return DTLSRoleServer
  227. case DTLSRoleClient:
  228. return DTLSRoleClient
  229. default:
  230. }
  231. // Remote was auto and no explicit role was configured via SettingEngine
  232. if t.iceTransport.Role() == ICERoleControlling {
  233. return DTLSRoleServer
  234. }
  235. return defaultDtlsRoleAnswer
  236. }
  237. // Start DTLS transport negotiation with the parameters of the remote DTLS transport
  238. func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
  239. // Take lock and prepare connection, we must not hold the lock
  240. // when connecting
  241. prepareTransport := func() (DTLSRole, *dtls.Config, error) {
  242. t.lock.Lock()
  243. defer t.lock.Unlock()
  244. if err := t.ensureICEConn(); err != nil {
  245. return DTLSRole(0), nil, err
  246. }
  247. if t.state != DTLSTransportStateNew {
  248. return DTLSRole(0), nil, &rtcerr.InvalidStateError{Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, t.state)}
  249. }
  250. t.srtpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTP)
  251. t.srtcpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTCP)
  252. t.remoteParameters = remoteParameters
  253. cert := t.certificates[0]
  254. t.onStateChange(DTLSTransportStateConnecting)
  255. return t.role(), &dtls.Config{
  256. Certificates: []tls.Certificate{
  257. {
  258. Certificate: [][]byte{cert.x509Cert.Raw},
  259. PrivateKey: cert.privateKey,
  260. },
  261. },
  262. SRTPProtectionProfiles: func() []dtls.SRTPProtectionProfile {
  263. if len(t.api.settingEngine.srtpProtectionProfiles) > 0 {
  264. return t.api.settingEngine.srtpProtectionProfiles
  265. }
  266. return defaultSrtpProtectionProfiles()
  267. }(),
  268. ClientAuth: dtls.RequireAnyClientCert,
  269. LoggerFactory: t.api.settingEngine.LoggerFactory,
  270. InsecureSkipVerify: !t.api.settingEngine.dtls.disableInsecureSkipVerify,
  271. }, nil
  272. }
  273. var dtlsConn *dtls.Conn
  274. dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS)
  275. role, dtlsConfig, err := prepareTransport()
  276. if err != nil {
  277. return err
  278. }
  279. if t.api.settingEngine.replayProtection.DTLS != nil {
  280. dtlsConfig.ReplayProtectionWindow = int(*t.api.settingEngine.replayProtection.DTLS)
  281. }
  282. if t.api.settingEngine.dtls.clientAuth != nil {
  283. dtlsConfig.ClientAuth = *t.api.settingEngine.dtls.clientAuth
  284. }
  285. dtlsConfig.FlightInterval = t.api.settingEngine.dtls.retransmissionInterval
  286. dtlsConfig.InsecureSkipVerifyHello = t.api.settingEngine.dtls.insecureSkipHelloVerify
  287. dtlsConfig.EllipticCurves = t.api.settingEngine.dtls.ellipticCurves
  288. dtlsConfig.ConnectContextMaker = t.api.settingEngine.dtls.connectContextMaker
  289. dtlsConfig.ExtendedMasterSecret = t.api.settingEngine.dtls.extendedMasterSecret
  290. dtlsConfig.ClientCAs = t.api.settingEngine.dtls.clientCAs
  291. dtlsConfig.RootCAs = t.api.settingEngine.dtls.rootCAs
  292. dtlsConfig.KeyLogWriter = t.api.settingEngine.dtls.keyLogWriter
  293. // Connect as DTLS Client/Server, function is blocking and we
  294. // must not hold the DTLSTransport lock
  295. if role == DTLSRoleClient {
  296. dtlsConn, err = dtls.Client(dtlsEndpoint, dtlsConfig)
  297. } else {
  298. dtlsConn, err = dtls.Server(dtlsEndpoint, dtlsConfig)
  299. }
  300. // Re-take the lock, nothing beyond here is blocking
  301. t.lock.Lock()
  302. defer t.lock.Unlock()
  303. if err != nil {
  304. t.onStateChange(DTLSTransportStateFailed)
  305. return err
  306. }
  307. srtpProfile, ok := dtlsConn.SelectedSRTPProtectionProfile()
  308. if !ok {
  309. t.onStateChange(DTLSTransportStateFailed)
  310. return ErrNoSRTPProtectionProfile
  311. }
  312. switch srtpProfile {
  313. case dtls.SRTP_AEAD_AES_128_GCM:
  314. t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes128Gcm
  315. case dtls.SRTP_AEAD_AES_256_GCM:
  316. t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes256Gcm
  317. case dtls.SRTP_AES128_CM_HMAC_SHA1_80:
  318. t.srtpProtectionProfile = srtp.ProtectionProfileAes128CmHmacSha1_80
  319. default:
  320. t.onStateChange(DTLSTransportStateFailed)
  321. return ErrNoSRTPProtectionProfile
  322. }
  323. // Check the fingerprint if a certificate was exchanged
  324. remoteCerts := dtlsConn.ConnectionState().PeerCertificates
  325. if len(remoteCerts) == 0 {
  326. t.onStateChange(DTLSTransportStateFailed)
  327. return errNoRemoteCertificate
  328. }
  329. t.remoteCertificate = remoteCerts[0]
  330. if !t.api.settingEngine.disableCertificateFingerprintVerification {
  331. parsedRemoteCert, err := x509.ParseCertificate(t.remoteCertificate)
  332. if err != nil {
  333. if closeErr := dtlsConn.Close(); closeErr != nil {
  334. t.log.Error(err.Error())
  335. }
  336. t.onStateChange(DTLSTransportStateFailed)
  337. return err
  338. }
  339. if err = t.validateFingerPrint(parsedRemoteCert); err != nil {
  340. if closeErr := dtlsConn.Close(); closeErr != nil {
  341. t.log.Error(err.Error())
  342. }
  343. t.onStateChange(DTLSTransportStateFailed)
  344. return err
  345. }
  346. }
  347. t.conn = dtlsConn
  348. t.onStateChange(DTLSTransportStateConnected)
  349. return t.startSRTP()
  350. }
  351. // Stop stops and closes the DTLSTransport object.
  352. func (t *DTLSTransport) Stop() error {
  353. t.lock.Lock()
  354. defer t.lock.Unlock()
  355. // Try closing everything and collect the errors
  356. var closeErrs []error
  357. if srtpSession, err := t.getSRTPSession(); err == nil && srtpSession != nil {
  358. closeErrs = append(closeErrs, srtpSession.Close())
  359. }
  360. if srtcpSession, err := t.getSRTCPSession(); err == nil && srtcpSession != nil {
  361. closeErrs = append(closeErrs, srtcpSession.Close())
  362. }
  363. for i := range t.simulcastStreams {
  364. closeErrs = append(closeErrs, t.simulcastStreams[i].Close())
  365. }
  366. if t.conn != nil {
  367. // dtls connection may be closed on sctp close.
  368. if err := t.conn.Close(); err != nil && !errors.Is(err, dtls.ErrConnClosed) {
  369. closeErrs = append(closeErrs, err)
  370. }
  371. }
  372. t.onStateChange(DTLSTransportStateClosed)
  373. return util.FlattenErrs(closeErrs)
  374. }
  375. func (t *DTLSTransport) validateFingerPrint(remoteCert *x509.Certificate) error {
  376. for _, fp := range t.remoteParameters.Fingerprints {
  377. hashAlgo, err := fingerprint.HashFromString(fp.Algorithm)
  378. if err != nil {
  379. return err
  380. }
  381. remoteValue, err := fingerprint.Fingerprint(remoteCert, hashAlgo)
  382. if err != nil {
  383. return err
  384. }
  385. if strings.EqualFold(remoteValue, fp.Value) {
  386. return nil
  387. }
  388. }
  389. return errNoMatchingCertificateFingerprint
  390. }
  391. func (t *DTLSTransport) ensureICEConn() error {
  392. if t.iceTransport == nil {
  393. return errICEConnectionNotStarted
  394. }
  395. return nil
  396. }
  397. func (t *DTLSTransport) storeSimulcastStream(s *srtp.ReadStreamSRTP) {
  398. t.lock.Lock()
  399. defer t.lock.Unlock()
  400. t.simulcastStreams = append(t.simulcastStreams, s)
  401. }
  402. func (t *DTLSTransport) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) {
  403. srtpSession, err := t.getSRTPSession()
  404. if err != nil {
  405. return nil, nil, nil, nil, err
  406. }
  407. rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc))
  408. if err != nil {
  409. return nil, nil, nil, nil, err
  410. }
  411. rtpInterceptor := t.api.interceptor.BindRemoteStream(&streamInfo, interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
  412. n, err = rtpReadStream.Read(in)
  413. return n, a, err
  414. }))
  415. srtcpSession, err := t.getSRTCPSession()
  416. if err != nil {
  417. return nil, nil, nil, nil, err
  418. }
  419. rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc))
  420. if err != nil {
  421. return nil, nil, nil, nil, err
  422. }
  423. rtcpInterceptor := t.api.interceptor.BindRTCPReader(interceptor.RTCPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
  424. n, err = rtcpReadStream.Read(in)
  425. return n, a, err
  426. }))
  427. return rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, nil
  428. }