1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-12-21 23:05:48 +01:00

Fix race condition in estimator.

This commit is contained in:
Juliusz Chroboczek 2022-04-20 21:27:34 +02:00
parent b5f8ea0e23
commit 461c78b0e3
3 changed files with 146 additions and 26 deletions

View file

@ -3,64 +3,85 @@
package estimator
import (
"sync/atomic"
"sync"
"time"
"github.com/jech/galene/rtptime"
)
type Estimator struct {
interval uint64
interval uint64
mu sync.Mutex
time uint64
bytes uint32
packets uint32
totalBytes uint32
totalPackets uint32
totalBytes uint64
totalPackets uint64
rate uint32
packetRate uint32
}
// New creates a new estimator that estimates rate over the last interval.
func New(interval time.Duration) *Estimator {
return new(rtptime.Now(rtptime.JiffiesPerSec), interval)
}
func new(now uint64, interval time.Duration) *Estimator {
return &Estimator{
interval: uint64(
rtptime.FromDuration(interval, rtptime.JiffiesPerSec),
),
time: rtptime.Now(rtptime.JiffiesPerSec),
time: now,
}
}
// called locked
func (e *Estimator) swap(now uint64) {
tm := atomic.LoadUint64(&e.time)
jiffies := now - tm
bytes := atomic.SwapUint32(&e.bytes, 0)
packets := atomic.SwapUint32(&e.packets, 0)
atomic.AddUint32(&e.totalBytes, bytes)
atomic.AddUint32(&e.totalPackets, packets)
jiffies := now - e.time
bytes := e.bytes
e.bytes = 0
packets := e.packets
e.packets = 0
e.totalBytes += uint64(bytes)
e.totalPackets += uint64(packets)
var rate, packetRate uint32
if jiffies >= rtptime.JiffiesPerSec/1000 {
rate = uint32((uint64(bytes)*rtptime.JiffiesPerSec + jiffies/2) / jiffies)
packetRate = uint32((uint64(packets)*rtptime.JiffiesPerSec + jiffies/2) / jiffies)
}
atomic.StoreUint32(&e.rate, rate)
atomic.StoreUint32(&e.packetRate, packetRate)
atomic.StoreUint64(&e.time, now)
e.rate = rate
e.packetRate = packetRate
e.time = now
}
// Accumulate records one packet of size bytes
func (e *Estimator) Accumulate(bytes uint32) {
atomic.AddUint32(&e.bytes, bytes)
atomic.AddUint32(&e.packets, 1)
e.mu.Lock()
if e.bytes < ^uint32(0)-bytes {
e.bytes += bytes
}
if e.packets < ^uint32(0)-1 {
e.packets += 1
}
e.mu.Unlock()
}
// called locked
func (e *Estimator) estimate(now uint64) (uint32, uint32) {
tm := atomic.LoadUint64(&e.time)
if now < tm || now-tm > e.interval {
if now < e.time {
// time went backwards
if e.time-now > e.interval {
e.time = now
e.packets = 0
e.bytes = 0
}
} else if now-e.time >= e.interval {
e.swap(now)
}
return atomic.LoadUint32(&e.rate), atomic.LoadUint32(&e.packetRate)
return e.rate, e.packetRate
}
// Estimate returns an estimate of the rate over the last interval.
@ -68,12 +89,15 @@ func (e *Estimator) estimate(now uint64) (uint32, uint32) {
// passed to New. It returns the byte rate and the packet rate, in units
// per second.
func (e *Estimator) Estimate() (uint32, uint32) {
e.mu.Lock()
defer e.mu.Unlock()
return e.estimate(rtptime.Now(rtptime.JiffiesPerSec))
}
// Totals returns the total number of bytes and packets accumulated.
func (e *Estimator) Totals() (uint32, uint32) {
b := atomic.LoadUint32(&e.totalBytes) + atomic.LoadUint32(&e.bytes)
p := atomic.LoadUint32(&e.totalPackets) + atomic.LoadUint32(&e.packets)
return p, b
func (e *Estimator) Totals() (uint64, uint64) {
e.mu.Lock()
defer e.mu.Unlock()
return e.totalPackets + uint64(e.packets),
e.totalBytes + uint64(e.bytes)
}

View file

@ -3,13 +3,15 @@ package estimator
import (
"testing"
"time"
"sync"
"sync/atomic"
"github.com/jech/galene/rtptime"
)
func TestEstimator(t *testing.T) {
now := rtptime.Jiffies()
e := New(time.Second)
e := new(now, time.Second)
e.estimate(now)
e.Accumulate(42)
@ -44,3 +46,97 @@ func TestEstimator(t *testing.T) {
}
}
func TestEstimatorMany(t *testing.T) {
now := rtptime.Jiffies()
e := new(now, time.Second)
for i := 0; i < 10000; i++ {
e.Accumulate(42)
now += rtptime.JiffiesPerSec / 1000
b, p := e.estimate(now)
if i >= 1000 {
if p != 1000 || b != p*42 {
t.Errorf("Got %v %v (%v), expected %v %v",
p, b, 1000, i, p*42,
)
}
}
}
}
func TestEstimatorParallel(t *testing.T) {
now := make([]uint64, 1)
now[0] = rtptime.Jiffies()
getNow := func() uint64 {
return atomic.LoadUint64(&now[0])
}
addNow := func(v uint64) {
atomic.AddUint64(&now[0], v)
}
e := new(getNow(), time.Second)
estimate := func() (uint32, uint32) {
e.mu.Lock()
defer e.mu.Unlock()
return e.estimate(getNow())
}
f := func(n int) {
for i := 0; i < 10000; i++ {
e.Accumulate(42)
addNow(rtptime.JiffiesPerSec / 1000)
b, p := estimate()
if i >= 1000 {
if b != p * 42 {
t.Errorf("%v: Got %v %v (%v), expected %v %v",
n, p, b, i, 1000, p*42,
)
}
}
}
}
var wg sync.WaitGroup
for i := 0; i < 16; i++ {
wg.Add(1)
go func(i int) {
f(i)
wg.Done()
}(i)
}
wg.Wait()
}
func BenchmarkEstimator(b *testing.B) {
e := New(time.Second)
e.Estimate()
time.Sleep(time.Millisecond)
e.Estimate()
b.ResetTimer()
for i := 0; i < 1000 * b.N; i++ {
e.Accumulate(100)
}
e.Estimate()
}
func BenchmarkEstimatorParallel(b *testing.B) {
e := New(time.Second)
e.Estimate()
time.Sleep(time.Millisecond)
e.Estimate()
b.ResetTimer()
b.RunParallel(func (pb *testing.PB) {
for pb.Next() {
for i := 0; i < 1000; i++ {
e.Accumulate(100)
}
}
})
e.Estimate()
}

View file

@ -1070,8 +1070,8 @@ func sendSR(conn *rtpDownConnection) error {
SSRC: uint32(t.ssrc),
NTPTime: nowNTP,
RTPTime: nowRTP,
PacketCount: p,
OctetCount: b,
PacketCount: uint32(p),
OctetCount: uint32(b),
})
t.setSRTime(jiffies, nowNTP)
}