From a4d074170414d9c5d257616fc7cd8aaa03d19d0d Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Fri, 12 Jun 2020 17:39:16 +0200 Subject: [PATCH] Perform congestion control at the connection level. REMB applies to the whole transport, not to individual tracks. --- conn.go | 2 +- disk.go | 8 +-- group.go | 28 +++++---- rtpconn.go | 171 ++++++++++++++++++++++++++------------------------- webclient.go | 40 +++--------- webserver.go | 20 ++++-- 6 files changed, 131 insertions(+), 138 deletions(-) diff --git a/conn.go b/conn.go index 4155529..f099e12 100644 --- a/conn.go +++ b/conn.go @@ -34,11 +34,11 @@ type upTrack interface { } type downConnection interface { + GetMaxBitrate(now uint64) uint64 } type downTrack interface { WriteRTP(packat *rtp.Packet) error Accumulate(bytes uint32) - GetMaxBitrate(now uint64) uint64 setTimeOffset(ntp uint64, rtp uint32) } diff --git a/disk.go b/disk.go index f1c1b00..818a241 100644 --- a/disk.go +++ b/disk.go @@ -382,10 +382,10 @@ func (conn *diskConn) initWriter(width, height uint32) error { return nil } +func (down *diskConn) GetMaxBitrate(now uint64) uint64 { + return ^uint64(0) +} + func (t *diskTrack) Accumulate(bytes uint32) { return } - -func (down *diskTrack) GetMaxBitrate(now uint64) uint64 { - return ^uint64(0) -} diff --git a/group.go b/group.go index 2ab4f94..c3759e7 100644 --- a/group.go +++ b/group.go @@ -37,8 +37,7 @@ type chatHistoryEntry struct { } const ( - minVideoRate = 200000 - minAudioRate = 9600 + minBitrate = 200000 ) type group struct { @@ -506,8 +505,9 @@ type clientStats struct { } type connStats struct { - id string - tracks []trackStats + id string + maxBitrate uint64 + tracks []trackStats } type trackStats struct { @@ -560,7 +560,9 @@ func getClientStats(c *webClient) clientStats { } for _, up := range c.up { - conns := connStats{id: up.id} + conns := connStats{ + id: up.id, + } tracks := up.getTracks() for _, t := range tracks { expected, lost, _, _ := t.cache.GetStats(false) @@ -572,10 +574,9 @@ func getClientStats(c *webClient) clientStats { (time.Second / time.Duration(t.jitter.HZ())) rate, _ := t.rate.Estimate() conns.tracks = append(conns.tracks, trackStats{ - bitrate: uint64(rate) * 8, - maxBitrate: atomic.LoadUint64(&t.maxBitrate), - loss: loss, - jitter: jitter, + bitrate: uint64(rate) * 8, + loss: loss, + jitter: jitter, }) } cs.up = append(cs.up, conns) @@ -584,10 +585,13 @@ func getClientStats(c *webClient) clientStats { return cs.up[i].id < cs.up[j].id }) + jiffies := rtptime.Jiffies() for _, down := range c.down { - conns := connStats{id: down.id} + conns := connStats{ + id: down.id, + maxBitrate: down.GetMaxBitrate(jiffies), + } for _, t := range down.tracks { - jiffies := rtptime.Jiffies() rate, _ := t.rate.Estimate() rtt := rtptime.ToDuration(atomic.LoadUint64(&t.rtt), rtptime.JiffiesPerSec) @@ -596,7 +600,7 @@ func getClientStats(c *webClient) clientStats { time.Duration(t.track.Codec().ClockRate) conns.tracks = append(conns.tracks, trackStats{ bitrate: uint64(rate) * 8, - maxBitrate: t.GetMaxBitrate(jiffies), + maxBitrate: t.maxBitrate.Get(jiffies), loss: uint8(uint32(loss) * 100 / 256), rtt: rtt, jitter: j, diff --git a/rtpconn.go b/rtpconn.go index 2f5acda..bb50599 100644 --- a/rtpconn.go +++ b/rtpconn.go @@ -70,17 +70,16 @@ type iceConnection interface { } type rtpDownTrack struct { - track *webrtc.Track - remote upTrack - maxLossBitrate *bitrate - maxREMBBitrate *bitrate - rate *estimator.Estimator - stats *receiverStats - srTime uint64 - srNTPTime uint64 - remoteNTPTime uint64 - remoteRTPTime uint32 - rtt uint64 + track *webrtc.Track + remote upTrack + maxBitrate *bitrate + rate *estimator.Estimator + stats *receiverStats + srTime uint64 + srNTPTime uint64 + remoteNTPTime uint64 + remoteRTPTime uint32 + rtt uint64 } func (down *rtpDownTrack) WriteRTP(packet *rtp.Packet) error { @@ -91,26 +90,18 @@ func (down *rtpDownTrack) Accumulate(bytes uint32) { down.rate.Accumulate(bytes) } -func (down *rtpDownTrack) GetMaxBitrate(now uint64) uint64 { - br1 := down.maxLossBitrate.Get(now) - br2 := down.maxREMBBitrate.Get(now) - if br1 < br2 { - return br1 - } - return br2 -} - func (down *rtpDownTrack) setTimeOffset(ntp uint64, rtp uint32) { atomic.StoreUint64(&down.remoteNTPTime, ntp) atomic.StoreUint32(&down.remoteRTPTime, rtp) } type rtpDownConnection struct { - id string - pc *webrtc.PeerConnection - remote upConnection - tracks []*rtpDownTrack - iceCandidates []*webrtc.ICECandidateInit + id string + pc *webrtc.PeerConnection + remote upConnection + tracks []*rtpDownTrack + maxREMBBitrate *bitrate + iceCandidates []*webrtc.ICECandidateInit } func newDownConn(id string, remote upConnection) (*rtpDownConnection, error) { @@ -124,14 +115,35 @@ func newDownConn(id string, remote upConnection) (*rtpDownConnection, error) { }) conn := &rtpDownConnection{ - id: id, - pc: pc, - remote: remote, + id: id, + pc: pc, + remote: remote, + maxREMBBitrate: new(bitrate), } return conn, nil } +func (down *rtpDownConnection) GetMaxBitrate(now uint64) uint64 { + rate := down.maxREMBBitrate.Get(now) + var trackRate uint64 + for _, t := range down.tracks { + r := t.maxBitrate.Get(now) + if r == ^uint64(0) { + if t.track.Kind() == webrtc.RTPCodecTypeAudio { + r = 128 * 1024 + } else { + r = 512 * 1024 + } + } + trackRate += r + } + if trackRate < rate { + return trackRate + } + return rate +} + func (down *rtpDownConnection) addICECandidate(candidate *webrtc.ICECandidateInit) error { if down.pc.RemoteDescription() != nil { return down.pc.AddICECandidate(*candidate) @@ -162,15 +174,14 @@ func (down *rtpDownConnection) flushICECandidates() error { } type rtpUpTrack struct { - track *webrtc.Track - label string - rate *estimator.Estimator - cache *packetcache.Cache - jitter *jitter.Estimator - maxBitrate uint64 - lastPLI uint64 - lastFIR uint64 - firSeqno uint32 + track *webrtc.Track + label string + rate *estimator.Estimator + cache *packetcache.Cache + jitter *jitter.Estimator + lastPLI uint64 + lastFIR uint64 + firSeqno uint32 localCh chan localTrackAction writerDone chan struct{} @@ -422,7 +433,6 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) { cache: packetcache.New(32), rate: estimator.New(time.Second), jitter: jitter.New(remote.Codec().ClockRate), - maxBitrate: ^uint64(0), localCh: make(chan localTrackAction, 2), writerDone: make(chan struct{}), } @@ -690,15 +700,6 @@ func sendFIR(pc *webrtc.PeerConnection, ssrc uint32, seqno uint8) error { }) } -func sendREMB(pc *webrtc.PeerConnection, ssrc uint32, bitrate uint64) error { - return pc.WriteRTCP([]rtcp.Packet{ - &rtcp.ReceiverEstimatedMaximumBitrate{ - Bitrate: bitrate, - SSRCs: []uint32{ssrc}, - }, - }) -} - func (up *rtpUpConnection) sendNACK(track *rtpUpTrack, first uint16, bitmap uint16) error { if !track.hasRtcpFb("nack", "") { return nil @@ -797,7 +798,7 @@ func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPRecei } } -func sendRR(conn *rtpUpConnection) error { +func sendUpRTCP(conn *rtpUpConnection) error { conn.mu.Lock() defer conn.mu.Unlock() @@ -813,6 +814,7 @@ func sendRR(conn *rtpUpConnection) error { reports := make([]rtcp.ReceptionReport, 0, len(conn.tracks)) for _, t := range conn.tracks { + updateUpTrack(t) expected, lost, totalLost, eseqno := t.cache.GetStats(true) if expected == 0 { expected = 1 @@ -843,17 +845,46 @@ func sendRR(conn *rtpUpConnection) error { }) } - return conn.pc.WriteRTCP([]rtcp.Packet{ + packets := []rtcp.Packet{ &rtcp.ReceiverReport{ Reports: reports, }, - }) + } + + rate := ^uint64(0) + for _, l := range conn.local { + r := l.GetMaxBitrate(now) + if r < rate { + rate = r + } + } + if rate < minBitrate { + rate = minBitrate + } + + var ssrcs []uint32 + for _, t := range conn.tracks { + if t.hasRtcpFb("goog-remb", "") { + continue + } + ssrcs = append(ssrcs, t.track.SSRC()) + } + + if len(ssrcs) > 0 { + packets = append(packets, + &rtcp.ReceiverEstimatedMaximumBitrate{ + Bitrate: rate, + SSRCs: ssrcs, + }, + ) + } + return conn.pc.WriteRTCP(packets) } func rtcpUpSender(conn *rtpUpConnection) { for { time.Sleep(time.Second) - err := sendRR(conn) + err := sendUpRTCP(conn) if err != nil { if err == io.EOF || err == io.ErrClosedPipe { return @@ -936,7 +967,7 @@ const ( ) func (track *rtpDownTrack) updateRate(loss uint8, now uint64) { - rate := track.maxLossBitrate.Get(now) + rate := track.maxBitrate.Get(now) if rate < minLossRate || rate > maxLossRate { // no recent feedback, reset rate = initLossRate @@ -962,7 +993,7 @@ func (track *rtpDownTrack) updateRate(loss uint8, now uint64) { } // update unconditionally, to set the timestamp - track.maxLossBitrate.Set(rate, now) + track.maxBitrate.Set(rate, now) } func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RTPSender) { @@ -1034,7 +1065,7 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT log.Printf("sendFIR: %v", err) } case *rtcp.ReceiverEstimatedMaximumBitrate: - track.maxREMBBitrate.Set(p.Bitrate, jiffies) + conn.maxREMBBitrate.Set(p.Bitrate, jiffies) case *rtcp.ReceiverReport: for _, r := range p.Reports { if r.SSRC == track.track.SSRC() { @@ -1048,11 +1079,7 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT } } case *rtcp.TransportLayerNack: - maxBitrate := track.GetMaxBitrate(jiffies) - bitrate, _ := track.rate.Estimate() - if uint64(bitrate)*7/8 < maxBitrate { - sendRecovery(p, track) - } + sendRecovery(p, track) } } } @@ -1086,37 +1113,18 @@ func handleReport(track *rtpDownTrack, report rtcp.ReceptionReport, jiffies uint } } -func updateUpTrack(track *rtpUpTrack) uint64 { +func updateUpTrack(track *rtpUpTrack) { now := rtptime.Jiffies() - isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo clockrate := track.track.Codec().ClockRate - minrate := uint64(minAudioRate) - rate := ^uint64(0) - if isvideo { - minrate = minVideoRate - rate = ^uint64(0) - } local := track.getLocal() var maxrto uint64 for _, l := range local { - bitrate := l.GetMaxBitrate(now) - if bitrate == ^uint64(0) { - continue - } - if bitrate <= minrate { - rate = minrate - break - } - if rate > bitrate { - rate = bitrate - } ll, ok := l.(*rtpDownTrack) if ok { _, j := ll.stats.Get(now) jitter := uint64(j) * - (rtptime.JiffiesPerSec / - uint64(clockrate)) + (rtptime.JiffiesPerSec / uint64(clockrate)) rtt := atomic.LoadUint64(&ll.rtt) rto := rtt + 4*jitter if rto > maxrto { @@ -1124,7 +1132,6 @@ func updateUpTrack(track *rtpUpTrack) uint64 { } } } - track.maxBitrate = rate _, r := track.rate.Estimate() packets := int((uint64(r) * maxrto * 4) / rtptime.JiffiesPerSec) if packets < 32 { @@ -1134,6 +1141,4 @@ func updateUpTrack(track *rtpUpTrack) uint64 { packets = 256 } track.cache.ResizeCond(packets) - - return rate } diff --git a/webclient.go b/webclient.go index 2300758..a2e1943 100644 --- a/webclient.go +++ b/webclient.go @@ -380,12 +380,11 @@ func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack upTrack, re } track := &rtpDownTrack{ - track: local, - remote: remoteTrack, - maxLossBitrate: new(bitrate), - maxREMBBitrate: new(bitrate), - stats: new(receiverStats), - rate: estimator.New(time.Second), + track: local, + remote: remoteTrack, + maxBitrate: new(bitrate), + stats: new(receiverStats), + rate: estimator.New(time.Second), } conn.tracks = append(conn.tracks, track) @@ -692,10 +691,8 @@ func clientLoop(c *webClient, conn *websocket.Conn) error { readTime := time.Now() - ticker := time.NewTicker(time.Second) + ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() - slowTicker := time.NewTicker(10 * time.Second) - defer slowTicker.Stop() for { select { @@ -766,7 +763,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error { go a.c.pushConn(u.id, u, ts, u.label) } case connectionFailedAction: - down := getDownConn(c, a.id); + down := getDownConn(c, a.id) if down == nil { log.Printf("Failed indication for " + "unknown connection") @@ -804,8 +801,6 @@ func clientLoop(c *webClient, conn *websocket.Conn) error { return errors.New("unexpected action") } case <-ticker.C: - sendRateUpdate(c) - case <-slowTicker.C: if time.Since(readTime) > 90*time.Second { return errors.New("client is dead") } @@ -1022,27 +1017,6 @@ func handleClientMessage(c *webClient, m clientMessage) error { return nil } -func sendRateUpdate(c *webClient) { - up := getUpConns(c) - - for _, u := range up { - tracks := u.getTracks() - for _, t := range tracks { - rate := updateUpTrack(t) - if !t.hasRtcpFb("goog-remb", "") { - continue - } - if rate == ^uint64(0) { - continue - } - err := sendREMB(u.pc, t.track.SSRC(), rate) - if err != nil { - log.Printf("sendREMB: %v", err) - } - } - } -} - func clientReader(conn *websocket.Conn, read chan<- interface{}, done <-chan struct{}) { defer close(read) for { diff --git a/webserver.go b/webserver.go index 55a87a9..f59ee44 100644 --- a/webserver.go +++ b/webserver.go @@ -171,16 +171,26 @@ func statsHandler(w http.ResponseWriter, r *http.Request) { for _, cs := range gs.clients { fmt.Fprintf(w, "%v\n", cs.id) for _, up := range cs.up { - fmt.Fprintf(w, "Up%v\n", + fmt.Fprintf(w, "Up%v", up.id) + if up.maxBitrate > 0 { + fmt.Fprintf(w, "%v", + up.maxBitrate) + } + fmt.Fprintf(w, "\n") for _, t := range up.tracks { printTrack(w, t) } } - for _, up := range cs.down { - fmt.Fprintf(w, "Down %v\n", - up.id) - for _, t := range up.tracks { + for _, down := range cs.down { + fmt.Fprintf(w, "Down %v", + down.id) + if down.maxBitrate > 0 { + fmt.Fprintf(w, "%v", + down.maxBitrate) + } + fmt.Fprintf(w, "\n") + for _, t := range down.tracks { printTrack(w, t) } }