diff --git a/client.go b/client.go index bf5c8bd..e3dae11 100644 --- a/client.go +++ b/client.go @@ -589,10 +589,12 @@ func addDownTrack(c *client, id string, remoteTrack *upTrack, remoteConn *upConn } track := &downTrack{ - track: local, - remote: remoteTrack, - maxBitrate: new(timeStampedBitrate), - rate: estimator.New(time.Second), + track: local, + remote: remoteTrack, + maxLossBitrate: new(bitrate), + maxREMBBitrate: new(bitrate), + stats: new(receiverStats), + rate: estimator.New(time.Second), } conn.tracks = append(conn.tracks, track) remoteTrack.addLocal(track) @@ -602,6 +604,41 @@ func addDownTrack(c *client, id string, remoteTrack *upTrack, remoteConn *upConn return conn, s, nil } +const ( + minLossRate = 9600 + initLossRate = 512 * 1000 + maxLossRate = 1 << 30 +) + +func (track *downTrack) updateRate(loss uint8, now uint64) { + rate := track.maxLossBitrate.Get(now) + if rate > maxLossRate { + // no recent feedback, reset + rate = initLossRate + } + if loss < 5 { + // if our actual rate is low, then we're not probing the + // bottleneck + actual := 8 * uint64(track.rate.Estimate()) + if actual >= (rate*7)/8 { + // loss < 0.02, multiply by 1.05 + rate = rate * 269 / 256 + if rate > maxLossRate { + rate = maxLossRate + } + } + } else if loss > 25 { + // loss > 0.1, multiply by (1 - loss/2) + rate = rate * (512 - uint64(loss)) / 512 + if rate < minLossRate { + rate = minLossRate + } + } + + // update unconditionally, to set the timestamp + track.maxLossBitrate.Set(rate, now) +} + func rtcpDownListener(g *group, conn *downConnection, track *downTrack, s *webrtc.RTPSender) { for { ps, err := s.ReadRTCP() @@ -620,23 +657,26 @@ func rtcpDownListener(g *group, conn *downConnection, track *downTrack, s *webrt log.Printf("sendPLI: %v", err) } case *rtcp.ReceiverEstimatedMaximumBitrate: - track.maxBitrate.Set(p.Bitrate, - mono.Microseconds(), + track.maxREMBBitrate.Set( + p.Bitrate, mono.Microseconds(), ) case *rtcp.ReceiverReport: for _, r := range p.Reports { if r.SSRC == track.track.SSRC() { - atomic.StoreUint32( - &track.loss, - uint32(r.FractionLost), + now := mono.Microseconds() + track.stats.Set( + r.FractionLost, + r.Jitter, + now, + ) + track.updateRate( + r.FractionLost, + now, ) - atomic.StoreUint32( - &track.jitter, - r.Jitter) } } case *rtcp.TransportLayerNack: - maxBitrate := track.maxBitrate.Get( + maxBitrate := track.GetMaxBitrate( mono.Microseconds(), ) bitrate := track.rate.Estimate() @@ -675,7 +715,7 @@ func updateUpBitrate(up *upConnection) { track.maxBitrate = ^uint64(0) local := track.getLocal() for _, l := range local { - bitrate := l.maxBitrate.Get(now) + bitrate := l.GetMaxBitrate(now) if bitrate == ^uint64(0) { continue } diff --git a/group.go b/group.go index 5a195ef..4095aa2 100644 --- a/group.go +++ b/group.go @@ -72,34 +72,63 @@ type upConnection struct { tracks []*upTrack } -type timeStampedBitrate struct { +type bitrate struct { bitrate uint64 microseconds uint64 } -func (tb *timeStampedBitrate) Set(bitrate, us uint64) { +func (br *bitrate) Set(bitrate uint64, now uint64) { // this is racy -- a reader might read the // data between the two writes. This shouldn't // matter, we'll recover at the next sample. - atomic.StoreUint64(&tb.bitrate, bitrate) - atomic.StoreUint64(&tb.microseconds, us) + atomic.StoreUint64(&br.bitrate, bitrate) + atomic.StoreUint64(&br.microseconds, now) } -func (tb *timeStampedBitrate) Get(now uint64) uint64 { - ts := atomic.LoadUint64(&tb.microseconds) +func (br *bitrate) Get(now uint64) uint64 { + ts := atomic.LoadUint64(&br.microseconds) if now < ts || now > ts+4000000 { return ^uint64(0) } - return atomic.LoadUint64(&tb.bitrate) + return atomic.LoadUint64(&br.bitrate) +} + +type receiverStats struct { + loss uint32 + jitter uint32 + microseconds uint64 +} + +func (s *receiverStats) Set(loss uint8, jitter uint32, now uint64) { + atomic.StoreUint32(&s.loss, uint32(loss)) + atomic.StoreUint32(&s.jitter, jitter) + atomic.StoreUint64(&s.microseconds, now) +} + +func (s *receiverStats) Get(now uint64) (uint8, uint32) { + ts := atomic.LoadUint64(&s.microseconds) + if now < ts || now > ts+4000000 { + return 0, 0 + } + return uint8(atomic.LoadUint32(&s.loss)), atomic.LoadUint32(&s.jitter) } type downTrack struct { - track *webrtc.Track - remote *upTrack - maxBitrate *timeStampedBitrate - rate *estimator.Estimator - loss uint32 - jitter uint32 + track *webrtc.Track + remote *upTrack + maxLossBitrate *bitrate + maxREMBBitrate *bitrate + rate *estimator.Estimator + stats *receiverStats +} + +func (down *downTrack) GetMaxBitrate(now uint64) uint64 { + br1 := down.maxLossBitrate.Get(now) + br2 := down.maxREMBBitrate.Get(now) + if br1 < br2 { + return br1 + } + return br2 } type downConnection struct { @@ -725,15 +754,14 @@ func getClientStats(c *client) clientStats { for _, down := range c.down { conns := connStats{id: down.id} for _, t := range down.tracks { - loss := atomic.LoadUint32(&t.loss) - jitter := time.Duration(atomic.LoadUint32(&t.jitter)) * - time.Second / + loss, jitter := t.stats.Get(mono.Microseconds()) + j := time.Duration(jitter) * time.Second / time.Duration(t.track.Codec().ClockRate) conns.tracks = append(conns.tracks, trackStats{ bitrate: uint64(t.rate.Estimate()) * 8, - maxBitrate: t.maxBitrate.Get(mono.Microseconds()), - loss: uint8((loss * 100) / 256), - jitter: jitter, + maxBitrate: t.GetMaxBitrate(mono.Microseconds()), + loss: uint8(uint32(loss) * 100 / 256), + jitter: j, }) } cs.down = append(cs.down, conns)