diff --git a/estimator/estimator.go b/estimator/estimator.go index 7b00f39..ef7a14d 100644 --- a/estimator/estimator.go +++ b/estimator/estimator.go @@ -8,12 +8,14 @@ import ( type Estimator struct { interval time.Duration - count uint32 + bytes uint32 + packets uint32 mu sync.Mutex totalBytes uint32 totalPackets uint32 rate uint32 + packetRate uint32 time time.Time } @@ -26,30 +28,38 @@ func New(interval time.Duration) *Estimator { func (e *Estimator) swap(now time.Time) { interval := now.Sub(e.time) - count := atomic.SwapUint32(&e.count, 0) + bytes := atomic.SwapUint32(&e.bytes, 0) + packets := atomic.SwapUint32(&e.packets, 0) + atomic.AddUint32(&e.totalBytes, bytes) + atomic.AddUint32(&e.totalPackets, packets) + if interval < time.Millisecond { e.rate = 0 + e.packetRate = 0 } else { - e.rate = uint32(uint64(count*1000) / uint64(interval/time.Millisecond)) + e.rate = uint32(uint64(bytes*1000) / + uint64(interval/time.Millisecond)) + e.packetRate = uint32(uint64(packets*1000) / + uint64(interval/time.Millisecond)) + } e.time = now } func (e *Estimator) Accumulate(count uint32) { - atomic.AddUint32(&e.totalBytes, count) - atomic.AddUint32(&e.totalPackets, 1) - atomic.AddUint32(&e.count, count) + atomic.AddUint32(&e.bytes, count) + atomic.AddUint32(&e.packets, 1) } -func (e *Estimator) estimate(now time.Time) uint32 { +func (e *Estimator) estimate(now time.Time) (uint32, uint32) { if now.Sub(e.time) > e.interval { e.swap(now) } - return e.rate + return e.rate, e.packetRate } -func (e *Estimator) Estimate() uint32 { +func (e *Estimator) Estimate() (uint32, uint32) { now := time.Now() e.mu.Lock() @@ -58,7 +68,7 @@ func (e *Estimator) Estimate() uint32 { } func (e *Estimator) Totals() (uint32, uint32) { - b := atomic.LoadUint32(&e.totalBytes) - p := atomic.LoadUint32(&e.totalPackets) + b := atomic.LoadUint32(&e.totalBytes) + atomic.LoadUint32(&e.bytes) + p := atomic.LoadUint32(&e.totalPackets) + atomic.LoadUint32(&e.packets) return p, b } diff --git a/estimator/estimator_test.go b/estimator/estimator_test.go index ed5926e..21ae682 100644 --- a/estimator/estimator_test.go +++ b/estimator/estimator_test.go @@ -13,11 +13,14 @@ func TestEstimator(t *testing.T) { e.Accumulate(42) e.Accumulate(128) e.estimate(now.Add(time.Second)) - rate := e.estimate(now.Add(time.Second + time.Millisecond)) + rate, packetRate := e.estimate(now.Add(time.Second + time.Millisecond)) if rate != 42+128 { t.Errorf("Expected %v, got %v", 42+128, rate) } + if packetRate != 2 { + t.Errorf("Expected 2, got %v", packetRate) + } totalP, totalB := e.Totals() if totalP != 2 { @@ -26,4 +29,15 @@ func TestEstimator(t *testing.T) { if totalB != 42+128 { t.Errorf("Expected %v, got %v", 42+128, totalB) } + + e.Accumulate(12) + + totalP, totalB = e.Totals() + if totalP != 3 { + t.Errorf("Expected 2, got %v", totalP) + } + if totalB != 42+128+12 { + t.Errorf("Expected %v, got %v", 42+128, totalB) + } + } diff --git a/group.go b/group.go index c4c8c7c..806dcd9 100644 --- a/group.go +++ b/group.go @@ -573,8 +573,9 @@ func getClientStats(c *webClient) clientStats { loss := uint8(lost * 100 / expected) jitter := time.Duration(t.jitter.Jitter()) * (time.Second / time.Duration(t.jitter.HZ())) + rate, _ := t.rate.Estimate() conns.tracks = append(conns.tracks, trackStats{ - bitrate: uint64(t.rate.Estimate()) * 8, + bitrate: uint64(rate) * 8, maxBitrate: atomic.LoadUint64(&t.maxBitrate), loss: loss, jitter: jitter, @@ -590,13 +591,14 @@ func getClientStats(c *webClient) clientStats { conns := connStats{id: down.id} for _, t := range down.tracks { jiffies := rtptime.Jiffies() + rate, _ := t.rate.Estimate() rtt := rtptime.ToDuration(atomic.LoadUint64(&t.rtt), rtptime.JiffiesPerSec) loss, jitter := t.stats.Get(jiffies) j := time.Duration(jitter) * time.Second / time.Duration(t.track.Codec().ClockRate) conns.tracks = append(conns.tracks, trackStats{ - bitrate: uint64(t.rate.Estimate()) * 8, + bitrate: uint64(rate) * 8, maxBitrate: t.GetMaxBitrate(jiffies), loss: uint8(uint32(loss) * 100 / 256), rtt: rtt, diff --git a/webclient.go b/webclient.go index 22fdaff..f6f27a7 100644 --- a/webclient.go +++ b/webclient.go @@ -847,7 +847,8 @@ func (track *rtpDownTrack) updateRate(loss uint8, now uint64) { if loss < 5 { // if our actual rate is low, then we're not probing the // bottleneck - actual := 8 * uint64(track.rate.Estimate()) + r, _ := track.rate.Estimate() + actual := 8 * uint64(r) if actual >= (rate*7)/8 { // loss < 0.02, multiply by 1.05 rate = rate * 269 / 256 @@ -937,7 +938,7 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT } case *rtcp.TransportLayerNack: maxBitrate := track.GetMaxBitrate(jiffies) - bitrate := track.rate.Estimate() + bitrate, _ := track.rate.Estimate() if uint64(bitrate)*7/8 < maxBitrate { sendRecovery(p, track) }