1
Fork 0

Perform congestion control at the connection level.

REMB applies to the whole transport, not to individual tracks.
This commit is contained in:
Juliusz Chroboczek 2020-06-12 17:39:16 +02:00
parent 903e499dd6
commit a4d0741704
6 changed files with 131 additions and 138 deletions

View File

@ -34,11 +34,11 @@ type upTrack interface {
}
type downConnection interface {
GetMaxBitrate(now uint64) uint64
}
type downTrack interface {
WriteRTP(packat *rtp.Packet) error
Accumulate(bytes uint32)
GetMaxBitrate(now uint64) uint64
setTimeOffset(ntp uint64, rtp uint32)
}

View File

@ -382,10 +382,10 @@ func (conn *diskConn) initWriter(width, height uint32) error {
return nil
}
func (down *diskConn) GetMaxBitrate(now uint64) uint64 {
return ^uint64(0)
}
func (t *diskTrack) Accumulate(bytes uint32) {
return
}
func (down *diskTrack) GetMaxBitrate(now uint64) uint64 {
return ^uint64(0)
}

View File

@ -37,8 +37,7 @@ type chatHistoryEntry struct {
}
const (
minVideoRate = 200000
minAudioRate = 9600
minBitrate = 200000
)
type group struct {
@ -506,8 +505,9 @@ type clientStats struct {
}
type connStats struct {
id string
tracks []trackStats
id string
maxBitrate uint64
tracks []trackStats
}
type trackStats struct {
@ -560,7 +560,9 @@ func getClientStats(c *webClient) clientStats {
}
for _, up := range c.up {
conns := connStats{id: up.id}
conns := connStats{
id: up.id,
}
tracks := up.getTracks()
for _, t := range tracks {
expected, lost, _, _ := t.cache.GetStats(false)
@ -572,10 +574,9 @@ func getClientStats(c *webClient) clientStats {
(time.Second / time.Duration(t.jitter.HZ()))
rate, _ := t.rate.Estimate()
conns.tracks = append(conns.tracks, trackStats{
bitrate: uint64(rate) * 8,
maxBitrate: atomic.LoadUint64(&t.maxBitrate),
loss: loss,
jitter: jitter,
bitrate: uint64(rate) * 8,
loss: loss,
jitter: jitter,
})
}
cs.up = append(cs.up, conns)
@ -584,10 +585,13 @@ func getClientStats(c *webClient) clientStats {
return cs.up[i].id < cs.up[j].id
})
jiffies := rtptime.Jiffies()
for _, down := range c.down {
conns := connStats{id: down.id}
conns := connStats{
id: down.id,
maxBitrate: down.GetMaxBitrate(jiffies),
}
for _, t := range down.tracks {
jiffies := rtptime.Jiffies()
rate, _ := t.rate.Estimate()
rtt := rtptime.ToDuration(atomic.LoadUint64(&t.rtt),
rtptime.JiffiesPerSec)
@ -596,7 +600,7 @@ func getClientStats(c *webClient) clientStats {
time.Duration(t.track.Codec().ClockRate)
conns.tracks = append(conns.tracks, trackStats{
bitrate: uint64(rate) * 8,
maxBitrate: t.GetMaxBitrate(jiffies),
maxBitrate: t.maxBitrate.Get(jiffies),
loss: uint8(uint32(loss) * 100 / 256),
rtt: rtt,
jitter: j,

View File

@ -70,17 +70,16 @@ type iceConnection interface {
}
type rtpDownTrack struct {
track *webrtc.Track
remote upTrack
maxLossBitrate *bitrate
maxREMBBitrate *bitrate
rate *estimator.Estimator
stats *receiverStats
srTime uint64
srNTPTime uint64
remoteNTPTime uint64
remoteRTPTime uint32
rtt uint64
track *webrtc.Track
remote upTrack
maxBitrate *bitrate
rate *estimator.Estimator
stats *receiverStats
srTime uint64
srNTPTime uint64
remoteNTPTime uint64
remoteRTPTime uint32
rtt uint64
}
func (down *rtpDownTrack) WriteRTP(packet *rtp.Packet) error {
@ -91,26 +90,18 @@ func (down *rtpDownTrack) Accumulate(bytes uint32) {
down.rate.Accumulate(bytes)
}
func (down *rtpDownTrack) GetMaxBitrate(now uint64) uint64 {
br1 := down.maxLossBitrate.Get(now)
br2 := down.maxREMBBitrate.Get(now)
if br1 < br2 {
return br1
}
return br2
}
func (down *rtpDownTrack) setTimeOffset(ntp uint64, rtp uint32) {
atomic.StoreUint64(&down.remoteNTPTime, ntp)
atomic.StoreUint32(&down.remoteRTPTime, rtp)
}
type rtpDownConnection struct {
id string
pc *webrtc.PeerConnection
remote upConnection
tracks []*rtpDownTrack
iceCandidates []*webrtc.ICECandidateInit
id string
pc *webrtc.PeerConnection
remote upConnection
tracks []*rtpDownTrack
maxREMBBitrate *bitrate
iceCandidates []*webrtc.ICECandidateInit
}
func newDownConn(id string, remote upConnection) (*rtpDownConnection, error) {
@ -124,14 +115,35 @@ func newDownConn(id string, remote upConnection) (*rtpDownConnection, error) {
})
conn := &rtpDownConnection{
id: id,
pc: pc,
remote: remote,
id: id,
pc: pc,
remote: remote,
maxREMBBitrate: new(bitrate),
}
return conn, nil
}
func (down *rtpDownConnection) GetMaxBitrate(now uint64) uint64 {
rate := down.maxREMBBitrate.Get(now)
var trackRate uint64
for _, t := range down.tracks {
r := t.maxBitrate.Get(now)
if r == ^uint64(0) {
if t.track.Kind() == webrtc.RTPCodecTypeAudio {
r = 128 * 1024
} else {
r = 512 * 1024
}
}
trackRate += r
}
if trackRate < rate {
return trackRate
}
return rate
}
func (down *rtpDownConnection) addICECandidate(candidate *webrtc.ICECandidateInit) error {
if down.pc.RemoteDescription() != nil {
return down.pc.AddICECandidate(*candidate)
@ -162,15 +174,14 @@ func (down *rtpDownConnection) flushICECandidates() error {
}
type rtpUpTrack struct {
track *webrtc.Track
label string
rate *estimator.Estimator
cache *packetcache.Cache
jitter *jitter.Estimator
maxBitrate uint64
lastPLI uint64
lastFIR uint64
firSeqno uint32
track *webrtc.Track
label string
rate *estimator.Estimator
cache *packetcache.Cache
jitter *jitter.Estimator
lastPLI uint64
lastFIR uint64
firSeqno uint32
localCh chan localTrackAction
writerDone chan struct{}
@ -422,7 +433,6 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) {
cache: packetcache.New(32),
rate: estimator.New(time.Second),
jitter: jitter.New(remote.Codec().ClockRate),
maxBitrate: ^uint64(0),
localCh: make(chan localTrackAction, 2),
writerDone: make(chan struct{}),
}
@ -690,15 +700,6 @@ func sendFIR(pc *webrtc.PeerConnection, ssrc uint32, seqno uint8) error {
})
}
func sendREMB(pc *webrtc.PeerConnection, ssrc uint32, bitrate uint64) error {
return pc.WriteRTCP([]rtcp.Packet{
&rtcp.ReceiverEstimatedMaximumBitrate{
Bitrate: bitrate,
SSRCs: []uint32{ssrc},
},
})
}
func (up *rtpUpConnection) sendNACK(track *rtpUpTrack, first uint16, bitmap uint16) error {
if !track.hasRtcpFb("nack", "") {
return nil
@ -797,7 +798,7 @@ func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPRecei
}
}
func sendRR(conn *rtpUpConnection) error {
func sendUpRTCP(conn *rtpUpConnection) error {
conn.mu.Lock()
defer conn.mu.Unlock()
@ -813,6 +814,7 @@ func sendRR(conn *rtpUpConnection) error {
reports := make([]rtcp.ReceptionReport, 0, len(conn.tracks))
for _, t := range conn.tracks {
updateUpTrack(t)
expected, lost, totalLost, eseqno := t.cache.GetStats(true)
if expected == 0 {
expected = 1
@ -843,17 +845,46 @@ func sendRR(conn *rtpUpConnection) error {
})
}
return conn.pc.WriteRTCP([]rtcp.Packet{
packets := []rtcp.Packet{
&rtcp.ReceiverReport{
Reports: reports,
},
})
}
rate := ^uint64(0)
for _, l := range conn.local {
r := l.GetMaxBitrate(now)
if r < rate {
rate = r
}
}
if rate < minBitrate {
rate = minBitrate
}
var ssrcs []uint32
for _, t := range conn.tracks {
if t.hasRtcpFb("goog-remb", "") {
continue
}
ssrcs = append(ssrcs, t.track.SSRC())
}
if len(ssrcs) > 0 {
packets = append(packets,
&rtcp.ReceiverEstimatedMaximumBitrate{
Bitrate: rate,
SSRCs: ssrcs,
},
)
}
return conn.pc.WriteRTCP(packets)
}
func rtcpUpSender(conn *rtpUpConnection) {
for {
time.Sleep(time.Second)
err := sendRR(conn)
err := sendUpRTCP(conn)
if err != nil {
if err == io.EOF || err == io.ErrClosedPipe {
return
@ -936,7 +967,7 @@ const (
)
func (track *rtpDownTrack) updateRate(loss uint8, now uint64) {
rate := track.maxLossBitrate.Get(now)
rate := track.maxBitrate.Get(now)
if rate < minLossRate || rate > maxLossRate {
// no recent feedback, reset
rate = initLossRate
@ -962,7 +993,7 @@ func (track *rtpDownTrack) updateRate(loss uint8, now uint64) {
}
// update unconditionally, to set the timestamp
track.maxLossBitrate.Set(rate, now)
track.maxBitrate.Set(rate, now)
}
func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RTPSender) {
@ -1034,7 +1065,7 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT
log.Printf("sendFIR: %v", err)
}
case *rtcp.ReceiverEstimatedMaximumBitrate:
track.maxREMBBitrate.Set(p.Bitrate, jiffies)
conn.maxREMBBitrate.Set(p.Bitrate, jiffies)
case *rtcp.ReceiverReport:
for _, r := range p.Reports {
if r.SSRC == track.track.SSRC() {
@ -1048,11 +1079,7 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT
}
}
case *rtcp.TransportLayerNack:
maxBitrate := track.GetMaxBitrate(jiffies)
bitrate, _ := track.rate.Estimate()
if uint64(bitrate)*7/8 < maxBitrate {
sendRecovery(p, track)
}
sendRecovery(p, track)
}
}
}
@ -1086,37 +1113,18 @@ func handleReport(track *rtpDownTrack, report rtcp.ReceptionReport, jiffies uint
}
}
func updateUpTrack(track *rtpUpTrack) uint64 {
func updateUpTrack(track *rtpUpTrack) {
now := rtptime.Jiffies()
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
clockrate := track.track.Codec().ClockRate
minrate := uint64(minAudioRate)
rate := ^uint64(0)
if isvideo {
minrate = minVideoRate
rate = ^uint64(0)
}
local := track.getLocal()
var maxrto uint64
for _, l := range local {
bitrate := l.GetMaxBitrate(now)
if bitrate == ^uint64(0) {
continue
}
if bitrate <= minrate {
rate = minrate
break
}
if rate > bitrate {
rate = bitrate
}
ll, ok := l.(*rtpDownTrack)
if ok {
_, j := ll.stats.Get(now)
jitter := uint64(j) *
(rtptime.JiffiesPerSec /
uint64(clockrate))
(rtptime.JiffiesPerSec / uint64(clockrate))
rtt := atomic.LoadUint64(&ll.rtt)
rto := rtt + 4*jitter
if rto > maxrto {
@ -1124,7 +1132,6 @@ func updateUpTrack(track *rtpUpTrack) uint64 {
}
}
}
track.maxBitrate = rate
_, r := track.rate.Estimate()
packets := int((uint64(r) * maxrto * 4) / rtptime.JiffiesPerSec)
if packets < 32 {
@ -1134,6 +1141,4 @@ func updateUpTrack(track *rtpUpTrack) uint64 {
packets = 256
}
track.cache.ResizeCond(packets)
return rate
}

View File

@ -380,12 +380,11 @@ func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack upTrack, re
}
track := &rtpDownTrack{
track: local,
remote: remoteTrack,
maxLossBitrate: new(bitrate),
maxREMBBitrate: new(bitrate),
stats: new(receiverStats),
rate: estimator.New(time.Second),
track: local,
remote: remoteTrack,
maxBitrate: new(bitrate),
stats: new(receiverStats),
rate: estimator.New(time.Second),
}
conn.tracks = append(conn.tracks, track)
@ -692,10 +691,8 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
readTime := time.Now()
ticker := time.NewTicker(time.Second)
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
slowTicker := time.NewTicker(10 * time.Second)
defer slowTicker.Stop()
for {
select {
@ -766,7 +763,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
go a.c.pushConn(u.id, u, ts, u.label)
}
case connectionFailedAction:
down := getDownConn(c, a.id);
down := getDownConn(c, a.id)
if down == nil {
log.Printf("Failed indication for " +
"unknown connection")
@ -804,8 +801,6 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
return errors.New("unexpected action")
}
case <-ticker.C:
sendRateUpdate(c)
case <-slowTicker.C:
if time.Since(readTime) > 90*time.Second {
return errors.New("client is dead")
}
@ -1022,27 +1017,6 @@ func handleClientMessage(c *webClient, m clientMessage) error {
return nil
}
func sendRateUpdate(c *webClient) {
up := getUpConns(c)
for _, u := range up {
tracks := u.getTracks()
for _, t := range tracks {
rate := updateUpTrack(t)
if !t.hasRtcpFb("goog-remb", "") {
continue
}
if rate == ^uint64(0) {
continue
}
err := sendREMB(u.pc, t.track.SSRC(), rate)
if err != nil {
log.Printf("sendREMB: %v", err)
}
}
}
}
func clientReader(conn *websocket.Conn, read chan<- interface{}, done <-chan struct{}) {
defer close(read)
for {

View File

@ -171,16 +171,26 @@ func statsHandler(w http.ResponseWriter, r *http.Request) {
for _, cs := range gs.clients {
fmt.Fprintf(w, "<tr><td>%v</td></tr>\n", cs.id)
for _, up := range cs.up {
fmt.Fprintf(w, "<tr><td></td><td>Up</td><td>%v</td></tr>\n",
fmt.Fprintf(w, "<tr><td></td><td>Up</td><td>%v</td>",
up.id)
if up.maxBitrate > 0 {
fmt.Fprintf(w, "<td>%v</td>",
up.maxBitrate)
}
fmt.Fprintf(w, "</tr>\n")
for _, t := range up.tracks {
printTrack(w, t)
}
}
for _, up := range cs.down {
fmt.Fprintf(w, "<tr><td></td><td>Down</td><td> %v</td></tr>\n",
up.id)
for _, t := range up.tracks {
for _, down := range cs.down {
fmt.Fprintf(w, "<tr><td></td><td>Down</td><td> %v</td>",
down.id)
if down.maxBitrate > 0 {
fmt.Fprintf(w, "<td>%v</td>",
down.maxBitrate)
}
fmt.Fprintf(w, "</tr>\n")
for _, t := range down.tracks {
printTrack(w, t)
}
}