From 22585e9d10041441e63ce541831d0b08946f07a8 Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Mon, 17 May 2021 16:23:07 +0200 Subject: [PATCH] Handle spatial scalability. Maintain spatial layer information, and drop lower layers when possible. Yields a 20% saving with VP9. --- conn/conn.go | 2 +- diskwriter/diskwriter.go | 4 +- rtpconn/codec.go | 62 +++++++++++------- rtpconn/codec_test.go | 19 +++--- rtpconn/rtpconn.go | 134 +++++++++++++++++++++++++-------------- rtpconn/rtpconn_test.go | 18 ++++-- rtpconn/rtpstats.go | 12 +++- static/stats.js | 11 +++- stats/stats.go | 6 +- 9 files changed, 173 insertions(+), 95 deletions(-) diff --git a/conn/conn.go b/conn/conn.go index 5e74297..c4a4c2c 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -40,5 +40,5 @@ type DownTrack interface { Write(buf []byte) (int, error) SetTimeOffset(ntp uint64, rtp uint32) SetCname(string) - GetMaxBitrate() (uint64, int) + GetMaxBitrate() (uint64, int, int) } diff --git a/diskwriter/diskwriter.go b/diskwriter/diskwriter.go index 9d30c97..e2a3096 100644 --- a/diskwriter/diskwriter.go +++ b/diskwriter/diskwriter.go @@ -616,6 +616,6 @@ func (conn *diskConn) initWriter(width, height uint32) error { return nil } -func (t *diskTrack) GetMaxBitrate() (uint64, int) { - return ^uint64(0), -1 +func (t *diskTrack) GetMaxBitrate() (uint64, int, int) { + return ^uint64(0), -1, -1 } diff --git a/rtpconn/codec.go b/rtpconn/codec.go index 1ee29a0..93dd465 100644 --- a/rtpconn/codec.go +++ b/rtpconn/codec.go @@ -64,7 +64,7 @@ func isKeyframe(codec string, packet *rtp.Packet) (bool, bool) { return nil, offset, offset > 0 } l := data[offset] - length |= int(l & 0x7f) << (offset * 7) + length |= int(l&0x7f) << (offset * 7) offset++ if (l & 0x80) == 0 { break @@ -182,50 +182,66 @@ func isKeyframe(codec string, packet *rtp.Packet) (bool, bool) { var errTruncated = errors.New("truncated packet") var errUnsupportedCodec = errors.New("unsupported codec") -func packetFlags(codec string, buf []byte) (seqno uint16, start bool, pid uint16, tid uint8, sid uint8, layersync bool, discardable bool, err error) { +type packetFlags struct { + seqno uint16 + start bool + pid uint16 // only if it needs rewriting + tid uint8 + sid uint8 + tidupsync bool + sidsync bool + sidnonreference bool + discardable bool +} + +func getPacketFlags(codec string, buf []byte) (packetFlags, error) { if len(buf) < 12 { - err = errTruncated - return + return packetFlags{}, errTruncated } - seqno = (uint16(buf[2]) << 8) | uint16(buf[3]) + var flags packetFlags + + flags.seqno = (uint16(buf[2]) << 8) | uint16(buf[3]) if strings.EqualFold(codec, "video/vp8") { var packet rtp.Packet - err = packet.Unmarshal(buf) + err := packet.Unmarshal(buf) if err != nil { - return + return flags, err } var vp8 codecs.VP8Packet _, err = vp8.Unmarshal(packet.Payload) if err != nil { - return + return flags, err } - start = vp8.S == 1 && vp8.PID == 0 - pid = vp8.PictureID - tid = vp8.TID - layersync = vp8.Y == 1 - discardable = vp8.N == 1 - return + flags.start = vp8.S == 1 && vp8.PID == 0 + flags.pid = vp8.PictureID + flags.tid = vp8.TID + flags.tidupsync = vp8.Y == 1 + flags.discardable = vp8.N == 1 + return flags, nil } else if strings.EqualFold(codec, "video/vp9") { var packet rtp.Packet - err = packet.Unmarshal(buf) + err := packet.Unmarshal(buf) if err != nil { - return + return flags, err } var vp9 codecs.VP9Packet _, err = vp9.Unmarshal(packet.Payload) if err != nil { - return + return flags, err } - start = vp9.B - tid = vp9.TID - sid = vp9.SID - layersync = vp9.U - return + flags.start = vp9.B + flags.tid = vp9.TID + flags.sid = vp9.SID + flags.tidupsync = vp9.U + flags.sidsync = vp9.P + // not yet in pion/rtp + flags.sidnonreference = (packet.Payload[0] & 0x01) != 0 + return flags, nil } - return + return flags, nil } func rewritePacket(codec string, data []byte, seqno uint16, delta uint16) error { diff --git a/rtpconn/codec_test.go b/rtpconn/codec_test.go index 16a49a8..358f132 100644 --- a/rtpconn/codec_test.go +++ b/rtpconn/codec_test.go @@ -16,12 +16,13 @@ var vp8 = []byte{ func TestPacketFlags(t *testing.T) { buf := append([]byte{}, vp8...) - seqno, start, pid, tid, sid, layersync, discardable, err := - packetFlags("video/vp8", buf) - if seqno != 42 || !start || pid != 57 || sid != 0 || tid != 0 || - layersync || discardable || err != nil { + flags, err := getPacketFlags("video/vp8", buf) + if flags.seqno != 42 || !flags.start || flags.pid != 57 || + flags.sid != 0 || flags.tid != 0 || + flags.tidupsync || flags.discardable || err != nil { t.Errorf("Got %v, %v, %v, %v, %v, %v (%v)", - seqno, start, pid, sid, layersync, discardable, err, + flags.seqno, flags.start, flags.pid, flags.sid, + flags.tidupsync, flags.discardable, err, ) } } @@ -34,10 +35,12 @@ func TestRewrite(t *testing.T) { t.Errorf("rewrite: %v", err) continue } - seqno, _, pid, _, _, _, _, err := packetFlags("video/vp8", buf) - if err != nil || seqno != i || pid != (57 + i) & 0x7FFF { + flags, err := getPacketFlags("video/vp8", buf) + if err != nil || flags.seqno != i || + flags.pid != (57 + i) & 0x7FFF { t.Errorf("Expected %v %v, got %v %v (%v)", - i, (57 + i) & 0x7FFF, seqno, pid, err) + i, (57 + i) & 0x7FFF, + flags.seqno, flags.pid, err) } } } diff --git a/rtpconn/rtpconn.go b/rtpconn/rtpconn.go index 8d2842f..0a18da5 100644 --- a/rtpconn/rtpconn.go +++ b/rtpconn/rtpconn.go @@ -126,14 +126,31 @@ func (down *rtpDownTrack) SetCname(cname string) { down.cname.Store(cname) } -func (down *rtpDownTrack) getLayerInfo() (uint8, uint8, uint8) { - info := atomic.LoadUint32(&down.atomics.layerInfo) - return uint8(info >> 16), uint8(info >> 8), uint8(info) +type layerInfo struct { + sid, wantedSid, maxSid uint8 + tid, wantedTid, maxTid uint8 } -func (down *rtpDownTrack) setLayerInfo(layer, wanted, max uint8) { +func (down *rtpDownTrack) getLayerInfo() layerInfo { + info := atomic.LoadUint32(&down.atomics.layerInfo) + return layerInfo{ + sid: uint8((info & 0xF)), + wantedSid: uint8((info >> 4) & 0xF), + maxSid: uint8((info >> 8) & 0xF), + tid: uint8((info >> 16) & 0xF), + wantedTid: uint8((info >> 20) & 0xF), + maxTid: uint8((info >> 24) & 0xF), + } +} + +func (down *rtpDownTrack) setLayerInfo(info layerInfo) { atomic.StoreUint32(&down.atomics.layerInfo, - (uint32(layer)<<16)|(uint32(wanted)<<8)|uint32(max), + uint32(info.sid&0xF)| + uint32(info.wantedSid&0xF)<<4| + uint32(info.maxSid&0xF)<<8| + uint32(info.tid&0xF)<<16| + uint32(info.wantedTid&0xF)<<20| + uint32(info.maxTid&0xF)<<24, ) } @@ -195,45 +212,59 @@ var packetBufPool = sync.Pool{ func (down *rtpDownTrack) Write(buf []byte) (int, error) { codec := down.remote.Codec().MimeType - seqno, start, pid, tid, _, u, _, err := packetFlags(codec, buf) + flags, err := getPacketFlags(codec, buf) if err != nil { return 0, err } - layer, wantedLayer, maxLayer := down.getLayerInfo() + layer := down.getLayerInfo() - if tid > maxLayer { - if layer == maxLayer { - wantedLayer = tid - layer = tid + if flags.tid > layer.maxTid || flags.sid > layer.maxSid { + if flags.tid > layer.maxTid { + if layer.tid == layer.maxTid { + layer.wantedTid = flags.tid + layer.tid = flags.tid + } + layer.maxTid = flags.tid } - maxLayer = tid - if wantedLayer > maxLayer { - wantedLayer = maxLayer + if flags.sid > layer.maxSid { + if layer.sid == layer.maxSid { + layer.wantedSid = flags.sid + layer.sid = flags.sid + } + layer.maxSid = flags.sid } - down.setLayerInfo(layer, wantedLayer, maxLayer) + down.setLayerInfo(layer) down.adjustLayer() } - if start && layer != wantedLayer { - if u || wantedLayer < layer { - layer = wantedLayer - down.setLayerInfo(layer, wantedLayer, maxLayer) + if flags.start && (layer.tid != layer.wantedTid) { + if layer.wantedTid < layer.tid || flags.tidupsync { + layer.tid = layer.wantedTid + down.setLayerInfo(layer) } } - if tid > layer { - ok := down.packetmap.Drop(seqno, pid) + if flags.start && (layer.sid != layer.wantedSid) { + if flags.sidsync { + layer.sid = layer.wantedTid + down.setLayerInfo(layer) + } + } + + if flags.tid > layer.tid || flags.sid > layer.sid || + (flags.sid < layer.sid && flags.sidnonreference) { + ok := down.packetmap.Drop(flags.seqno, flags.pid) if ok { return 0, nil } } - ok, newseqno, piddelta := down.packetmap.Map(seqno, pid) + ok, newseqno, piddelta := down.packetmap.Map(flags.seqno, flags.pid) if !ok { return 0, nil } - if newseqno == seqno && piddelta == 0 { + if newseqno == flags.seqno && piddelta == 0 { return down.write(buf) } @@ -257,35 +288,35 @@ func (down *rtpDownTrack) write(buf []byte) (int, error) { return n, err } -func (t *rtpDownTrack) GetMaxBitrate() (uint64, int) { +func (t *rtpDownTrack) GetMaxBitrate() (uint64, int, int) { now := rtptime.Jiffies() - layer, _, _ := t.getLayerInfo() + layer := t.getLayerInfo() r := t.maxBitrate.Get(now) if r == ^uint64(0) { r = 512 * 1024 } rr := t.maxREMBBitrate.Get(now) - if rr == 0 || r < rr { - return r, int(layer) + if rr != 0 && rr < r { + r = rr } - return rr, int(layer) + return r, int(layer.sid), int(layer.tid) } func (t *rtpDownTrack) adjustLayer() { - max, _ := t.GetMaxBitrate() + max, _, _ := t.GetMaxBitrate() r, _ := t.rate.Estimate() rate := uint64(r) * 8 if rate < max*7/8 { - layer, wanted, max := t.getLayerInfo() - if layer < max { - wanted = layer + 1 - t.setLayerInfo(layer, wanted, max) + layer := t.getLayerInfo() + if layer.tid < layer.maxTid { + layer.wantedTid = layer.tid + 1 + t.setLayerInfo(layer) } } else if rate > max*3/2 { - layer, wanted, max := t.getLayerInfo() - if layer > 0 { - wanted = layer - 1 - t.setLayerInfo(layer, wanted, max) + layer := t.getLayerInfo() + if layer.tid > 0 { + layer.wantedTid = layer.tid - 1 + t.setLayerInfo(layer) } } } @@ -320,11 +351,11 @@ func (down *rtpDownConnection) flushICECandidates() error { } type rtpUpTrack struct { - track *webrtc.TrackRemote - rate *estimator.Estimator - cache *packetcache.Cache - jitter *jitter.Estimator - cname atomic.Value + track *webrtc.TrackRemote + rate *estimator.Estimator + cache *packetcache.Cache + jitter *jitter.Estimator + cname atomic.Value localCh chan trackAction readerDone chan struct{} @@ -881,12 +912,16 @@ func sendUpRTCP(up *rtpUpConnection) error { } else { minrate := ^uint64(0) maxrate := uint64(group.MinBitrate) - maxlayer := 0 + maxsid := 0 + maxtid := 0 local := t.getLocal() for _, down := range local { - r, l := down.GetMaxBitrate() - if maxlayer < l { - maxlayer = l + r, sid, tid := down.GetMaxBitrate() + if maxsid < sid { + maxsid = sid + } + if maxtid < tid { + maxtid = tid } if r < group.MinBitrate { r = group.MinBitrate @@ -898,10 +933,15 @@ func sendUpRTCP(up *rtpUpConnection) error { maxrate = r } } + // assume that lower spatial layers take up 1/5 of + // the throughput + if maxsid > 0 { + maxrate = maxrate * 5 / 4 + } // assume that each layer takes two times less // throughput than the higher one. Then we've // got enough slack for a factor of 2^(layers-1). - for i := 0; i < maxlayer; i++ { + for i := 0; i < maxtid; i++ { if minrate < ^uint64(0)/2 { minrate *= 2 } diff --git a/rtpconn/rtpconn_test.go b/rtpconn/rtpconn_test.go index b156a0c..4cbb2c8 100644 --- a/rtpconn/rtpconn_test.go +++ b/rtpconn/rtpconn_test.go @@ -18,19 +18,23 @@ func TestDownTrackAtomics(t *testing.T) { down.setSRTime(4, 5) down.maxBitrate.Set(6, rtptime.Jiffies()) down.maxREMBBitrate.Set(7, rtptime.Jiffies()) - down.setLayerInfo(8, 9, 10) + info := layerInfo{8, 9, 10, 11, 12, 13} + down.setLayerInfo(info) ntp, rtp := down.getTimeOffset() rtt := down.getRTT() sr, srntp := down.getSRTime() - br, lbr := down.GetMaxBitrate() - l, w, m := down.getLayerInfo() + br, sbr, tbr := down.GetMaxBitrate() + info2 := down.getLayerInfo() if ntp != 1 || rtp != 2 || rtt != 3 || sr != 4 || srntp != 5 || - br != 6 || lbr != 8 || l != 8 || w != 9 || m != 10 { + br != 6 || sbr != 8 || tbr != 11 { t.Errorf( - "Expected 1 2 3 4 5 6 8 8 9 10, "+ - "got %v %v %v %v %v %v %v %v %v %v", - ntp, rtp, rtt, sr, srntp, br, lbr, l, w, m, + "Expected 1 2 3 4 5 6 8 11, "+ + "got %v %v %v %v %v %v %v %v", + ntp, rtp, rtt, sr, srntp, br, sbr, tbr, ) } + if info2 != info { + t.Errorf("Expected %v, got %v", info, info2) + } } diff --git a/rtpconn/rtpstats.go b/rtpconn/rtpstats.go index 98f17c4..9f816bf 100644 --- a/rtpconn/rtpstats.go +++ b/rtpconn/rtpstats.go @@ -49,7 +49,11 @@ func (c *webClient) GetStats() *stats.Client { Id: down.id, } for _, t := range down.tracks { - l, _, ml := t.getLayerInfo() + layer := t.getLayerInfo() + sid := layer.sid + maxSid := layer.maxSid + tid := layer.tid + maxTid := layer.maxTid rate, _ := t.rate.Estimate() rtt := rtptime.ToDuration(t.getRTT(), rtptime.JiffiesPerSec) @@ -57,8 +61,10 @@ func (c *webClient) GetStats() *stats.Client { j := time.Duration(jitter) * time.Second / time.Duration(t.track.Codec().ClockRate) conns.Tracks = append(conns.Tracks, stats.Track{ - Layer: &l, - MaxLayer: &ml, + Tid: &tid, + MaxTid: &maxTid, + Sid: &sid, + MaxSid: &maxSid, Bitrate: uint64(rate) * 8, MaxBitrate: t.maxBitrate.Get(jiffies), Loss: float64(loss) / 256.0, diff --git a/static/stats.js b/static/stats.js index 3746d96..6ff432c 100644 --- a/static/stats.js +++ b/static/stats.js @@ -99,8 +99,15 @@ function formatTrack(table, track) { tr.appendChild(document.createElement('td')); tr.appendChild(document.createElement('td')); let td = document.createElement('td'); - if(track.layer && track.maxLayer) - td.textContent = `${track.layer}/${track.maxLayer}`; + let layer = ''; + if(track.sid || track.maxSid) + layer = layer + `s${track.sid}/${track.maxSid}`; + if(track.tid || track.maxTid) { + if(layer !== '') + layer = layer + '+'; + layer = layer + `t${track.tid}/${track.maxTid}`; + } + td.textContent = layer; tr.appendChild(td); let td2 = document.createElement('td'); if(track.maxBitrate) diff --git a/stats/stats.go b/stats/stats.go index 1a0d0b6..60f0637 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -47,8 +47,10 @@ func (d *Duration) UnmarshalJSON(buf []byte) error { } type Track struct { - Layer *uint8 `json:"layer,omitempty"` - MaxLayer *uint8 `json:"maxLayer,omitempty"` + Sid *uint8 `json:"sid,omitempty"` + MaxSid *uint8 `json:"maxSid,omitempty"` + Tid *uint8 `json:"tid,omitempty"` + MaxTid *uint8 `json:"maxTid,omitempty"` Bitrate uint64 `json:"bitrate"` MaxBitrate uint64 `json:"maxBitrate,omitempty"` Loss float64 `json:"loss"`