| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425 |
- // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
- // SPDX-License-Identifier: MIT
- //go:build !js
- // +build !js
- package webrtc
- import (
- "errors"
- "io"
- "math"
- "sync"
- "time"
- "github.com/pion/datachannel"
- "github.com/pion/logging"
- "github.com/pion/sctp"
- "github.com/pion/webrtc/v3/pkg/rtcerr"
- )
- const sctpMaxChannels = uint16(65535)
- // SCTPTransport provides details about the SCTP transport.
- type SCTPTransport struct {
- lock sync.RWMutex
- dtlsTransport *DTLSTransport
- // State represents the current state of the SCTP transport.
- state SCTPTransportState
- // SCTPTransportState doesn't have an enum to distinguish between New/Connecting
- // so we need a dedicated field
- isStarted bool
- // MaxMessageSize represents the maximum size of data that can be passed to
- // DataChannel's send() method.
- maxMessageSize float64
- // MaxChannels represents the maximum amount of DataChannel's that can
- // be used simultaneously.
- maxChannels *uint16
- // OnStateChange func()
- onErrorHandler func(error)
- sctpAssociation *sctp.Association
- onDataChannelHandler func(*DataChannel)
- onDataChannelOpenedHandler func(*DataChannel)
- // DataChannels
- dataChannels []*DataChannel
- dataChannelsOpened uint32
- dataChannelsRequested uint32
- dataChannelsAccepted uint32
- api *API
- log logging.LeveledLogger
- }
- // NewSCTPTransport creates a new SCTPTransport.
- // This constructor is part of the ORTC API. It is not
- // meant to be used together with the basic WebRTC API.
- func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport {
- res := &SCTPTransport{
- dtlsTransport: dtls,
- state: SCTPTransportStateConnecting,
- api: api,
- log: api.settingEngine.LoggerFactory.NewLogger("ortc"),
- }
- res.updateMessageSize()
- res.updateMaxChannels()
- return res
- }
- // Transport returns the DTLSTransport instance the SCTPTransport is sending over.
- func (r *SCTPTransport) Transport() *DTLSTransport {
- r.lock.RLock()
- defer r.lock.RUnlock()
- return r.dtlsTransport
- }
- // GetCapabilities returns the SCTPCapabilities of the SCTPTransport.
- func (r *SCTPTransport) GetCapabilities() SCTPCapabilities {
- return SCTPCapabilities{
- MaxMessageSize: 0,
- }
- }
- // Start the SCTPTransport. Since both local and remote parties must mutually
- // create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish
- // a connection over SCTP.
- func (r *SCTPTransport) Start(SCTPCapabilities) error {
- if r.isStarted {
- return nil
- }
- r.isStarted = true
- dtlsTransport := r.Transport()
- if dtlsTransport == nil || dtlsTransport.conn == nil {
- return errSCTPTransportDTLS
- }
- sctpAssociation, err := sctp.Client(sctp.Config{
- NetConn: dtlsTransport.conn,
- MaxReceiveBufferSize: r.api.settingEngine.sctp.maxReceiveBufferSize,
- EnableZeroChecksum: r.api.settingEngine.sctp.enableZeroChecksum,
- LoggerFactory: r.api.settingEngine.LoggerFactory,
- })
- if err != nil {
- return err
- }
- r.lock.Lock()
- r.sctpAssociation = sctpAssociation
- r.state = SCTPTransportStateConnected
- dataChannels := append([]*DataChannel{}, r.dataChannels...)
- r.lock.Unlock()
- var openedDCCount uint32
- for _, d := range dataChannels {
- if d.ReadyState() == DataChannelStateConnecting {
- err := d.open(r)
- if err != nil {
- r.log.Warnf("failed to open data channel: %s", err)
- continue
- }
- openedDCCount++
- }
- }
- r.lock.Lock()
- r.dataChannelsOpened += openedDCCount
- r.lock.Unlock()
- go r.acceptDataChannels(sctpAssociation)
- return nil
- }
- // Stop stops the SCTPTransport
- func (r *SCTPTransport) Stop() error {
- r.lock.Lock()
- defer r.lock.Unlock()
- if r.sctpAssociation == nil {
- return nil
- }
- err := r.sctpAssociation.Close()
- if err != nil {
- return err
- }
- r.sctpAssociation = nil
- r.state = SCTPTransportStateClosed
- return nil
- }
- func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) {
- r.lock.RLock()
- dataChannels := make([]*datachannel.DataChannel, 0, len(r.dataChannels))
- for _, dc := range r.dataChannels {
- dc.mu.Lock()
- isNil := dc.dataChannel == nil
- dc.mu.Unlock()
- if isNil {
- continue
- }
- dataChannels = append(dataChannels, dc.dataChannel)
- }
- r.lock.RUnlock()
- ACCEPT:
- for {
- dc, err := datachannel.Accept(a, &datachannel.Config{
- LoggerFactory: r.api.settingEngine.LoggerFactory,
- }, dataChannels...)
- if err != nil {
- if !errors.Is(err, io.EOF) {
- r.log.Errorf("Failed to accept data channel: %v", err)
- r.onError(err)
- }
- return
- }
- for _, ch := range dataChannels {
- if ch.StreamIdentifier() == dc.StreamIdentifier() {
- continue ACCEPT
- }
- }
- var (
- maxRetransmits *uint16
- maxPacketLifeTime *uint16
- )
- val := uint16(dc.Config.ReliabilityParameter)
- ordered := true
- switch dc.Config.ChannelType {
- case datachannel.ChannelTypeReliable:
- ordered = true
- case datachannel.ChannelTypeReliableUnordered:
- ordered = false
- case datachannel.ChannelTypePartialReliableRexmit:
- ordered = true
- maxRetransmits = &val
- case datachannel.ChannelTypePartialReliableRexmitUnordered:
- ordered = false
- maxRetransmits = &val
- case datachannel.ChannelTypePartialReliableTimed:
- ordered = true
- maxPacketLifeTime = &val
- case datachannel.ChannelTypePartialReliableTimedUnordered:
- ordered = false
- maxPacketLifeTime = &val
- default:
- }
- sid := dc.StreamIdentifier()
- rtcDC, err := r.api.newDataChannel(&DataChannelParameters{
- ID: &sid,
- Label: dc.Config.Label,
- Protocol: dc.Config.Protocol,
- Negotiated: dc.Config.Negotiated,
- Ordered: ordered,
- MaxPacketLifeTime: maxPacketLifeTime,
- MaxRetransmits: maxRetransmits,
- }, r, r.api.settingEngine.LoggerFactory.NewLogger("ortc"))
- if err != nil {
- r.log.Errorf("Failed to accept data channel: %v", err)
- r.onError(err)
- return
- }
- <-r.onDataChannel(rtcDC)
- rtcDC.handleOpen(dc, true, dc.Config.Negotiated)
- r.lock.Lock()
- r.dataChannelsOpened++
- handler := r.onDataChannelOpenedHandler
- r.lock.Unlock()
- if handler != nil {
- handler(rtcDC)
- }
- }
- }
- // OnError sets an event handler which is invoked when
- // the SCTP connection error occurs.
- func (r *SCTPTransport) OnError(f func(err error)) {
- r.lock.Lock()
- defer r.lock.Unlock()
- r.onErrorHandler = f
- }
- func (r *SCTPTransport) onError(err error) {
- r.lock.RLock()
- handler := r.onErrorHandler
- r.lock.RUnlock()
- if handler != nil {
- go handler(err)
- }
- }
- // OnDataChannel sets an event handler which is invoked when a data
- // channel message arrives from a remote peer.
- func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) {
- r.lock.Lock()
- defer r.lock.Unlock()
- r.onDataChannelHandler = f
- }
- // OnDataChannelOpened sets an event handler which is invoked when a data
- // channel is opened
- func (r *SCTPTransport) OnDataChannelOpened(f func(*DataChannel)) {
- r.lock.Lock()
- defer r.lock.Unlock()
- r.onDataChannelOpenedHandler = f
- }
- func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) {
- r.lock.Lock()
- r.dataChannels = append(r.dataChannels, dc)
- r.dataChannelsAccepted++
- handler := r.onDataChannelHandler
- r.lock.Unlock()
- done = make(chan struct{})
- if handler == nil || dc == nil {
- close(done)
- return
- }
- // Run this synchronously to allow setup done in onDataChannelFn()
- // to complete before datachannel event handlers might be called.
- go func() {
- handler(dc)
- close(done)
- }()
- return
- }
- func (r *SCTPTransport) updateMessageSize() {
- r.lock.Lock()
- defer r.lock.Unlock()
- var remoteMaxMessageSize float64 = 65536 // pion/webrtc#758
- var canSendSize float64 = 65536 // pion/webrtc#758
- r.maxMessageSize = r.calcMessageSize(remoteMaxMessageSize, canSendSize)
- }
- func (r *SCTPTransport) calcMessageSize(remoteMaxMessageSize, canSendSize float64) float64 {
- switch {
- case remoteMaxMessageSize == 0 &&
- canSendSize == 0:
- return math.Inf(1)
- case remoteMaxMessageSize == 0:
- return canSendSize
- case canSendSize == 0:
- return remoteMaxMessageSize
- case canSendSize > remoteMaxMessageSize:
- return remoteMaxMessageSize
- default:
- return canSendSize
- }
- }
- func (r *SCTPTransport) updateMaxChannels() {
- val := sctpMaxChannels
- r.maxChannels = &val
- }
- // MaxChannels is the maximum number of RTCDataChannels that can be open simultaneously.
- func (r *SCTPTransport) MaxChannels() uint16 {
- r.lock.Lock()
- defer r.lock.Unlock()
- if r.maxChannels == nil {
- return sctpMaxChannels
- }
- return *r.maxChannels
- }
- // State returns the current state of the SCTPTransport
- func (r *SCTPTransport) State() SCTPTransportState {
- r.lock.RLock()
- defer r.lock.RUnlock()
- return r.state
- }
- func (r *SCTPTransport) collectStats(collector *statsReportCollector) {
- collector.Collecting()
- stats := SCTPTransportStats{
- Timestamp: statsTimestampFrom(time.Now()),
- Type: StatsTypeSCTPTransport,
- ID: "sctpTransport",
- }
- association := r.association()
- if association != nil {
- stats.BytesSent = association.BytesSent()
- stats.BytesReceived = association.BytesReceived()
- stats.SmoothedRoundTripTime = association.SRTT() * 0.001 // convert milliseconds to seconds
- stats.CongestionWindow = association.CWND()
- stats.ReceiverWindow = association.RWND()
- stats.MTU = association.MTU()
- }
- collector.Collect(stats.ID, stats)
- }
- func (r *SCTPTransport) generateAndSetDataChannelID(dtlsRole DTLSRole, idOut **uint16) error {
- var id uint16
- if dtlsRole != DTLSRoleClient {
- id++
- }
- max := r.MaxChannels()
- r.lock.Lock()
- defer r.lock.Unlock()
- // Create map of ids so we can compare without double-looping each time.
- idsMap := make(map[uint16]struct{}, len(r.dataChannels))
- for _, dc := range r.dataChannels {
- if dc.ID() == nil {
- continue
- }
- idsMap[*dc.ID()] = struct{}{}
- }
- for ; id < max-1; id += 2 {
- if _, ok := idsMap[id]; ok {
- continue
- }
- *idOut = &id
- return nil
- }
- return &rtcerr.OperationError{Err: ErrMaxDataChannelID}
- }
- func (r *SCTPTransport) association() *sctp.Association {
- if r == nil {
- return nil
- }
- r.lock.RLock()
- association := r.sctpAssociation
- r.lock.RUnlock()
- return association
- }
|