From a50e9c6771daa4fb1380c8227b5196dcd31c6496 Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Sat, 3 Oct 2020 12:54:17 +0200 Subject: [PATCH] Buffer last keyframe. --- packetcache/packetcache.go | 80 ++++++++++++++++++++++++++----- packetcache/packetcache_test.go | 85 ++++++++++++++++++++++++++++----- rtpconn/rtpreader.go | 23 ++++++++- rtpconn/rtpwriter.go | 31 +++++++++++- 4 files changed, 191 insertions(+), 28 deletions(-) diff --git a/packetcache/packetcache.go b/packetcache/packetcache.go index 2d8a6ea..ab9ecf5 100644 --- a/packetcache/packetcache.go +++ b/packetcache/packetcache.go @@ -5,6 +5,7 @@ import ( ) const BufSize = 1500 +const maxKeyframe = 1024 type entry struct { seqno uint16 @@ -24,6 +25,9 @@ type Cache struct { // bitmap first uint16 bitmap uint32 + // buffered keyframe + kfTimestamp uint32 + kfEntries []entry // packet cache tail uint16 entries []entry @@ -75,7 +79,7 @@ func (cache *Cache) set(seqno uint16) { } // Store a packet, setting bitmap at the same time -func (cache *Cache) Store(seqno uint16, buf []byte) (uint16, uint16) { +func (cache *Cache) Store(seqno uint16, timestamp uint32, keyframe bool, buf []byte) (uint16, uint16) { cache.mu.Lock() defer cache.mu.Unlock() @@ -97,9 +101,39 @@ func (cache *Cache) Store(seqno uint16, buf []byte) (uint16, uint16) { } } } - cache.set(seqno) + doit := false + if keyframe { + if cache.kfTimestamp != timestamp { + cache.kfTimestamp = timestamp + cache.kfEntries = cache.kfEntries[:0] + } + doit = true + } else if len(cache.kfEntries) > 0 { + doit = cache.kfTimestamp == timestamp + } + if doit { + i := 0 + for i < len(cache.kfEntries) { + if cache.kfEntries[i].seqno >= seqno { + break + } + i++ + } + + if i >= len(cache.kfEntries) || cache.kfEntries[i].seqno != seqno { + if len(cache.kfEntries) >= maxKeyframe { + cache.kfEntries = cache.kfEntries[:maxKeyframe-1] + } + cache.kfEntries = append(cache.kfEntries, entry{}) + copy(cache.kfEntries[i+1:], cache.kfEntries[i:]) + } + cache.kfEntries[i].seqno = seqno + cache.kfEntries[i].length = uint16(len(buf)) + copy(cache.kfEntries[i].buf[:], buf) + } + i := cache.tail cache.entries[i].seqno = seqno copy(cache.entries[i].buf[:], buf) @@ -118,20 +152,33 @@ func (cache *Cache) Expect(n int) { cache.expected += uint32(n) } +func get(seqno uint16, entries []entry, result []byte) uint16 { + for i := range entries { + if entries[i].length == 0 || entries[i].seqno != seqno { + continue + } + return uint16(copy( + result[:entries[i].length], + entries[i].buf[:]), + ) + } + return 0 +} + func (cache *Cache) Get(seqno uint16, result []byte) uint16 { cache.mu.Lock() defer cache.mu.Unlock() - for i := range cache.entries { - if cache.entries[i].length == 0 || - cache.entries[i].seqno != seqno { - continue - } - return uint16(copy( - result[:cache.entries[i].length], - cache.entries[i].buf[:]), - ) + n := get(seqno, cache.kfEntries, result) + if n > 0 { + return n } + + n = get(seqno, cache.entries, result) + if n > 0 { + return n + } + return 0 } @@ -151,6 +198,17 @@ func (cache *Cache) GetAt(seqno uint16, index uint16, result []byte) uint16 { ) } +func (cache *Cache) Keyframe() (uint32, []uint16) { + cache.mu.Lock() + defer cache.mu.Unlock() + + seqnos := make([]uint16, len(cache.kfEntries)) + for i := range cache.kfEntries { + seqnos[i] = cache.kfEntries[i].seqno + } + return cache.kfTimestamp, seqnos +} + func (cache *Cache) resize(capacity int) { if len(cache.entries) == capacity { return diff --git a/packetcache/packetcache_test.go b/packetcache/packetcache_test.go index 5188e2b..ac9f2d2 100644 --- a/packetcache/packetcache_test.go +++ b/packetcache/packetcache_test.go @@ -20,8 +20,8 @@ func TestCache(t *testing.T) { buf1 := randomBuf() buf2 := randomBuf() cache := New(16) - _, i1 := cache.Store(13, buf1) - _, i2 := cache.Store(17, buf2) + _, i1 := cache.Store(13, 0, false, buf1) + _, i2 := cache.Store(17, 0, false, buf2) buf := make([]byte, BufSize) @@ -62,7 +62,7 @@ func TestCacheOverflow(t *testing.T) { cache := New(16) for i := 0; i < 32; i++ { - cache.Store(uint16(i), []byte{uint8(i)}) + cache.Store(uint16(i), 0, false, []byte{uint8(i)}) } for i := 0; i < 32; i++ { @@ -84,7 +84,7 @@ func TestCacheGrow(t *testing.T) { cache := New(16) for i := 0; i < 24; i++ { - cache.Store(uint16(i), []byte{uint8(i)}) + cache.Store(uint16(i), 0, false, []byte{uint8(i)}) } cache.Resize(32) @@ -107,7 +107,7 @@ func TestCacheShrink(t *testing.T) { cache := New(16) for i := 0; i < 24; i++ { - cache.Store(uint16(i), []byte{uint8(i)}) + cache.Store(uint16(i), 0, false, []byte{uint8(i)}) } cache.Resize(12) @@ -150,6 +150,65 @@ func TestCacheGrowCond(t *testing.T) { } } +func TestKeyframe(t *testing.T) { + cache := New(16) + packet := make([]byte, 1) + buf := make([]byte, BufSize) + + cache.Store(7, 57, true, packet) + cache.Store(8, 57, true, packet) + + ts, kf := cache.Keyframe() + if ts != 57 || len(kf) != 2 { + t.Errorf("Got %v %v, expected %v %v", ts, len(kf), 57, 2) + } + for _, i := range kf { + l := cache.Get(i, buf) + if int(l) != len(packet) { + t.Errorf("Couldn't get %v", i) + } + } + + for i := 0; i < 32; i++ { + cache.Store(uint16(9 + i), uint32(58 + i), false, packet) + } + + ts, kf = cache.Keyframe() + if ts != 57 || len(kf) != 2 { + t.Errorf("Got %v %v, expected %v %v", ts, len(kf), 57, 2) + } + for _, i := range kf { + l := cache.Get(i, buf) + if int(l) != len(packet) { + t.Errorf("Couldn't get %v", i) + } + } +} + +func TestKeyframeUnsorted(t *testing.T) { + cache := New(16) + packet := make([]byte, 1) + + cache.Store(7, 57, true, packet) + cache.Store(9, 57, true, packet) + cache.Store(8, 57, true, packet) + cache.Store(10, 57, true, packet) + cache.Store(6, 57, true, packet) + cache.Store(8, 57, true, packet) + + _, kf := cache.Keyframe() + if len(kf) != 5 { + t.Errorf("Got length %v, expected 5", len(kf)) + } + for i, v := range kf { + if v != uint16(i + 6) { + t.Errorf("Position %v, expected %v, got %v\n", + i, i + 6, v) + } + } +} + + func TestBitmap(t *testing.T) { value := uint64(0xcdd58f1e035379c0) packet := make([]byte, 1) @@ -159,7 +218,7 @@ func TestBitmap(t *testing.T) { var first uint16 for i := 0; i < 64; i++ { if (value & (1 << i)) != 0 { - first, _ = cache.Store(uint16(42+i), packet) + first, _ = cache.Store(uint16(42+i), 0, false, packet) } } @@ -175,13 +234,13 @@ func TestBitmapWrap(t *testing.T) { cache := New(16) - cache.Store(0x7000, packet) - cache.Store(0xA000, packet) + cache.Store(0x7000, 0, false, packet) + cache.Store(0xA000, 0, false, packet) var first uint16 for i := 0; i < 64; i++ { if (value & (1 << i)) != 0 { - first, _ = cache.Store(uint16(42+i), packet) + first, _ = cache.Store(uint16(42+i), 0, false, packet) } } @@ -199,7 +258,7 @@ func TestBitmapGet(t *testing.T) { for i := 0; i < 64; i++ { if (value & (1 << i)) != 0 { - cache.Store(uint16(42+i), packet) + cache.Store(uint16(42+i), 0, false, packet) } } @@ -241,7 +300,7 @@ func TestBitmapPacket(t *testing.T) { for i := 0; i < 64; i++ { if (value & (1 << i)) != 0 { - cache.Store(uint16(42+i), packet) + cache.Store(uint16(42+i), 0, false, packet) } } @@ -299,7 +358,7 @@ func BenchmarkCachePutGet(b *testing.B) { for i := 0; i < b.N; i++ { seqno := uint16(i) - cache.Store(seqno, buf) + cache.Store(seqno, 0, false, buf) for _, ch := range chans { ch <- seqno } @@ -350,7 +409,7 @@ func BenchmarkCachePutGetAt(b *testing.B) { for i := 0; i < b.N; i++ { seqno := uint16(i) - _, index := cache.Store(seqno, buf) + _, index := cache.Store(seqno, 0, false, buf) for _, ch := range chans { ch <- is{index, seqno} } diff --git a/rtpconn/rtpreader.go b/rtpconn/rtpreader.go index bdb56f9..bf5216e 100644 --- a/rtpconn/rtpreader.go +++ b/rtpconn/rtpreader.go @@ -5,12 +5,24 @@ import ( "log" "github.com/pion/rtp" + "github.com/pion/rtp/codecs" "github.com/pion/webrtc/v3" "sfu/packetcache" "sfu/rtptime" ) +func isVP8Keyframe(packet *rtp.Packet) bool { + var vp8 codecs.VP8Packet + _, err := vp8.Unmarshal(packet.Payload) + if err != nil { + return false + } + + return vp8.S != 0 && vp8.PID == 0 && + len(vp8.Payload) > 0 && (vp8.Payload[0]&0x1) == 0 +} + func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { writers := rtpWriterPool{conn: conn, track: track} defer func() { @@ -19,6 +31,7 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { }() isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo + codec := track.track.Codec().Name buf := make([]byte, packetcache.BufSize) var packet rtp.Packet for { @@ -39,8 +52,14 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { track.jitter.Accumulate(packet.Timestamp) - first, index := - track.cache.Store(packet.SequenceNumber, buf[:bytes]) + kf := false + if isvideo && codec == webrtc.VP8 { + kf = isVP8Keyframe(&packet) + } + + first, index := track.cache.Store( + packet.SequenceNumber, packet.Timestamp, kf, buf[:bytes], + ) if packet.SequenceNumber-first > 24 { found, first, bitmap := track.cache.BitmapGet() if found { diff --git a/rtpconn/rtpwriter.go b/rtpconn/rtpwriter.go index 2150c43..b0bc040 100644 --- a/rtpconn/rtpwriter.go +++ b/rtpconn/rtpwriter.go @@ -138,7 +138,7 @@ func (wp *rtpWriterPool) write(seqno uint16, index uint16, delay uint32, isvideo continue } // audio, try again with a delay - d := delay/uint32(2*len(wp.writers)) + d := delay / uint32(2*len(wp.writers)) timer := time.NewTimer(rtptime.ToDuration( uint64(d), rtptime.JiffiesPerSec, )) @@ -208,6 +208,31 @@ func (writer *rtpWriter) add(track conn.DownTrack, add bool, max int) error { } } +func sendKeyframe(track conn.DownTrack, cache *packetcache.Cache) { + _, kf := cache.Keyframe() + if len(kf) == 0 { + return + } + + buf := make([]byte, packetcache.BufSize) + var packet rtp.Packet + for _, seqno := range kf { + bytes := cache.Get(seqno, buf) + if(bytes == 0) { + return + } + err := packet.Unmarshal(buf[:bytes]) + if err != nil { + return + } + err = track.WriteRTP(&packet) + if err != nil && err != conn.ErrKeyframeNeeded { + return + } + track.Accumulate(uint32(bytes)) + } +} + // rtpWriterLoop is the main loop of an rtpWriter. func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { defer close(writer.done) @@ -245,6 +270,7 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { if cname != "" { action.track.SetCname(cname) } + go sendKeyframe(action.track, track.cache) } else { found := false for i, t := range local { @@ -286,8 +312,9 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { if err != nil { if err == conn.ErrKeyframeNeeded { kfNeeded = true + } else { + continue } - continue } l.Accumulate(uint32(bytes)) }