1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-11-22 16:45:58 +01:00

Implement rate estimation.

This commit is contained in:
Juliusz Chroboczek 2020-04-30 20:15:52 +02:00
parent 10526d474e
commit 5dd27e5067
4 changed files with 88 additions and 4 deletions

View file

@ -16,6 +16,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"sfu/estimator"
"sfu/packetcache" "sfu/packetcache"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -290,6 +291,7 @@ func addUpConn(c *client, id string) (*upConnection, error) {
track := &upTrack{ track := &upTrack{
track: remote, track: remote,
cache: packetcache.New(96), cache: packetcache.New(96),
rate: estimator.New(time.Second),
maxBitrate: ^uint64(0), maxBitrate: ^uint64(0),
} }
u.tracks = append(u.tracks, track) u.tracks = append(u.tracks, track)
@ -324,21 +326,22 @@ func upLoop(conn *upConnection, track *upTrack) {
localTime = now localTime = now
} }
i, err := track.track.Read(buf) bytes, err := track.track.Read(buf)
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Printf("%v", err) log.Printf("%v", err)
} }
break break
} }
track.rate.Add(uint32(bytes))
err = packet.Unmarshal(buf[:i]) err = packet.Unmarshal(buf[:bytes])
if err != nil { if err != nil {
log.Printf("%v", err) log.Printf("%v", err)
continue continue
} }
first := track.cache.Store(packet.SequenceNumber, buf[:i]) first := track.cache.Store(packet.SequenceNumber, buf[:bytes])
if packet.SequenceNumber-first > 24 { if packet.SequenceNumber-first > 24 {
first, bitmap := track.cache.BitmapGet() first, bitmap := track.cache.BitmapGet()
if bitmap != ^uint16(0) { if bitmap != ^uint16(0) {
@ -357,6 +360,7 @@ func upLoop(conn *upConnection, track *upTrack) {
if err != nil && err != io.ErrClosedPipe { if err != nil && err != io.ErrClosedPipe {
log.Printf("%v", err) log.Printf("%v", err)
} }
l.rate.Add(uint32(bytes))
} }
} }
} }
@ -568,6 +572,7 @@ func addDownTrack(c *client, id string, remoteTrack *upTrack, remoteConn *upConn
track: local, track: local,
remote: remoteTrack, remote: remoteTrack,
maxBitrate: new(timeStampedBitrate), maxBitrate: new(timeStampedBitrate),
rate: estimator.New(time.Second),
} }
conn.tracks = append(conn.tracks, track) conn.tracks = append(conn.tracks, track)
remoteTrack.addLocal(track) remoteTrack.addLocal(track)
@ -758,6 +763,7 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *downTrack) {
if err != nil { if err != nil {
log.Printf("%v", err) log.Printf("%v", err)
} }
track.rate.Add(uint32(len(raw)))
} }
} }
} }

54
estimator/estimator.go Normal file
View file

@ -0,0 +1,54 @@
package estimator
import (
"sync"
"sync/atomic"
"time"
)
type Estimator struct {
interval time.Duration
count uint32
mu sync.Mutex
rate uint32
time time.Time
}
func New(interval time.Duration) *Estimator {
return &Estimator{
interval: interval,
time: time.Now(),
}
}
func (e *Estimator) swap(now time.Time) {
interval := now.Sub(e.time)
count := atomic.SwapUint32(&e.count, 0)
if interval < time.Millisecond {
e.rate = 0
} else {
e.rate = uint32(uint64(count*1000) / uint64(interval/time.Millisecond))
}
e.time = now
}
func (e *Estimator) Add(count uint32) {
atomic.AddUint32(&e.count, count)
}
func (e *Estimator) estimate(now time.Time) uint32 {
if now.Sub(e.time) > e.interval {
e.swap(now)
}
return e.rate
}
func (e *Estimator) Estimate() uint32 {
now := time.Now()
e.mu.Lock()
defer e.mu.Unlock()
return e.estimate(now)
}

View file

@ -0,0 +1,21 @@
package estimator
import (
"testing"
"time"
)
func TestEstimator(t *testing.T) {
now := time.Now()
e := New(time.Second)
e.estimate(now)
e.Add(42)
e.Add(128)
e.estimate(now.Add(time.Second))
rate := e.estimate(now.Add(time.Second + time.Millisecond))
if rate != 42+128 {
t.Errorf("Expected %v, got %v", 42+128, rate)
}
}

View file

@ -16,6 +16,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"sfu/estimator"
"sfu/packetcache" "sfu/packetcache"
"github.com/pion/webrtc/v2" "github.com/pion/webrtc/v2"
@ -23,6 +24,7 @@ import (
type upTrack struct { type upTrack struct {
track *webrtc.Track track *webrtc.Track
rate *estimator.Estimator
cache *packetcache.Cache cache *packetcache.Cache
maxBitrate uint64 maxBitrate uint64
lastPLI uint64 lastPLI uint64
@ -76,6 +78,7 @@ type downTrack struct {
remote *upTrack remote *upTrack
isMuted uint32 isMuted uint32
maxBitrate *timeStampedBitrate maxBitrate *timeStampedBitrate
rate *estimator.Estimator
loss uint32 loss uint32
} }
@ -706,7 +709,7 @@ func getClientStats(c *client) clientStats {
loss := atomic.LoadUint32(&t.loss) loss := atomic.LoadUint32(&t.loss)
conns.tracks = append(conns.tracks, trackStats{ conns.tracks = append(conns.tracks, trackStats{
bitrate: atomic.LoadUint64(&t.maxBitrate.bitrate), bitrate: atomic.LoadUint64(&t.maxBitrate.bitrate),
loss: uint8((loss * 100) / 256), loss: uint8((loss * 100) / 256),
}) })
} }
cs.down = append(cs.down, conns) cs.down = append(cs.down, conns)