mirror of
https://github.com/jech/galene.git
synced 2024-12-22 23:35:46 +01:00
Implement rate estimation.
This commit is contained in:
parent
10526d474e
commit
5dd27e5067
4 changed files with 88 additions and 4 deletions
12
client.go
12
client.go
|
@ -16,6 +16,7 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"sfu/estimator"
|
||||
"sfu/packetcache"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
@ -290,6 +291,7 @@ func addUpConn(c *client, id string) (*upConnection, error) {
|
|||
track := &upTrack{
|
||||
track: remote,
|
||||
cache: packetcache.New(96),
|
||||
rate: estimator.New(time.Second),
|
||||
maxBitrate: ^uint64(0),
|
||||
}
|
||||
u.tracks = append(u.tracks, track)
|
||||
|
@ -324,21 +326,22 @@ func upLoop(conn *upConnection, track *upTrack) {
|
|||
localTime = now
|
||||
}
|
||||
|
||||
i, err := track.track.Read(buf)
|
||||
bytes, err := track.track.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Printf("%v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
track.rate.Add(uint32(bytes))
|
||||
|
||||
err = packet.Unmarshal(buf[:i])
|
||||
err = packet.Unmarshal(buf[:bytes])
|
||||
if err != nil {
|
||||
log.Printf("%v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
first := track.cache.Store(packet.SequenceNumber, buf[:i])
|
||||
first := track.cache.Store(packet.SequenceNumber, buf[:bytes])
|
||||
if packet.SequenceNumber-first > 24 {
|
||||
first, bitmap := track.cache.BitmapGet()
|
||||
if bitmap != ^uint16(0) {
|
||||
|
@ -357,6 +360,7 @@ func upLoop(conn *upConnection, track *upTrack) {
|
|||
if err != nil && err != io.ErrClosedPipe {
|
||||
log.Printf("%v", err)
|
||||
}
|
||||
l.rate.Add(uint32(bytes))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -568,6 +572,7 @@ func addDownTrack(c *client, id string, remoteTrack *upTrack, remoteConn *upConn
|
|||
track: local,
|
||||
remote: remoteTrack,
|
||||
maxBitrate: new(timeStampedBitrate),
|
||||
rate: estimator.New(time.Second),
|
||||
}
|
||||
conn.tracks = append(conn.tracks, track)
|
||||
remoteTrack.addLocal(track)
|
||||
|
@ -758,6 +763,7 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *downTrack) {
|
|||
if err != nil {
|
||||
log.Printf("%v", err)
|
||||
}
|
||||
track.rate.Add(uint32(len(raw)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
54
estimator/estimator.go
Normal file
54
estimator/estimator.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package estimator
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Estimator struct {
|
||||
interval time.Duration
|
||||
count uint32
|
||||
|
||||
mu sync.Mutex
|
||||
rate uint32
|
||||
time time.Time
|
||||
}
|
||||
|
||||
func New(interval time.Duration) *Estimator {
|
||||
return &Estimator{
|
||||
interval: interval,
|
||||
time: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Estimator) swap(now time.Time) {
|
||||
interval := now.Sub(e.time)
|
||||
count := atomic.SwapUint32(&e.count, 0)
|
||||
if interval < time.Millisecond {
|
||||
e.rate = 0
|
||||
} else {
|
||||
e.rate = uint32(uint64(count*1000) / uint64(interval/time.Millisecond))
|
||||
}
|
||||
e.time = now
|
||||
}
|
||||
|
||||
func (e *Estimator) Add(count uint32) {
|
||||
atomic.AddUint32(&e.count, count)
|
||||
}
|
||||
|
||||
func (e *Estimator) estimate(now time.Time) uint32 {
|
||||
if now.Sub(e.time) > e.interval {
|
||||
e.swap(now)
|
||||
}
|
||||
|
||||
return e.rate
|
||||
}
|
||||
|
||||
func (e *Estimator) Estimate() uint32 {
|
||||
now := time.Now()
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.estimate(now)
|
||||
}
|
21
estimator/estimator_test.go
Normal file
21
estimator/estimator_test.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
package estimator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEstimator(t *testing.T) {
|
||||
now := time.Now()
|
||||
e := New(time.Second)
|
||||
|
||||
e.estimate(now)
|
||||
e.Add(42)
|
||||
e.Add(128)
|
||||
e.estimate(now.Add(time.Second))
|
||||
rate := e.estimate(now.Add(time.Second + time.Millisecond))
|
||||
|
||||
if rate != 42+128 {
|
||||
t.Errorf("Expected %v, got %v", 42+128, rate)
|
||||
}
|
||||
}
|
5
group.go
5
group.go
|
@ -16,6 +16,7 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"sfu/estimator"
|
||||
"sfu/packetcache"
|
||||
|
||||
"github.com/pion/webrtc/v2"
|
||||
|
@ -23,6 +24,7 @@ import (
|
|||
|
||||
type upTrack struct {
|
||||
track *webrtc.Track
|
||||
rate *estimator.Estimator
|
||||
cache *packetcache.Cache
|
||||
maxBitrate uint64
|
||||
lastPLI uint64
|
||||
|
@ -76,6 +78,7 @@ type downTrack struct {
|
|||
remote *upTrack
|
||||
isMuted uint32
|
||||
maxBitrate *timeStampedBitrate
|
||||
rate *estimator.Estimator
|
||||
loss uint32
|
||||
}
|
||||
|
||||
|
@ -706,7 +709,7 @@ func getClientStats(c *client) clientStats {
|
|||
loss := atomic.LoadUint32(&t.loss)
|
||||
conns.tracks = append(conns.tracks, trackStats{
|
||||
bitrate: atomic.LoadUint64(&t.maxBitrate.bitrate),
|
||||
loss: uint8((loss * 100) / 256),
|
||||
loss: uint8((loss * 100) / 256),
|
||||
})
|
||||
}
|
||||
cs.down = append(cs.down, conns)
|
||||
|
|
Loading…
Reference in a new issue