From a71cb0a6d5678c15b41fe4a4c253dfda5bd43785 Mon Sep 17 00:00:00 2001 From: Mathis Engelbart Date: Thu, 16 Jan 2025 22:28:05 +0100 Subject: [PATCH] Add support for TWCC --- pkg/ccfb/history.go | 29 ++++--- pkg/ccfb/interceptor.go | 53 ++++++++++--- pkg/ccfb/twcc_receiver.go | 94 ++++++++++++++++++++++ pkg/ccfb/twcc_receiver_test.go | 141 +++++++++++++++++++++++++++++++++ 4 files changed, 297 insertions(+), 20 deletions(-) create mode 100644 pkg/ccfb/twcc_receiver.go create mode 100644 pkg/ccfb/twcc_receiver_test.go diff --git a/pkg/ccfb/history.go b/pkg/ccfb/history.go index e78e192..bec0f76 100644 --- a/pkg/ccfb/history.go +++ b/pkg/ccfb/history.go @@ -3,6 +3,7 @@ package ccfb import ( "container/list" "errors" + "sync" "time" "github.com/pion/interceptor/internal/sequencenumber" @@ -30,6 +31,7 @@ type sentPacket struct { } type history struct { + lock sync.Mutex size int evictList *list.List seqNrToPacket map[int64]*list.Element @@ -39,6 +41,7 @@ type history struct { func newHistory(size int) *history { return &history{ + lock: sync.Mutex{}, size: size, evictList: list.New(), seqNrToPacket: make(map[int64]*list.Element), @@ -48,6 +51,9 @@ func newHistory(size int) *history { } func (h *history) add(seqNr uint16, size uint16, departure time.Time) error { + h.lock.Lock() + defer h.lock.Unlock() + sn := h.sentSeqNr.Unwrap(seqNr) last := h.evictList.Back() if last != nil { @@ -65,11 +71,23 @@ func (h *history) add(seqNr uint16, size uint16, departure time.Time) error { if h.evictList.Len() > h.size { h.removeOldest() } - return nil } +// Must be called while holding the lock +func (h *history) removeOldest() { + if ent := h.evictList.Front(); ent != nil { + v := h.evictList.Remove(ent) + if sp, ok := v.(sentPacket); ok { + delete(h.seqNrToPacket, sp.seqNr) + } + } +} + func (h *history) getReportForAck(al acknowledgementList) PacketReportList { + h.lock.Lock() + defer h.lock.Unlock() + var reports []PacketReport for _, pr := range al.acks { sn := h.ackedSeqNr.Unwrap(pr.seqNr) @@ -95,12 +113,3 @@ func (h *history) getReportForAck(al acknowledgementList) PacketReportList { Reports: reports, } } - -func (h *history) removeOldest() { - if ent := h.evictList.Front(); ent != nil { - v := h.evictList.Remove(ent) - if sp, ok := v.(sentPacket); ok { - delete(h.seqNrToPacket, sp.seqNr) - } - } -} diff --git a/pkg/ccfb/interceptor.go b/pkg/ccfb/interceptor.go index fba41b7..e8b4afe 100644 --- a/pkg/ccfb/interceptor.go +++ b/pkg/ccfb/interceptor.go @@ -9,6 +9,8 @@ import ( "github.com/pion/rtp" ) +const transportCCURI = "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01" + type ccfbAttributesKeyType uint32 const CCFBAttributesKey ccfbAttributesKeyType = iota @@ -48,14 +50,42 @@ type Interceptor struct { // BindLocalStream implements interceptor.Interceptor. func (i *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { + var twccHdrExtID uint8 + var useTWCC bool + for _, e := range info.RTPHeaderExtensions { + if e.URI == transportCCURI { + twccHdrExtID = uint8(e.ID) + useTWCC = true + break + } + } + i.lock.Lock() defer i.lock.Unlock() - i.ssrcToHistory[info.SSRC] = newHistory(200) + + ssrc := info.SSRC + if useTWCC { + ssrc = 0 + } + i.ssrcToHistory[ssrc] = newHistory(200) return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { i.lock.Lock() defer i.lock.Unlock() - i.ssrcToHistory[header.SSRC].add(header.SequenceNumber, uint16(header.MarshalSize()+len(payload)), i.timestamp()) + + // If we are using TWCC, we use the sequence number from the TWCC header + // extension and save all TWCC sequence numbers with the same SSRC (0). + // If we are not using TWCC, we save a history per SSRC and use the + // normal RTP sequence numbers. + ssrc := header.SSRC + seqNr := header.SequenceNumber + if useTWCC { + ssrc = 0 + var twccHdrExt rtp.TransportCCExtension + twccHdrExt.Unmarshal(header.GetExtension(twccHdrExtID)) + seqNr = twccHdrExt.TransportSequence + } + i.ssrcToHistory[ssrc].add(seqNr, uint16(header.MarshalSize()+len(payload)), i.timestamp()) return writer.Write(header, payload, attributes) }) } @@ -80,16 +110,19 @@ func (i *Interceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor. pkts, err := attr.GetRTCPPackets(buf) for _, pkt := range pkts { + var reportLists map[uint32]acknowledgementList switch fb := pkt.(type) { case *rtcp.CCFeedbackReport: - reportLists := convertCCFB(now, fb) - for ssrc, reportList := range reportLists { - prl := i.ssrcToHistory[ssrc].getReportForAck(reportList) - if l, ok := pktReportLists[ssrc]; !ok { - pktReportLists[ssrc] = &prl - } else { - l.Reports = append(l.Reports, prl.Reports...) - } + reportLists = convertCCFB(now, fb) + case *rtcp.TransportLayerCC: + reportLists = convertTWCC(now, fb) + } + for ssrc, reportList := range reportLists { + prl := i.ssrcToHistory[ssrc].getReportForAck(reportList) + if l, ok := pktReportLists[ssrc]; !ok { + pktReportLists[ssrc] = &prl + } else { + l.Reports = append(l.Reports, prl.Reports...) } } } diff --git a/pkg/ccfb/twcc_receiver.go b/pkg/ccfb/twcc_receiver.go new file mode 100644 index 0000000..d5c4c19 --- /dev/null +++ b/pkg/ccfb/twcc_receiver.go @@ -0,0 +1,94 @@ +package ccfb + +import ( + "log" + "time" + + "github.com/pion/rtcp" +) + +func convertTWCC(ts time.Time, feedback *rtcp.TransportLayerCC) map[uint32]acknowledgementList { + log.Printf("got twcc report: %v", feedback) + if feedback == nil { + return nil + } + var acks []acknowledgement + + nextTimestamp := time.Time{}.Add(time.Duration(feedback.ReferenceTime) * 64 * time.Millisecond) + recvDeltaIndex := 0 + + offset := 0 + for _, pc := range feedback.PacketChunks { + switch chunk := pc.(type) { + case *rtcp.RunLengthChunk: + for i := uint16(0); i < chunk.RunLength; i++ { + seqNr := feedback.BaseSequenceNumber + uint16(offset) + offset++ + switch chunk.PacketStatusSymbol { + case rtcp.TypeTCCPacketNotReceived: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: false, + arrival: time.Time{}, + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedSmallDelta, rtcp.TypeTCCPacketReceivedLargeDelta: + delta := feedback.RecvDeltas[recvDeltaIndex] + nextTimestamp = nextTimestamp.Add(time.Duration(delta.Delta) * time.Microsecond) + recvDeltaIndex++ + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: nextTimestamp, + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedWithoutDelta: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: time.Time{}, + ecn: 0, + }) + } + } + case *rtcp.StatusVectorChunk: + for _, s := range chunk.SymbolList { + seqNr := feedback.BaseSequenceNumber + uint16(offset) + offset++ + switch s { + case rtcp.TypeTCCPacketNotReceived: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: false, + arrival: time.Time{}, + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedSmallDelta, rtcp.TypeTCCPacketReceivedLargeDelta: + delta := feedback.RecvDeltas[recvDeltaIndex] + nextTimestamp = nextTimestamp.Add(time.Duration(delta.Delta) * time.Microsecond) + recvDeltaIndex++ + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: nextTimestamp, + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedWithoutDelta: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: time.Time{}, + ecn: 0, + }) + } + } + } + } + + return map[uint32]acknowledgementList{ + 0: { + ts: ts, + acks: acks, + }, + } +} diff --git a/pkg/ccfb/twcc_receiver_test.go b/pkg/ccfb/twcc_receiver_test.go new file mode 100644 index 0000000..d042e95 --- /dev/null +++ b/pkg/ccfb/twcc_receiver_test.go @@ -0,0 +1,141 @@ +package ccfb + +import ( + "fmt" + "testing" + "time" + + "github.com/pion/rtcp" + "github.com/stretchr/testify/assert" +) + +func TestConvertTWCC(t *testing.T) { + timeZero := time.Now() + cases := []struct { + ts time.Time + feedback *rtcp.TransportLayerCC + expect map[uint32]acknowledgementList + }{ + {}, + { + ts: timeZero.Add(2 * time.Second), + feedback: &rtcp.TransportLayerCC{ + SenderSSRC: 1, + MediaSSRC: 2, + BaseSequenceNumber: 178, + PacketStatusCount: 0, + ReferenceTime: 0, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{}, + RecvDeltas: []*rtcp.RecvDelta{}, + }, + expect: map[uint32]acknowledgementList{ + 2: { + ts: timeZero.Add(2 * time.Second), + acks: []acknowledgement{}, + }, + }, + }, + { + ts: timeZero.Add(2 * time.Second), + feedback: &rtcp.TransportLayerCC{ + SenderSSRC: 1, + MediaSSRC: 2, + BaseSequenceNumber: 178, + PacketStatusCount: 3, + ReferenceTime: 0, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.RunLengthChunk{ + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta, + RunLength: 3, + }, + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeOneBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0}, + {Type: rtcp.TypeTCCPacketReceivedLargeDelta, Delta: 0}, + {Type: rtcp.TypeTCCPacketReceivedLargeDelta, Delta: 0}, + }, + }, + expect: map[uint32]acknowledgementList{ + 2: { + ts: timeZero.Add(2 * time.Second), + acks: []acknowledgement{ + // first run length chunk + {seqNr: 178, arrived: true, arrival: time.Time{}, ecn: 0}, + {seqNr: 179, arrived: true, arrival: time.Time{}, ecn: 0}, + {seqNr: 180, arrived: true, arrival: time.Time{}, ecn: 0}, + + // first status vector chunk + {seqNr: 181, arrived: true, arrival: time.Time{}, ecn: 0}, + {seqNr: 182, arrived: true, arrival: time.Time{}, ecn: 0}, + {seqNr: 183, arrived: true, arrival: time.Time{}, ecn: 0}, + {seqNr: 184, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 185, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 186, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 187, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 188, arrived: false, arrival: time.Time{}, ecn: 0}, + + // second status vector chunk + {seqNr: 189, arrived: true, arrival: time.Time{}, ecn: 0}, + {seqNr: 190, arrived: true, arrival: time.Time{}, ecn: 0}, + {seqNr: 191, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 192, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 193, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 194, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 195, arrived: false, arrival: time.Time{}, ecn: 0}, + }, + }, + }, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + res := convertTWCC(tc.ts, tc.feedback) + + // Can't directly check equality since arrival timestamp conversions + // may be slightly off due to ntp conversions. + assert.Equal(t, len(tc.expect), len(res)) + for i, ee := range tc.expect { + assert.Equal(t, ee.ts, res[i].ts) + for j, ack := range ee.acks { + assert.Equal(t, ack.seqNr, res[i].acks[j].seqNr) + assert.Equal(t, ack.arrived, res[i].acks[j].arrived) + assert.Equal(t, ack.ecn, res[i].acks[j].ecn) + assert.InDelta(t, ack.arrival.UnixNano(), res[i].acks[j].arrival.UnixNano(), float64(time.Millisecond.Nanoseconds())) + } + } + }) + } + +}