1
Fork 0

Implement rate estimation.

This commit is contained in:
Juliusz Chroboczek 2020-04-30 20:15:52 +02:00
parent 10526d474e
commit 5dd27e5067
4 changed files with 88 additions and 4 deletions

View File

@ -16,6 +16,7 @@ import (
"sync/atomic"
"time"
"sfu/estimator"
"sfu/packetcache"
"github.com/gorilla/websocket"
@ -290,6 +291,7 @@ func addUpConn(c *client, id string) (*upConnection, error) {
track := &upTrack{
track: remote,
cache: packetcache.New(96),
rate: estimator.New(time.Second),
maxBitrate: ^uint64(0),
}
u.tracks = append(u.tracks, track)
@ -324,21 +326,22 @@ func upLoop(conn *upConnection, track *upTrack) {
localTime = now
}
i, err := track.track.Read(buf)
bytes, err := track.track.Read(buf)
if err != nil {
if err != io.EOF {
log.Printf("%v", err)
}
break
}
track.rate.Add(uint32(bytes))
err = packet.Unmarshal(buf[:i])
err = packet.Unmarshal(buf[:bytes])
if err != nil {
log.Printf("%v", err)
continue
}
first := track.cache.Store(packet.SequenceNumber, buf[:i])
first := track.cache.Store(packet.SequenceNumber, buf[:bytes])
if packet.SequenceNumber-first > 24 {
first, bitmap := track.cache.BitmapGet()
if bitmap != ^uint16(0) {
@ -357,6 +360,7 @@ func upLoop(conn *upConnection, track *upTrack) {
if err != nil && err != io.ErrClosedPipe {
log.Printf("%v", err)
}
l.rate.Add(uint32(bytes))
}
}
}
@ -568,6 +572,7 @@ func addDownTrack(c *client, id string, remoteTrack *upTrack, remoteConn *upConn
track: local,
remote: remoteTrack,
maxBitrate: new(timeStampedBitrate),
rate: estimator.New(time.Second),
}
conn.tracks = append(conn.tracks, track)
remoteTrack.addLocal(track)
@ -758,6 +763,7 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *downTrack) {
if err != nil {
log.Printf("%v", err)
}
track.rate.Add(uint32(len(raw)))
}
}
}

54
estimator/estimator.go Normal file
View File

@ -0,0 +1,54 @@
package estimator
import (
"sync"
"sync/atomic"
"time"
)
type Estimator struct {
interval time.Duration
count uint32
mu sync.Mutex
rate uint32
time time.Time
}
func New(interval time.Duration) *Estimator {
return &Estimator{
interval: interval,
time: time.Now(),
}
}
func (e *Estimator) swap(now time.Time) {
interval := now.Sub(e.time)
count := atomic.SwapUint32(&e.count, 0)
if interval < time.Millisecond {
e.rate = 0
} else {
e.rate = uint32(uint64(count*1000) / uint64(interval/time.Millisecond))
}
e.time = now
}
func (e *Estimator) Add(count uint32) {
atomic.AddUint32(&e.count, count)
}
func (e *Estimator) estimate(now time.Time) uint32 {
if now.Sub(e.time) > e.interval {
e.swap(now)
}
return e.rate
}
func (e *Estimator) Estimate() uint32 {
now := time.Now()
e.mu.Lock()
defer e.mu.Unlock()
return e.estimate(now)
}

View File

@ -0,0 +1,21 @@
package estimator
import (
"testing"
"time"
)
func TestEstimator(t *testing.T) {
now := time.Now()
e := New(time.Second)
e.estimate(now)
e.Add(42)
e.Add(128)
e.estimate(now.Add(time.Second))
rate := e.estimate(now.Add(time.Second + time.Millisecond))
if rate != 42+128 {
t.Errorf("Expected %v, got %v", 42+128, rate)
}
}

View File

@ -16,6 +16,7 @@ import (
"sync/atomic"
"time"
"sfu/estimator"
"sfu/packetcache"
"github.com/pion/webrtc/v2"
@ -23,6 +24,7 @@ import (
type upTrack struct {
track *webrtc.Track
rate *estimator.Estimator
cache *packetcache.Cache
maxBitrate uint64
lastPLI uint64
@ -76,6 +78,7 @@ type downTrack struct {
remote *upTrack
isMuted uint32
maxBitrate *timeStampedBitrate
rate *estimator.Estimator
loss uint32
}