mirror of
https://github.com/jech/galene.git
synced 2024-11-09 18:25:58 +01:00
Implement rate estimation.
This commit is contained in:
parent
10526d474e
commit
5dd27e5067
4 changed files with 88 additions and 4 deletions
12
client.go
12
client.go
|
@ -16,6 +16,7 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"sfu/estimator"
|
||||||
"sfu/packetcache"
|
"sfu/packetcache"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
@ -290,6 +291,7 @@ func addUpConn(c *client, id string) (*upConnection, error) {
|
||||||
track := &upTrack{
|
track := &upTrack{
|
||||||
track: remote,
|
track: remote,
|
||||||
cache: packetcache.New(96),
|
cache: packetcache.New(96),
|
||||||
|
rate: estimator.New(time.Second),
|
||||||
maxBitrate: ^uint64(0),
|
maxBitrate: ^uint64(0),
|
||||||
}
|
}
|
||||||
u.tracks = append(u.tracks, track)
|
u.tracks = append(u.tracks, track)
|
||||||
|
@ -324,21 +326,22 @@ func upLoop(conn *upConnection, track *upTrack) {
|
||||||
localTime = now
|
localTime = now
|
||||||
}
|
}
|
||||||
|
|
||||||
i, err := track.track.Read(buf)
|
bytes, err := track.track.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
log.Printf("%v", err)
|
log.Printf("%v", err)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
track.rate.Add(uint32(bytes))
|
||||||
|
|
||||||
err = packet.Unmarshal(buf[:i])
|
err = packet.Unmarshal(buf[:bytes])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("%v", err)
|
log.Printf("%v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
first := track.cache.Store(packet.SequenceNumber, buf[:i])
|
first := track.cache.Store(packet.SequenceNumber, buf[:bytes])
|
||||||
if packet.SequenceNumber-first > 24 {
|
if packet.SequenceNumber-first > 24 {
|
||||||
first, bitmap := track.cache.BitmapGet()
|
first, bitmap := track.cache.BitmapGet()
|
||||||
if bitmap != ^uint16(0) {
|
if bitmap != ^uint16(0) {
|
||||||
|
@ -357,6 +360,7 @@ func upLoop(conn *upConnection, track *upTrack) {
|
||||||
if err != nil && err != io.ErrClosedPipe {
|
if err != nil && err != io.ErrClosedPipe {
|
||||||
log.Printf("%v", err)
|
log.Printf("%v", err)
|
||||||
}
|
}
|
||||||
|
l.rate.Add(uint32(bytes))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -568,6 +572,7 @@ func addDownTrack(c *client, id string, remoteTrack *upTrack, remoteConn *upConn
|
||||||
track: local,
|
track: local,
|
||||||
remote: remoteTrack,
|
remote: remoteTrack,
|
||||||
maxBitrate: new(timeStampedBitrate),
|
maxBitrate: new(timeStampedBitrate),
|
||||||
|
rate: estimator.New(time.Second),
|
||||||
}
|
}
|
||||||
conn.tracks = append(conn.tracks, track)
|
conn.tracks = append(conn.tracks, track)
|
||||||
remoteTrack.addLocal(track)
|
remoteTrack.addLocal(track)
|
||||||
|
@ -758,6 +763,7 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *downTrack) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("%v", err)
|
log.Printf("%v", err)
|
||||||
}
|
}
|
||||||
|
track.rate.Add(uint32(len(raw)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
54
estimator/estimator.go
Normal file
54
estimator/estimator.go
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
package estimator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Estimator struct {
|
||||||
|
interval time.Duration
|
||||||
|
count uint32
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
rate uint32
|
||||||
|
time time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(interval time.Duration) *Estimator {
|
||||||
|
return &Estimator{
|
||||||
|
interval: interval,
|
||||||
|
time: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Estimator) swap(now time.Time) {
|
||||||
|
interval := now.Sub(e.time)
|
||||||
|
count := atomic.SwapUint32(&e.count, 0)
|
||||||
|
if interval < time.Millisecond {
|
||||||
|
e.rate = 0
|
||||||
|
} else {
|
||||||
|
e.rate = uint32(uint64(count*1000) / uint64(interval/time.Millisecond))
|
||||||
|
}
|
||||||
|
e.time = now
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Estimator) Add(count uint32) {
|
||||||
|
atomic.AddUint32(&e.count, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Estimator) estimate(now time.Time) uint32 {
|
||||||
|
if now.Sub(e.time) > e.interval {
|
||||||
|
e.swap(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.rate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Estimator) Estimate() uint32 {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
return e.estimate(now)
|
||||||
|
}
|
21
estimator/estimator_test.go
Normal file
21
estimator/estimator_test.go
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
package estimator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEstimator(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
e := New(time.Second)
|
||||||
|
|
||||||
|
e.estimate(now)
|
||||||
|
e.Add(42)
|
||||||
|
e.Add(128)
|
||||||
|
e.estimate(now.Add(time.Second))
|
||||||
|
rate := e.estimate(now.Add(time.Second + time.Millisecond))
|
||||||
|
|
||||||
|
if rate != 42+128 {
|
||||||
|
t.Errorf("Expected %v, got %v", 42+128, rate)
|
||||||
|
}
|
||||||
|
}
|
5
group.go
5
group.go
|
@ -16,6 +16,7 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"sfu/estimator"
|
||||||
"sfu/packetcache"
|
"sfu/packetcache"
|
||||||
|
|
||||||
"github.com/pion/webrtc/v2"
|
"github.com/pion/webrtc/v2"
|
||||||
|
@ -23,6 +24,7 @@ import (
|
||||||
|
|
||||||
type upTrack struct {
|
type upTrack struct {
|
||||||
track *webrtc.Track
|
track *webrtc.Track
|
||||||
|
rate *estimator.Estimator
|
||||||
cache *packetcache.Cache
|
cache *packetcache.Cache
|
||||||
maxBitrate uint64
|
maxBitrate uint64
|
||||||
lastPLI uint64
|
lastPLI uint64
|
||||||
|
@ -76,6 +78,7 @@ type downTrack struct {
|
||||||
remote *upTrack
|
remote *upTrack
|
||||||
isMuted uint32
|
isMuted uint32
|
||||||
maxBitrate *timeStampedBitrate
|
maxBitrate *timeStampedBitrate
|
||||||
|
rate *estimator.Estimator
|
||||||
loss uint32
|
loss uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -706,7 +709,7 @@ func getClientStats(c *client) clientStats {
|
||||||
loss := atomic.LoadUint32(&t.loss)
|
loss := atomic.LoadUint32(&t.loss)
|
||||||
conns.tracks = append(conns.tracks, trackStats{
|
conns.tracks = append(conns.tracks, trackStats{
|
||||||
bitrate: atomic.LoadUint64(&t.maxBitrate.bitrate),
|
bitrate: atomic.LoadUint64(&t.maxBitrate.bitrate),
|
||||||
loss: uint8((loss * 100) / 256),
|
loss: uint8((loss * 100) / 256),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
cs.down = append(cs.down, conns)
|
cs.down = append(cs.down, conns)
|
||||||
|
|
Loading…
Reference in a new issue