Skip to content

Commit

Permalink
Add support for TWCC
Browse files Browse the repository at this point in the history
  • Loading branch information
mengelbart committed Jan 17, 2025
1 parent 1fb32cc commit a71cb0a
Show file tree
Hide file tree
Showing 4 changed files with 297 additions and 20 deletions.
29 changes: 19 additions & 10 deletions pkg/ccfb/history.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ccfb
import (
"container/list"
"errors"
"sync"
"time"

"github.com/pion/interceptor/internal/sequencenumber"
Expand Down Expand Up @@ -30,6 +31,7 @@ type sentPacket struct {
}

type history struct {
lock sync.Mutex
size int
evictList *list.List
seqNrToPacket map[int64]*list.Element
Expand All @@ -39,6 +41,7 @@ type history struct {

func newHistory(size int) *history {
return &history{
lock: sync.Mutex{},
size: size,
evictList: list.New(),
seqNrToPacket: make(map[int64]*list.Element),
Expand All @@ -48,6 +51,9 @@ func newHistory(size int) *history {
}

func (h *history) add(seqNr uint16, size uint16, departure time.Time) error {
h.lock.Lock()
defer h.lock.Unlock()

sn := h.sentSeqNr.Unwrap(seqNr)
last := h.evictList.Back()
if last != nil {
Expand All @@ -65,11 +71,23 @@ func (h *history) add(seqNr uint16, size uint16, departure time.Time) error {
if h.evictList.Len() > h.size {
h.removeOldest()
}

return nil
}

// Must be called while holding the lock
func (h *history) removeOldest() {
if ent := h.evictList.Front(); ent != nil {
v := h.evictList.Remove(ent)
if sp, ok := v.(sentPacket); ok {
delete(h.seqNrToPacket, sp.seqNr)
}
}
}

func (h *history) getReportForAck(al acknowledgementList) PacketReportList {
h.lock.Lock()
defer h.lock.Unlock()

var reports []PacketReport
for _, pr := range al.acks {
sn := h.ackedSeqNr.Unwrap(pr.seqNr)
Expand All @@ -95,12 +113,3 @@ func (h *history) getReportForAck(al acknowledgementList) PacketReportList {
Reports: reports,
}
}

func (h *history) removeOldest() {
if ent := h.evictList.Front(); ent != nil {
v := h.evictList.Remove(ent)
if sp, ok := v.(sentPacket); ok {
delete(h.seqNrToPacket, sp.seqNr)
}
}
}
53 changes: 43 additions & 10 deletions pkg/ccfb/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"github.com/pion/rtp"
)

const transportCCURI = "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01"

type ccfbAttributesKeyType uint32

const CCFBAttributesKey ccfbAttributesKeyType = iota

Check failure on line 16 in pkg/ccfb/interceptor.go

View workflow job for this annotation

GitHub Actions / lint / Go

exported: exported const CCFBAttributesKey should have comment or be unexported (revive)
Expand Down Expand Up @@ -48,14 +50,42 @@ type Interceptor struct {

// BindLocalStream implements interceptor.Interceptor.
func (i *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
var twccHdrExtID uint8
var useTWCC bool
for _, e := range info.RTPHeaderExtensions {
if e.URI == transportCCURI {
twccHdrExtID = uint8(e.ID)
useTWCC = true
break
}
}

i.lock.Lock()
defer i.lock.Unlock()
i.ssrcToHistory[info.SSRC] = newHistory(200)

ssrc := info.SSRC
if useTWCC {
ssrc = 0
}
i.ssrcToHistory[ssrc] = newHistory(200)

return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
i.lock.Lock()
defer i.lock.Unlock()
i.ssrcToHistory[header.SSRC].add(header.SequenceNumber, uint16(header.MarshalSize()+len(payload)), i.timestamp())

// If we are using TWCC, we use the sequence number from the TWCC header
// extension and save all TWCC sequence numbers with the same SSRC (0).
// If we are not using TWCC, we save a history per SSRC and use the
// normal RTP sequence numbers.
ssrc := header.SSRC
seqNr := header.SequenceNumber
if useTWCC {
ssrc = 0
var twccHdrExt rtp.TransportCCExtension
twccHdrExt.Unmarshal(header.GetExtension(twccHdrExtID))

Check failure on line 85 in pkg/ccfb/interceptor.go

View workflow job for this annotation

GitHub Actions / lint / Go

Error return value of `twccHdrExt.Unmarshal` is not checked (errcheck)
seqNr = twccHdrExt.TransportSequence
}
i.ssrcToHistory[ssrc].add(seqNr, uint16(header.MarshalSize()+len(payload)), i.timestamp())

Check failure on line 88 in pkg/ccfb/interceptor.go

View workflow job for this annotation

GitHub Actions / lint / Go

Error return value of `(*github.com/pion/interceptor/pkg/ccfb.history).add` is not checked (errcheck)
return writer.Write(header, payload, attributes)
})
}
Expand All @@ -80,16 +110,19 @@ func (i *Interceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.

pkts, err := attr.GetRTCPPackets(buf)
for _, pkt := range pkts {
var reportLists map[uint32]acknowledgementList
switch fb := pkt.(type) {
case *rtcp.CCFeedbackReport:
reportLists := convertCCFB(now, fb)
for ssrc, reportList := range reportLists {
prl := i.ssrcToHistory[ssrc].getReportForAck(reportList)
if l, ok := pktReportLists[ssrc]; !ok {
pktReportLists[ssrc] = &prl
} else {
l.Reports = append(l.Reports, prl.Reports...)
}
reportLists = convertCCFB(now, fb)
case *rtcp.TransportLayerCC:
reportLists = convertTWCC(now, fb)
}
for ssrc, reportList := range reportLists {
prl := i.ssrcToHistory[ssrc].getReportForAck(reportList)
if l, ok := pktReportLists[ssrc]; !ok {
pktReportLists[ssrc] = &prl
} else {
l.Reports = append(l.Reports, prl.Reports...)
}
}
}
Expand Down
94 changes: 94 additions & 0 deletions pkg/ccfb/twcc_receiver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package ccfb

import (
"log"
"time"

"github.com/pion/rtcp"
)

func convertTWCC(ts time.Time, feedback *rtcp.TransportLayerCC) map[uint32]acknowledgementList {
log.Printf("got twcc report: %v", feedback)

Check failure on line 11 in pkg/ccfb/twcc_receiver.go

View workflow job for this annotation

GitHub Actions / lint / Go

use of `log.Printf` forbidden by pattern `^log.(Panic|Fatal|Print)(f|ln)?$` (forbidigo)
if feedback == nil {
return nil
}
var acks []acknowledgement

nextTimestamp := time.Time{}.Add(time.Duration(feedback.ReferenceTime) * 64 * time.Millisecond)
recvDeltaIndex := 0

offset := 0
for _, pc := range feedback.PacketChunks {
switch chunk := pc.(type) {
case *rtcp.RunLengthChunk:
for i := uint16(0); i < chunk.RunLength; i++ {
seqNr := feedback.BaseSequenceNumber + uint16(offset)
offset++
switch chunk.PacketStatusSymbol {
case rtcp.TypeTCCPacketNotReceived:
acks = append(acks, acknowledgement{
seqNr: seqNr,
arrived: false,
arrival: time.Time{},
ecn: 0,
})
case rtcp.TypeTCCPacketReceivedSmallDelta, rtcp.TypeTCCPacketReceivedLargeDelta:
delta := feedback.RecvDeltas[recvDeltaIndex]
nextTimestamp = nextTimestamp.Add(time.Duration(delta.Delta) * time.Microsecond)
recvDeltaIndex++
acks = append(acks, acknowledgement{
seqNr: seqNr,
arrived: true,
arrival: nextTimestamp,
ecn: 0,
})
case rtcp.TypeTCCPacketReceivedWithoutDelta:
acks = append(acks, acknowledgement{
seqNr: seqNr,
arrived: true,
arrival: time.Time{},
ecn: 0,
})
}
}
case *rtcp.StatusVectorChunk:
for _, s := range chunk.SymbolList {
seqNr := feedback.BaseSequenceNumber + uint16(offset)
offset++
switch s {
case rtcp.TypeTCCPacketNotReceived:
acks = append(acks, acknowledgement{
seqNr: seqNr,
arrived: false,
arrival: time.Time{},
ecn: 0,
})
case rtcp.TypeTCCPacketReceivedSmallDelta, rtcp.TypeTCCPacketReceivedLargeDelta:
delta := feedback.RecvDeltas[recvDeltaIndex]
nextTimestamp = nextTimestamp.Add(time.Duration(delta.Delta) * time.Microsecond)
recvDeltaIndex++
acks = append(acks, acknowledgement{
seqNr: seqNr,
arrived: true,
arrival: nextTimestamp,
ecn: 0,
})
case rtcp.TypeTCCPacketReceivedWithoutDelta:
acks = append(acks, acknowledgement{
seqNr: seqNr,
arrived: true,
arrival: time.Time{},
ecn: 0,
})
}
}
}
}

return map[uint32]acknowledgementList{
0: {
ts: ts,
acks: acks,
},
}
}
141 changes: 141 additions & 0 deletions pkg/ccfb/twcc_receiver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package ccfb

import (
"fmt"
"testing"
"time"

"github.com/pion/rtcp"
"github.com/stretchr/testify/assert"
)

func TestConvertTWCC(t *testing.T) {
timeZero := time.Now()
cases := []struct {
ts time.Time
feedback *rtcp.TransportLayerCC
expect map[uint32]acknowledgementList
}{
{},
{
ts: timeZero.Add(2 * time.Second),
feedback: &rtcp.TransportLayerCC{
SenderSSRC: 1,
MediaSSRC: 2,
BaseSequenceNumber: 178,
PacketStatusCount: 0,
ReferenceTime: 0,
FbPktCount: 0,
PacketChunks: []rtcp.PacketStatusChunk{},
RecvDeltas: []*rtcp.RecvDelta{},
},
expect: map[uint32]acknowledgementList{
2: {
ts: timeZero.Add(2 * time.Second),
acks: []acknowledgement{},
},
},
},
{
ts: timeZero.Add(2 * time.Second),
feedback: &rtcp.TransportLayerCC{
SenderSSRC: 1,
MediaSSRC: 2,
BaseSequenceNumber: 178,
PacketStatusCount: 3,
ReferenceTime: 0,
FbPktCount: 0,
PacketChunks: []rtcp.PacketStatusChunk{
&rtcp.RunLengthChunk{
PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta,
RunLength: 3,
},
&rtcp.StatusVectorChunk{
SymbolSize: rtcp.TypeTCCSymbolSizeOneBit,
SymbolList: []uint16{
rtcp.TypeTCCPacketReceivedSmallDelta,
rtcp.TypeTCCPacketReceivedSmallDelta,
rtcp.TypeTCCPacketReceivedSmallDelta,
rtcp.TypeTCCPacketNotReceived,
rtcp.TypeTCCPacketNotReceived,
rtcp.TypeTCCPacketNotReceived,
rtcp.TypeTCCPacketNotReceived,
rtcp.TypeTCCPacketNotReceived,
},
},
&rtcp.StatusVectorChunk{
SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit,
SymbolList: []uint16{
rtcp.TypeTCCPacketReceivedLargeDelta,
rtcp.TypeTCCPacketReceivedLargeDelta,
rtcp.TypeTCCPacketNotReceived,
rtcp.TypeTCCPacketNotReceived,
rtcp.TypeTCCPacketNotReceived,
rtcp.TypeTCCPacketNotReceived,
rtcp.TypeTCCPacketNotReceived,
},
},
},
RecvDeltas: []*rtcp.RecvDelta{
{Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0},
{Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0},
{Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0},
{Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0},
{Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0},
{Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0},
{Type: rtcp.TypeTCCPacketReceivedLargeDelta, Delta: 0},
{Type: rtcp.TypeTCCPacketReceivedLargeDelta, Delta: 0},
},
},
expect: map[uint32]acknowledgementList{
2: {
ts: timeZero.Add(2 * time.Second),
acks: []acknowledgement{
// first run length chunk
{seqNr: 178, arrived: true, arrival: time.Time{}, ecn: 0},
{seqNr: 179, arrived: true, arrival: time.Time{}, ecn: 0},
{seqNr: 180, arrived: true, arrival: time.Time{}, ecn: 0},

// first status vector chunk
{seqNr: 181, arrived: true, arrival: time.Time{}, ecn: 0},
{seqNr: 182, arrived: true, arrival: time.Time{}, ecn: 0},
{seqNr: 183, arrived: true, arrival: time.Time{}, ecn: 0},
{seqNr: 184, arrived: false, arrival: time.Time{}, ecn: 0},
{seqNr: 185, arrived: false, arrival: time.Time{}, ecn: 0},
{seqNr: 186, arrived: false, arrival: time.Time{}, ecn: 0},
{seqNr: 187, arrived: false, arrival: time.Time{}, ecn: 0},
{seqNr: 188, arrived: false, arrival: time.Time{}, ecn: 0},

// second status vector chunk
{seqNr: 189, arrived: true, arrival: time.Time{}, ecn: 0},
{seqNr: 190, arrived: true, arrival: time.Time{}, ecn: 0},
{seqNr: 191, arrived: false, arrival: time.Time{}, ecn: 0},
{seqNr: 192, arrived: false, arrival: time.Time{}, ecn: 0},
{seqNr: 193, arrived: false, arrival: time.Time{}, ecn: 0},
{seqNr: 194, arrived: false, arrival: time.Time{}, ecn: 0},
{seqNr: 195, arrived: false, arrival: time.Time{}, ecn: 0},
},
},
},
},
}
for i, tc := range cases {
t.Run(fmt.Sprintf("%v", i), func(t *testing.T) {
res := convertTWCC(tc.ts, tc.feedback)

// Can't directly check equality since arrival timestamp conversions
// may be slightly off due to ntp conversions.
assert.Equal(t, len(tc.expect), len(res))
for i, ee := range tc.expect {
assert.Equal(t, ee.ts, res[i].ts)
for j, ack := range ee.acks {
assert.Equal(t, ack.seqNr, res[i].acks[j].seqNr)
assert.Equal(t, ack.arrived, res[i].acks[j].arrived)
assert.Equal(t, ack.ecn, res[i].acks[j].ecn)
assert.InDelta(t, ack.arrival.UnixNano(), res[i].acks[j].arrival.UnixNano(), float64(time.Millisecond.Nanoseconds()))
}
}
})
}

}

0 comments on commit a71cb0a

Please sign in to comment.