| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360 |
- // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
- // SPDX-License-Identifier: MIT
- package rtp
- import (
- "encoding/binary"
- "errors"
- "fmt"
- "strings"
- "github.com/pion/rtp/codecs/av1/obu"
- )
- var (
- ErrVLATooShort = errors.New("VLA payload too short") // ErrVLATooShort is returned when payload is too short
- ErrVLAInvalidStreamCount = errors.New("invalid RTP stream count in VLA") // ErrVLAInvalidStreamCount is returned when RTP stream count is invalid
- ErrVLAInvalidStreamID = errors.New("invalid RTP stream ID in VLA") // ErrVLAInvalidStreamID is returned when RTP stream ID is invalid
- ErrVLAInvalidSpatialID = errors.New("invalid spatial ID in VLA") // ErrVLAInvalidSpatialID is returned when spatial ID is invalid
- ErrVLADuplicateSpatialID = errors.New("duplicate spatial ID in VLA") // ErrVLADuplicateSpatialID is returned when spatial ID is invalid
- ErrVLAInvalidTemporalLayer = errors.New("invalid temporal layer in VLA") // ErrVLAInvalidTemporalLayer is returned when temporal layer is invalid
- )
- // SpatialLayer is a spatial layer in VLA.
- type SpatialLayer struct {
- RTPStreamID int
- SpatialID int
- TargetBitrates []int // target bitrates per temporal layer
- // Following members are valid only when HasResolutionAndFramerate is true
- Width int
- Height int
- Framerate int
- }
- // VLA is a Video Layer Allocation (VLA) extension.
- // See https://webrtc.googlesource.com/src/+/refs/heads/main/docs/native-code/rtp-hdrext/video-layers-allocation00
- type VLA struct {
- RTPStreamID int // 0-origin RTP stream ID (RID) this allocation is sent on (0..3)
- RTPStreamCount int // Number of RTP streams (1..4)
- ActiveSpatialLayer []SpatialLayer
- HasResolutionAndFramerate bool
- }
- type vlaMarshalingContext struct {
- slMBs [4]uint8
- sls [4][4]*SpatialLayer
- commonSLBM uint8
- encodedTargetBitrates [][]byte
- requiredLen int
- }
- func (v VLA) preprocessForMashaling(ctx *vlaMarshalingContext) error {
- for i := 0; i < len(v.ActiveSpatialLayer); i++ {
- sl := v.ActiveSpatialLayer[i]
- if sl.RTPStreamID < 0 || sl.RTPStreamID >= v.RTPStreamCount {
- return fmt.Errorf("invalid RTP streamID %d:%w", sl.RTPStreamID, ErrVLAInvalidStreamID)
- }
- if sl.SpatialID < 0 || sl.SpatialID >= 4 {
- return fmt.Errorf("invalid spatial ID %d: %w", sl.SpatialID, ErrVLAInvalidSpatialID)
- }
- if len(sl.TargetBitrates) == 0 || len(sl.TargetBitrates) > 4 {
- return fmt.Errorf("invalid temporal layer count %d: %w", len(sl.TargetBitrates), ErrVLAInvalidTemporalLayer)
- }
- ctx.slMBs[sl.RTPStreamID] |= 1 << sl.SpatialID
- if ctx.sls[sl.RTPStreamID][sl.SpatialID] != nil {
- return fmt.Errorf("duplicate spatial layer: %w", ErrVLADuplicateSpatialID)
- }
- ctx.sls[sl.RTPStreamID][sl.SpatialID] = &sl
- }
- return nil
- }
- func (v VLA) encodeTargetBitrates(ctx *vlaMarshalingContext) {
- for rtpStreamID := 0; rtpStreamID < v.RTPStreamCount; rtpStreamID++ {
- for spatialID := 0; spatialID < 4; spatialID++ {
- if sl := ctx.sls[rtpStreamID][spatialID]; sl != nil {
- for _, kbps := range sl.TargetBitrates {
- leb128 := obu.WriteToLeb128(uint(kbps))
- ctx.encodedTargetBitrates = append(ctx.encodedTargetBitrates, leb128)
- ctx.requiredLen += len(leb128)
- }
- }
- }
- }
- }
- func (v VLA) analyzeVLAForMarshaling() (*vlaMarshalingContext, error) {
- // Validate RTPStreamCount
- if v.RTPStreamCount <= 0 || v.RTPStreamCount > 4 {
- return nil, ErrVLAInvalidStreamCount
- }
- // Validate RTPStreamID
- if v.RTPStreamID < 0 || v.RTPStreamID >= v.RTPStreamCount {
- return nil, ErrVLAInvalidStreamID
- }
- ctx := &vlaMarshalingContext{}
- err := v.preprocessForMashaling(ctx)
- if err != nil {
- return nil, err
- }
- ctx.commonSLBM = commonSLBMValues(ctx.slMBs[:])
- // RID, NS, sl_bm fields
- if ctx.commonSLBM != 0 {
- ctx.requiredLen = 1
- } else {
- ctx.requiredLen = 3
- }
- // #tl fields
- ctx.requiredLen += (len(v.ActiveSpatialLayer)-1)/4 + 1
- v.encodeTargetBitrates(ctx)
- if v.HasResolutionAndFramerate {
- ctx.requiredLen += len(v.ActiveSpatialLayer) * 5
- }
- return ctx, nil
- }
- // Marshal encodes VLA into a byte slice.
- func (v VLA) Marshal() ([]byte, error) {
- ctx, err := v.analyzeVLAForMarshaling()
- if err != nil {
- return nil, err
- }
- payload := make([]byte, ctx.requiredLen)
- offset := 0
- // RID, NS, sl_bm fields
- payload[offset] = byte(v.RTPStreamID<<6) | byte(v.RTPStreamCount-1)<<4 | ctx.commonSLBM
- if ctx.commonSLBM == 0 {
- offset++
- for streamID := 0; streamID < v.RTPStreamCount; streamID++ {
- if streamID%2 == 0 {
- payload[offset+streamID/2] |= ctx.slMBs[streamID] << 4
- } else {
- payload[offset+streamID/2] |= ctx.slMBs[streamID]
- }
- }
- offset += (v.RTPStreamCount - 1) / 2
- }
- // #tl fields
- offset++
- var temporalLayerIndex int
- for rtpStreamID := 0; rtpStreamID < v.RTPStreamCount; rtpStreamID++ {
- for spatialID := 0; spatialID < 4; spatialID++ {
- if sl := ctx.sls[rtpStreamID][spatialID]; sl != nil {
- if temporalLayerIndex >= 4 {
- temporalLayerIndex = 0
- offset++
- }
- payload[offset] |= byte(len(sl.TargetBitrates)-1) << (2 * (3 - temporalLayerIndex))
- temporalLayerIndex++
- }
- }
- }
- // Target bitrate fields
- offset++
- for _, encodedKbps := range ctx.encodedTargetBitrates {
- encodedSize := len(encodedKbps)
- copy(payload[offset:], encodedKbps)
- offset += encodedSize
- }
- // Resolution & framerate fields
- if v.HasResolutionAndFramerate {
- for _, sl := range v.ActiveSpatialLayer {
- binary.BigEndian.PutUint16(payload[offset+0:], uint16(sl.Width-1))
- binary.BigEndian.PutUint16(payload[offset+2:], uint16(sl.Height-1))
- payload[offset+4] = byte(sl.Framerate)
- offset += 5
- }
- }
- return payload, nil
- }
- func commonSLBMValues(slMBs []uint8) uint8 {
- var common uint8
- for i := 0; i < len(slMBs); i++ {
- if slMBs[i] == 0 {
- continue
- }
- if common == 0 {
- common = slMBs[i]
- continue
- }
- if slMBs[i] != common {
- return 0
- }
- }
- return common
- }
- type vlaUnmarshalingContext struct {
- payload []byte
- offset int
- slBMField uint8
- slBMs [4]uint8
- }
- func (ctx *vlaUnmarshalingContext) checkRemainingLen(requiredLen int) bool {
- return len(ctx.payload)-ctx.offset >= requiredLen
- }
- func (v *VLA) unmarshalSpatialLayers(ctx *vlaUnmarshalingContext) error {
- if !ctx.checkRemainingLen(1) {
- return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
- }
- v.RTPStreamID = int(ctx.payload[ctx.offset] >> 6 & 0b11)
- v.RTPStreamCount = int(ctx.payload[ctx.offset]>>4&0b11) + 1
- // sl_bm fields
- ctx.slBMField = ctx.payload[ctx.offset] & 0b1111
- ctx.offset++
- if ctx.slBMField != 0 {
- for streamID := 0; streamID < v.RTPStreamCount; streamID++ {
- ctx.slBMs[streamID] = ctx.slBMField
- }
- } else {
- if !ctx.checkRemainingLen((v.RTPStreamCount-1)/2 + 1) {
- return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
- }
- // slX_bm fields
- for streamID := 0; streamID < v.RTPStreamCount; streamID++ {
- var bm uint8
- if streamID%2 == 0 {
- bm = ctx.payload[ctx.offset+streamID/2] >> 4 & 0b1111
- } else {
- bm = ctx.payload[ctx.offset+streamID/2] & 0b1111
- }
- ctx.slBMs[streamID] = bm
- }
- ctx.offset += 1 + (v.RTPStreamCount-1)/2
- }
- return nil
- }
- func (v *VLA) unmarshalTemporalLayers(ctx *vlaUnmarshalingContext) error {
- if !ctx.checkRemainingLen(1) {
- return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
- }
- var temporalLayerIndex int
- for streamID := 0; streamID < v.RTPStreamCount; streamID++ {
- for spatialID := 0; spatialID < 4; spatialID++ {
- if ctx.slBMs[streamID]&(1<<spatialID) == 0 {
- continue
- }
- if temporalLayerIndex >= 4 {
- temporalLayerIndex = 0
- ctx.offset++
- if !ctx.checkRemainingLen(1) {
- return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
- }
- }
- tlCount := int(ctx.payload[ctx.offset]>>(2*(3-temporalLayerIndex))&0b11) + 1
- temporalLayerIndex++
- sl := SpatialLayer{
- RTPStreamID: streamID,
- SpatialID: spatialID,
- TargetBitrates: make([]int, tlCount),
- }
- v.ActiveSpatialLayer = append(v.ActiveSpatialLayer, sl)
- }
- }
- ctx.offset++
- // target bitrates
- for i, sl := range v.ActiveSpatialLayer {
- for j := range sl.TargetBitrates {
- kbps, n, err := obu.ReadLeb128(ctx.payload[ctx.offset:])
- if err != nil {
- return err
- }
- if !ctx.checkRemainingLen(int(n)) {
- return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
- }
- v.ActiveSpatialLayer[i].TargetBitrates[j] = int(kbps)
- ctx.offset += int(n)
- }
- }
- return nil
- }
- func (v *VLA) unmarshalResolutionAndFramerate(ctx *vlaUnmarshalingContext) error {
- if !ctx.checkRemainingLen(len(v.ActiveSpatialLayer) * 5) {
- return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
- }
- v.HasResolutionAndFramerate = true
- for i := range v.ActiveSpatialLayer {
- v.ActiveSpatialLayer[i].Width = int(binary.BigEndian.Uint16(ctx.payload[ctx.offset+0:])) + 1
- v.ActiveSpatialLayer[i].Height = int(binary.BigEndian.Uint16(ctx.payload[ctx.offset+2:])) + 1
- v.ActiveSpatialLayer[i].Framerate = int(ctx.payload[ctx.offset+4])
- ctx.offset += 5
- }
- return nil
- }
- // Unmarshal decodes VLA from a byte slice.
- func (v *VLA) Unmarshal(payload []byte) (int, error) {
- ctx := &vlaUnmarshalingContext{
- payload: payload,
- }
- err := v.unmarshalSpatialLayers(ctx)
- if err != nil {
- return ctx.offset, err
- }
- // #tl fields (build the list ActiveSpatialLayer at the same time)
- err = v.unmarshalTemporalLayers(ctx)
- if err != nil {
- return ctx.offset, err
- }
- if len(ctx.payload) == ctx.offset {
- return ctx.offset, nil
- }
- // resolution & framerate (optional)
- err = v.unmarshalResolutionAndFramerate(ctx)
- if err != nil {
- return ctx.offset, err
- }
- return ctx.offset, nil
- }
- // String makes VLA printable.
- func (v VLA) String() string {
- out := fmt.Sprintf("RID:%d,RTPStreamCount:%d", v.RTPStreamID, v.RTPStreamCount)
- var slOut []string
- for _, sl := range v.ActiveSpatialLayer {
- out2 := fmt.Sprintf("RTPStreamID:%d", sl.RTPStreamID)
- out2 += fmt.Sprintf(",TargetBitrates:%v", sl.TargetBitrates)
- if v.HasResolutionAndFramerate {
- out2 += fmt.Sprintf(",Resolution:(%d,%d)", sl.Width, sl.Height)
- out2 += fmt.Sprintf(",Framerate:%d", sl.Framerate)
- }
- slOut = append(slOut, out2)
- }
- out += fmt.Sprintf(",ActiveSpatialLayers:{%s}", strings.Join(slOut, ","))
- return out
- }
|