From 5dd27e506725ddb07676ade24e2124e580eec180 Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Thu, 30 Apr 2020 20:15:52 +0200 Subject: [PATCH] Implement rate estimation. --- client.go | 12 ++++++--- estimator/estimator.go | 54 +++++++++++++++++++++++++++++++++++++ estimator/estimator_test.go | 21 +++++++++++++++ group.go | 5 +++- 4 files changed, 88 insertions(+), 4 deletions(-) create mode 100644 estimator/estimator.go create mode 100644 estimator/estimator_test.go diff --git a/client.go b/client.go index c709efb..52417e7 100644 --- a/client.go +++ b/client.go @@ -16,6 +16,7 @@ import ( "sync/atomic" "time" + "sfu/estimator" "sfu/packetcache" "github.com/gorilla/websocket" @@ -290,6 +291,7 @@ func addUpConn(c *client, id string) (*upConnection, error) { track := &upTrack{ track: remote, cache: packetcache.New(96), + rate: estimator.New(time.Second), maxBitrate: ^uint64(0), } u.tracks = append(u.tracks, track) @@ -324,21 +326,22 @@ func upLoop(conn *upConnection, track *upTrack) { localTime = now } - i, err := track.track.Read(buf) + bytes, err := track.track.Read(buf) if err != nil { if err != io.EOF { log.Printf("%v", err) } break } + track.rate.Add(uint32(bytes)) - err = packet.Unmarshal(buf[:i]) + err = packet.Unmarshal(buf[:bytes]) if err != nil { log.Printf("%v", err) continue } - first := track.cache.Store(packet.SequenceNumber, buf[:i]) + first := track.cache.Store(packet.SequenceNumber, buf[:bytes]) if packet.SequenceNumber-first > 24 { first, bitmap := track.cache.BitmapGet() if bitmap != ^uint16(0) { @@ -357,6 +360,7 @@ func upLoop(conn *upConnection, track *upTrack) { if err != nil && err != io.ErrClosedPipe { log.Printf("%v", err) } + l.rate.Add(uint32(bytes)) } } } @@ -568,6 +572,7 @@ func addDownTrack(c *client, id string, remoteTrack *upTrack, remoteConn *upConn track: local, remote: remoteTrack, maxBitrate: new(timeStampedBitrate), + rate: estimator.New(time.Second), } conn.tracks = append(conn.tracks, track) remoteTrack.addLocal(track) @@ -758,6 +763,7 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *downTrack) { if err != nil { log.Printf("%v", err) } + track.rate.Add(uint32(len(raw))) } } } diff --git a/estimator/estimator.go b/estimator/estimator.go new file mode 100644 index 0000000..a5800a3 --- /dev/null +++ b/estimator/estimator.go @@ -0,0 +1,54 @@ +package estimator + +import ( + "sync" + "sync/atomic" + "time" +) + +type Estimator struct { + interval time.Duration + count uint32 + + mu sync.Mutex + rate uint32 + time time.Time +} + +func New(interval time.Duration) *Estimator { + return &Estimator{ + interval: interval, + time: time.Now(), + } +} + +func (e *Estimator) swap(now time.Time) { + interval := now.Sub(e.time) + count := atomic.SwapUint32(&e.count, 0) + if interval < time.Millisecond { + e.rate = 0 + } else { + e.rate = uint32(uint64(count*1000) / uint64(interval/time.Millisecond)) + } + e.time = now +} + +func (e *Estimator) Add(count uint32) { + atomic.AddUint32(&e.count, count) +} + +func (e *Estimator) estimate(now time.Time) uint32 { + if now.Sub(e.time) > e.interval { + e.swap(now) + } + + return e.rate +} + +func (e *Estimator) Estimate() uint32 { + now := time.Now() + + e.mu.Lock() + defer e.mu.Unlock() + return e.estimate(now) +} diff --git a/estimator/estimator_test.go b/estimator/estimator_test.go new file mode 100644 index 0000000..a0a9933 --- /dev/null +++ b/estimator/estimator_test.go @@ -0,0 +1,21 @@ +package estimator + +import ( + "testing" + "time" +) + +func TestEstimator(t *testing.T) { + now := time.Now() + e := New(time.Second) + + e.estimate(now) + e.Add(42) + e.Add(128) + e.estimate(now.Add(time.Second)) + rate := e.estimate(now.Add(time.Second + time.Millisecond)) + + if rate != 42+128 { + t.Errorf("Expected %v, got %v", 42+128, rate) + } +} diff --git a/group.go b/group.go index dc1913d..8ff5bd7 100644 --- a/group.go +++ b/group.go @@ -16,6 +16,7 @@ import ( "sync/atomic" "time" + "sfu/estimator" "sfu/packetcache" "github.com/pion/webrtc/v2" @@ -23,6 +24,7 @@ import ( type upTrack struct { track *webrtc.Track + rate *estimator.Estimator cache *packetcache.Cache maxBitrate uint64 lastPLI uint64 @@ -76,6 +78,7 @@ type downTrack struct { remote *upTrack isMuted uint32 maxBitrate *timeStampedBitrate + rate *estimator.Estimator loss uint32 } @@ -706,7 +709,7 @@ func getClientStats(c *client) clientStats { loss := atomic.LoadUint32(&t.loss) conns.tracks = append(conns.tracks, trackStats{ bitrate: atomic.LoadUint64(&t.maxBitrate.bitrate), - loss: uint8((loss * 100) / 256), + loss: uint8((loss * 100) / 256), }) } cs.down = append(cs.down, conns)