vlaextension.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package rtp
  4. import (
  5. "encoding/binary"
  6. "errors"
  7. "fmt"
  8. "strings"
  9. "github.com/pion/rtp/codecs/av1/obu"
  10. )
  11. var (
  12. ErrVLATooShort = errors.New("VLA payload too short") // ErrVLATooShort is returned when payload is too short
  13. ErrVLAInvalidStreamCount = errors.New("invalid RTP stream count in VLA") // ErrVLAInvalidStreamCount is returned when RTP stream count is invalid
  14. ErrVLAInvalidStreamID = errors.New("invalid RTP stream ID in VLA") // ErrVLAInvalidStreamID is returned when RTP stream ID is invalid
  15. ErrVLAInvalidSpatialID = errors.New("invalid spatial ID in VLA") // ErrVLAInvalidSpatialID is returned when spatial ID is invalid
  16. ErrVLADuplicateSpatialID = errors.New("duplicate spatial ID in VLA") // ErrVLADuplicateSpatialID is returned when spatial ID is invalid
  17. ErrVLAInvalidTemporalLayer = errors.New("invalid temporal layer in VLA") // ErrVLAInvalidTemporalLayer is returned when temporal layer is invalid
  18. )
  19. // SpatialLayer is a spatial layer in VLA.
  20. type SpatialLayer struct {
  21. RTPStreamID int
  22. SpatialID int
  23. TargetBitrates []int // target bitrates per temporal layer
  24. // Following members are valid only when HasResolutionAndFramerate is true
  25. Width int
  26. Height int
  27. Framerate int
  28. }
  29. // VLA is a Video Layer Allocation (VLA) extension.
  30. // See https://webrtc.googlesource.com/src/+/refs/heads/main/docs/native-code/rtp-hdrext/video-layers-allocation00
  31. type VLA struct {
  32. RTPStreamID int // 0-origin RTP stream ID (RID) this allocation is sent on (0..3)
  33. RTPStreamCount int // Number of RTP streams (1..4)
  34. ActiveSpatialLayer []SpatialLayer
  35. HasResolutionAndFramerate bool
  36. }
  37. type vlaMarshalingContext struct {
  38. slMBs [4]uint8
  39. sls [4][4]*SpatialLayer
  40. commonSLBM uint8
  41. encodedTargetBitrates [][]byte
  42. requiredLen int
  43. }
  44. func (v VLA) preprocessForMashaling(ctx *vlaMarshalingContext) error {
  45. for i := 0; i < len(v.ActiveSpatialLayer); i++ {
  46. sl := v.ActiveSpatialLayer[i]
  47. if sl.RTPStreamID < 0 || sl.RTPStreamID >= v.RTPStreamCount {
  48. return fmt.Errorf("invalid RTP streamID %d:%w", sl.RTPStreamID, ErrVLAInvalidStreamID)
  49. }
  50. if sl.SpatialID < 0 || sl.SpatialID >= 4 {
  51. return fmt.Errorf("invalid spatial ID %d: %w", sl.SpatialID, ErrVLAInvalidSpatialID)
  52. }
  53. if len(sl.TargetBitrates) == 0 || len(sl.TargetBitrates) > 4 {
  54. return fmt.Errorf("invalid temporal layer count %d: %w", len(sl.TargetBitrates), ErrVLAInvalidTemporalLayer)
  55. }
  56. ctx.slMBs[sl.RTPStreamID] |= 1 << sl.SpatialID
  57. if ctx.sls[sl.RTPStreamID][sl.SpatialID] != nil {
  58. return fmt.Errorf("duplicate spatial layer: %w", ErrVLADuplicateSpatialID)
  59. }
  60. ctx.sls[sl.RTPStreamID][sl.SpatialID] = &sl
  61. }
  62. return nil
  63. }
  64. func (v VLA) encodeTargetBitrates(ctx *vlaMarshalingContext) {
  65. for rtpStreamID := 0; rtpStreamID < v.RTPStreamCount; rtpStreamID++ {
  66. for spatialID := 0; spatialID < 4; spatialID++ {
  67. if sl := ctx.sls[rtpStreamID][spatialID]; sl != nil {
  68. for _, kbps := range sl.TargetBitrates {
  69. leb128 := obu.WriteToLeb128(uint(kbps))
  70. ctx.encodedTargetBitrates = append(ctx.encodedTargetBitrates, leb128)
  71. ctx.requiredLen += len(leb128)
  72. }
  73. }
  74. }
  75. }
  76. }
  77. func (v VLA) analyzeVLAForMarshaling() (*vlaMarshalingContext, error) {
  78. // Validate RTPStreamCount
  79. if v.RTPStreamCount <= 0 || v.RTPStreamCount > 4 {
  80. return nil, ErrVLAInvalidStreamCount
  81. }
  82. // Validate RTPStreamID
  83. if v.RTPStreamID < 0 || v.RTPStreamID >= v.RTPStreamCount {
  84. return nil, ErrVLAInvalidStreamID
  85. }
  86. ctx := &vlaMarshalingContext{}
  87. err := v.preprocessForMashaling(ctx)
  88. if err != nil {
  89. return nil, err
  90. }
  91. ctx.commonSLBM = commonSLBMValues(ctx.slMBs[:])
  92. // RID, NS, sl_bm fields
  93. if ctx.commonSLBM != 0 {
  94. ctx.requiredLen = 1
  95. } else {
  96. ctx.requiredLen = 3
  97. }
  98. // #tl fields
  99. ctx.requiredLen += (len(v.ActiveSpatialLayer)-1)/4 + 1
  100. v.encodeTargetBitrates(ctx)
  101. if v.HasResolutionAndFramerate {
  102. ctx.requiredLen += len(v.ActiveSpatialLayer) * 5
  103. }
  104. return ctx, nil
  105. }
  106. // Marshal encodes VLA into a byte slice.
  107. func (v VLA) Marshal() ([]byte, error) {
  108. ctx, err := v.analyzeVLAForMarshaling()
  109. if err != nil {
  110. return nil, err
  111. }
  112. payload := make([]byte, ctx.requiredLen)
  113. offset := 0
  114. // RID, NS, sl_bm fields
  115. payload[offset] = byte(v.RTPStreamID<<6) | byte(v.RTPStreamCount-1)<<4 | ctx.commonSLBM
  116. if ctx.commonSLBM == 0 {
  117. offset++
  118. for streamID := 0; streamID < v.RTPStreamCount; streamID++ {
  119. if streamID%2 == 0 {
  120. payload[offset+streamID/2] |= ctx.slMBs[streamID] << 4
  121. } else {
  122. payload[offset+streamID/2] |= ctx.slMBs[streamID]
  123. }
  124. }
  125. offset += (v.RTPStreamCount - 1) / 2
  126. }
  127. // #tl fields
  128. offset++
  129. var temporalLayerIndex int
  130. for rtpStreamID := 0; rtpStreamID < v.RTPStreamCount; rtpStreamID++ {
  131. for spatialID := 0; spatialID < 4; spatialID++ {
  132. if sl := ctx.sls[rtpStreamID][spatialID]; sl != nil {
  133. if temporalLayerIndex >= 4 {
  134. temporalLayerIndex = 0
  135. offset++
  136. }
  137. payload[offset] |= byte(len(sl.TargetBitrates)-1) << (2 * (3 - temporalLayerIndex))
  138. temporalLayerIndex++
  139. }
  140. }
  141. }
  142. // Target bitrate fields
  143. offset++
  144. for _, encodedKbps := range ctx.encodedTargetBitrates {
  145. encodedSize := len(encodedKbps)
  146. copy(payload[offset:], encodedKbps)
  147. offset += encodedSize
  148. }
  149. // Resolution & framerate fields
  150. if v.HasResolutionAndFramerate {
  151. for _, sl := range v.ActiveSpatialLayer {
  152. binary.BigEndian.PutUint16(payload[offset+0:], uint16(sl.Width-1))
  153. binary.BigEndian.PutUint16(payload[offset+2:], uint16(sl.Height-1))
  154. payload[offset+4] = byte(sl.Framerate)
  155. offset += 5
  156. }
  157. }
  158. return payload, nil
  159. }
  160. func commonSLBMValues(slMBs []uint8) uint8 {
  161. var common uint8
  162. for i := 0; i < len(slMBs); i++ {
  163. if slMBs[i] == 0 {
  164. continue
  165. }
  166. if common == 0 {
  167. common = slMBs[i]
  168. continue
  169. }
  170. if slMBs[i] != common {
  171. return 0
  172. }
  173. }
  174. return common
  175. }
  176. type vlaUnmarshalingContext struct {
  177. payload []byte
  178. offset int
  179. slBMField uint8
  180. slBMs [4]uint8
  181. }
  182. func (ctx *vlaUnmarshalingContext) checkRemainingLen(requiredLen int) bool {
  183. return len(ctx.payload)-ctx.offset >= requiredLen
  184. }
  185. func (v *VLA) unmarshalSpatialLayers(ctx *vlaUnmarshalingContext) error {
  186. if !ctx.checkRemainingLen(1) {
  187. return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
  188. }
  189. v.RTPStreamID = int(ctx.payload[ctx.offset] >> 6 & 0b11)
  190. v.RTPStreamCount = int(ctx.payload[ctx.offset]>>4&0b11) + 1
  191. // sl_bm fields
  192. ctx.slBMField = ctx.payload[ctx.offset] & 0b1111
  193. ctx.offset++
  194. if ctx.slBMField != 0 {
  195. for streamID := 0; streamID < v.RTPStreamCount; streamID++ {
  196. ctx.slBMs[streamID] = ctx.slBMField
  197. }
  198. } else {
  199. if !ctx.checkRemainingLen((v.RTPStreamCount-1)/2 + 1) {
  200. return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
  201. }
  202. // slX_bm fields
  203. for streamID := 0; streamID < v.RTPStreamCount; streamID++ {
  204. var bm uint8
  205. if streamID%2 == 0 {
  206. bm = ctx.payload[ctx.offset+streamID/2] >> 4 & 0b1111
  207. } else {
  208. bm = ctx.payload[ctx.offset+streamID/2] & 0b1111
  209. }
  210. ctx.slBMs[streamID] = bm
  211. }
  212. ctx.offset += 1 + (v.RTPStreamCount-1)/2
  213. }
  214. return nil
  215. }
  216. func (v *VLA) unmarshalTemporalLayers(ctx *vlaUnmarshalingContext) error {
  217. if !ctx.checkRemainingLen(1) {
  218. return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
  219. }
  220. var temporalLayerIndex int
  221. for streamID := 0; streamID < v.RTPStreamCount; streamID++ {
  222. for spatialID := 0; spatialID < 4; spatialID++ {
  223. if ctx.slBMs[streamID]&(1<<spatialID) == 0 {
  224. continue
  225. }
  226. if temporalLayerIndex >= 4 {
  227. temporalLayerIndex = 0
  228. ctx.offset++
  229. if !ctx.checkRemainingLen(1) {
  230. return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
  231. }
  232. }
  233. tlCount := int(ctx.payload[ctx.offset]>>(2*(3-temporalLayerIndex))&0b11) + 1
  234. temporalLayerIndex++
  235. sl := SpatialLayer{
  236. RTPStreamID: streamID,
  237. SpatialID: spatialID,
  238. TargetBitrates: make([]int, tlCount),
  239. }
  240. v.ActiveSpatialLayer = append(v.ActiveSpatialLayer, sl)
  241. }
  242. }
  243. ctx.offset++
  244. // target bitrates
  245. for i, sl := range v.ActiveSpatialLayer {
  246. for j := range sl.TargetBitrates {
  247. kbps, n, err := obu.ReadLeb128(ctx.payload[ctx.offset:])
  248. if err != nil {
  249. return err
  250. }
  251. if !ctx.checkRemainingLen(int(n)) {
  252. return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
  253. }
  254. v.ActiveSpatialLayer[i].TargetBitrates[j] = int(kbps)
  255. ctx.offset += int(n)
  256. }
  257. }
  258. return nil
  259. }
  260. func (v *VLA) unmarshalResolutionAndFramerate(ctx *vlaUnmarshalingContext) error {
  261. if !ctx.checkRemainingLen(len(v.ActiveSpatialLayer) * 5) {
  262. return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
  263. }
  264. v.HasResolutionAndFramerate = true
  265. for i := range v.ActiveSpatialLayer {
  266. v.ActiveSpatialLayer[i].Width = int(binary.BigEndian.Uint16(ctx.payload[ctx.offset+0:])) + 1
  267. v.ActiveSpatialLayer[i].Height = int(binary.BigEndian.Uint16(ctx.payload[ctx.offset+2:])) + 1
  268. v.ActiveSpatialLayer[i].Framerate = int(ctx.payload[ctx.offset+4])
  269. ctx.offset += 5
  270. }
  271. return nil
  272. }
  273. // Unmarshal decodes VLA from a byte slice.
  274. func (v *VLA) Unmarshal(payload []byte) (int, error) {
  275. ctx := &vlaUnmarshalingContext{
  276. payload: payload,
  277. }
  278. err := v.unmarshalSpatialLayers(ctx)
  279. if err != nil {
  280. return ctx.offset, err
  281. }
  282. // #tl fields (build the list ActiveSpatialLayer at the same time)
  283. err = v.unmarshalTemporalLayers(ctx)
  284. if err != nil {
  285. return ctx.offset, err
  286. }
  287. if len(ctx.payload) == ctx.offset {
  288. return ctx.offset, nil
  289. }
  290. // resolution & framerate (optional)
  291. err = v.unmarshalResolutionAndFramerate(ctx)
  292. if err != nil {
  293. return ctx.offset, err
  294. }
  295. return ctx.offset, nil
  296. }
  297. // String makes VLA printable.
  298. func (v VLA) String() string {
  299. out := fmt.Sprintf("RID:%d,RTPStreamCount:%d", v.RTPStreamID, v.RTPStreamCount)
  300. var slOut []string
  301. for _, sl := range v.ActiveSpatialLayer {
  302. out2 := fmt.Sprintf("RTPStreamID:%d", sl.RTPStreamID)
  303. out2 += fmt.Sprintf(",TargetBitrates:%v", sl.TargetBitrates)
  304. if v.HasResolutionAndFramerate {
  305. out2 += fmt.Sprintf(",Resolution:(%d,%d)", sl.Width, sl.Height)
  306. out2 += fmt.Sprintf(",Framerate:%d", sl.Framerate)
  307. }
  308. slOut = append(slOut, out2)
  309. }
  310. out += fmt.Sprintf(",ActiveSpatialLayers:{%s}", strings.Join(slOut, ","))
  311. return out
  312. }