1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-11-22 00:25:58 +01:00

Fix race condition in estimator.

This commit is contained in:
Juliusz Chroboczek 2022-04-20 21:27:34 +02:00
parent b5f8ea0e23
commit 461c78b0e3
3 changed files with 146 additions and 26 deletions

View file

@ -3,7 +3,7 @@
package estimator package estimator
import ( import (
"sync/atomic" "sync"
"time" "time"
"github.com/jech/galene/rtptime" "github.com/jech/galene/rtptime"
@ -11,56 +11,77 @@ import (
type Estimator struct { type Estimator struct {
interval uint64 interval uint64
mu sync.Mutex
time uint64 time uint64
bytes uint32 bytes uint32
packets uint32 packets uint32
totalBytes uint32 totalBytes uint64
totalPackets uint32 totalPackets uint64
rate uint32 rate uint32
packetRate uint32 packetRate uint32
} }
// New creates a new estimator that estimates rate over the last interval. // New creates a new estimator that estimates rate over the last interval.
func New(interval time.Duration) *Estimator { func New(interval time.Duration) *Estimator {
return new(rtptime.Now(rtptime.JiffiesPerSec), interval)
}
func new(now uint64, interval time.Duration) *Estimator {
return &Estimator{ return &Estimator{
interval: uint64( interval: uint64(
rtptime.FromDuration(interval, rtptime.JiffiesPerSec), rtptime.FromDuration(interval, rtptime.JiffiesPerSec),
), ),
time: rtptime.Now(rtptime.JiffiesPerSec), time: now,
} }
} }
// called locked
func (e *Estimator) swap(now uint64) { func (e *Estimator) swap(now uint64) {
tm := atomic.LoadUint64(&e.time) jiffies := now - e.time
jiffies := now - tm bytes := e.bytes
bytes := atomic.SwapUint32(&e.bytes, 0) e.bytes = 0
packets := atomic.SwapUint32(&e.packets, 0) packets := e.packets
atomic.AddUint32(&e.totalBytes, bytes) e.packets = 0
atomic.AddUint32(&e.totalPackets, packets) e.totalBytes += uint64(bytes)
e.totalPackets += uint64(packets)
var rate, packetRate uint32 var rate, packetRate uint32
if jiffies >= rtptime.JiffiesPerSec/1000 { if jiffies >= rtptime.JiffiesPerSec/1000 {
rate = uint32((uint64(bytes)*rtptime.JiffiesPerSec + jiffies/2) / jiffies) rate = uint32((uint64(bytes)*rtptime.JiffiesPerSec + jiffies/2) / jiffies)
packetRate = uint32((uint64(packets)*rtptime.JiffiesPerSec + jiffies/2) / jiffies) packetRate = uint32((uint64(packets)*rtptime.JiffiesPerSec + jiffies/2) / jiffies)
} }
atomic.StoreUint32(&e.rate, rate) e.rate = rate
atomic.StoreUint32(&e.packetRate, packetRate) e.packetRate = packetRate
atomic.StoreUint64(&e.time, now) e.time = now
} }
// Accumulate records one packet of size bytes // Accumulate records one packet of size bytes
func (e *Estimator) Accumulate(bytes uint32) { func (e *Estimator) Accumulate(bytes uint32) {
atomic.AddUint32(&e.bytes, bytes) e.mu.Lock()
atomic.AddUint32(&e.packets, 1) if e.bytes < ^uint32(0)-bytes {
e.bytes += bytes
}
if e.packets < ^uint32(0)-1 {
e.packets += 1
}
e.mu.Unlock()
} }
// called locked
func (e *Estimator) estimate(now uint64) (uint32, uint32) { func (e *Estimator) estimate(now uint64) (uint32, uint32) {
tm := atomic.LoadUint64(&e.time) if now < e.time {
if now < tm || now-tm > e.interval { // time went backwards
if e.time-now > e.interval {
e.time = now
e.packets = 0
e.bytes = 0
}
} else if now-e.time >= e.interval {
e.swap(now) e.swap(now)
} }
return atomic.LoadUint32(&e.rate), atomic.LoadUint32(&e.packetRate) return e.rate, e.packetRate
} }
// Estimate returns an estimate of the rate over the last interval. // Estimate returns an estimate of the rate over the last interval.
@ -68,12 +89,15 @@ func (e *Estimator) estimate(now uint64) (uint32, uint32) {
// passed to New. It returns the byte rate and the packet rate, in units // passed to New. It returns the byte rate and the packet rate, in units
// per second. // per second.
func (e *Estimator) Estimate() (uint32, uint32) { func (e *Estimator) Estimate() (uint32, uint32) {
e.mu.Lock()
defer e.mu.Unlock()
return e.estimate(rtptime.Now(rtptime.JiffiesPerSec)) return e.estimate(rtptime.Now(rtptime.JiffiesPerSec))
} }
// Totals returns the total number of bytes and packets accumulated. // Totals returns the total number of bytes and packets accumulated.
func (e *Estimator) Totals() (uint32, uint32) { func (e *Estimator) Totals() (uint64, uint64) {
b := atomic.LoadUint32(&e.totalBytes) + atomic.LoadUint32(&e.bytes) e.mu.Lock()
p := atomic.LoadUint32(&e.totalPackets) + atomic.LoadUint32(&e.packets) defer e.mu.Unlock()
return p, b return e.totalPackets + uint64(e.packets),
e.totalBytes + uint64(e.bytes)
} }

View file

@ -3,13 +3,15 @@ package estimator
import ( import (
"testing" "testing"
"time" "time"
"sync"
"sync/atomic"
"github.com/jech/galene/rtptime" "github.com/jech/galene/rtptime"
) )
func TestEstimator(t *testing.T) { func TestEstimator(t *testing.T) {
now := rtptime.Jiffies() now := rtptime.Jiffies()
e := New(time.Second) e := new(now, time.Second)
e.estimate(now) e.estimate(now)
e.Accumulate(42) e.Accumulate(42)
@ -44,3 +46,97 @@ func TestEstimator(t *testing.T) {
} }
} }
func TestEstimatorMany(t *testing.T) {
now := rtptime.Jiffies()
e := new(now, time.Second)
for i := 0; i < 10000; i++ {
e.Accumulate(42)
now += rtptime.JiffiesPerSec / 1000
b, p := e.estimate(now)
if i >= 1000 {
if p != 1000 || b != p*42 {
t.Errorf("Got %v %v (%v), expected %v %v",
p, b, 1000, i, p*42,
)
}
}
}
}
func TestEstimatorParallel(t *testing.T) {
now := make([]uint64, 1)
now[0] = rtptime.Jiffies()
getNow := func() uint64 {
return atomic.LoadUint64(&now[0])
}
addNow := func(v uint64) {
atomic.AddUint64(&now[0], v)
}
e := new(getNow(), time.Second)
estimate := func() (uint32, uint32) {
e.mu.Lock()
defer e.mu.Unlock()
return e.estimate(getNow())
}
f := func(n int) {
for i := 0; i < 10000; i++ {
e.Accumulate(42)
addNow(rtptime.JiffiesPerSec / 1000)
b, p := estimate()
if i >= 1000 {
if b != p * 42 {
t.Errorf("%v: Got %v %v (%v), expected %v %v",
n, p, b, i, 1000, p*42,
)
}
}
}
}
var wg sync.WaitGroup
for i := 0; i < 16; i++ {
wg.Add(1)
go func(i int) {
f(i)
wg.Done()
}(i)
}
wg.Wait()
}
func BenchmarkEstimator(b *testing.B) {
e := New(time.Second)
e.Estimate()
time.Sleep(time.Millisecond)
e.Estimate()
b.ResetTimer()
for i := 0; i < 1000 * b.N; i++ {
e.Accumulate(100)
}
e.Estimate()
}
func BenchmarkEstimatorParallel(b *testing.B) {
e := New(time.Second)
e.Estimate()
time.Sleep(time.Millisecond)
e.Estimate()
b.ResetTimer()
b.RunParallel(func (pb *testing.PB) {
for pb.Next() {
for i := 0; i < 1000; i++ {
e.Accumulate(100)
}
}
})
e.Estimate()
}

View file

@ -1070,8 +1070,8 @@ func sendSR(conn *rtpDownConnection) error {
SSRC: uint32(t.ssrc), SSRC: uint32(t.ssrc),
NTPTime: nowNTP, NTPTime: nowNTP,
RTPTime: nowRTP, RTPTime: nowRTP,
PacketCount: p, PacketCount: uint32(p),
OctetCount: b, OctetCount: uint32(b),
}) })
t.setSRTime(jiffies, nowNTP) t.setSRTime(jiffies, nowNTP)
} }