1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-11-26 02:25:58 +01:00

Implement loss-based congestion control on the down side.

This commit is contained in:
Juliusz Chroboczek 2020-05-03 12:25:10 +02:00
parent 5205c0773b
commit ae7e32a36a
2 changed files with 101 additions and 33 deletions

View file

@ -591,7 +591,9 @@ func addDownTrack(c *client, id string, remoteTrack *upTrack, remoteConn *upConn
track := &downTrack{ track := &downTrack{
track: local, track: local,
remote: remoteTrack, remote: remoteTrack,
maxBitrate: new(timeStampedBitrate), maxLossBitrate: new(bitrate),
maxREMBBitrate: new(bitrate),
stats: new(receiverStats),
rate: estimator.New(time.Second), rate: estimator.New(time.Second),
} }
conn.tracks = append(conn.tracks, track) conn.tracks = append(conn.tracks, track)
@ -602,6 +604,41 @@ func addDownTrack(c *client, id string, remoteTrack *upTrack, remoteConn *upConn
return conn, s, nil return conn, s, nil
} }
const (
minLossRate = 9600
initLossRate = 512 * 1000
maxLossRate = 1 << 30
)
func (track *downTrack) updateRate(loss uint8, now uint64) {
rate := track.maxLossBitrate.Get(now)
if rate > maxLossRate {
// no recent feedback, reset
rate = initLossRate
}
if loss < 5 {
// if our actual rate is low, then we're not probing the
// bottleneck
actual := 8 * uint64(track.rate.Estimate())
if actual >= (rate*7)/8 {
// loss < 0.02, multiply by 1.05
rate = rate * 269 / 256
if rate > maxLossRate {
rate = maxLossRate
}
}
} else if loss > 25 {
// loss > 0.1, multiply by (1 - loss/2)
rate = rate * (512 - uint64(loss)) / 512
if rate < minLossRate {
rate = minLossRate
}
}
// update unconditionally, to set the timestamp
track.maxLossBitrate.Set(rate, now)
}
func rtcpDownListener(g *group, conn *downConnection, track *downTrack, s *webrtc.RTPSender) { func rtcpDownListener(g *group, conn *downConnection, track *downTrack, s *webrtc.RTPSender) {
for { for {
ps, err := s.ReadRTCP() ps, err := s.ReadRTCP()
@ -620,23 +657,26 @@ func rtcpDownListener(g *group, conn *downConnection, track *downTrack, s *webrt
log.Printf("sendPLI: %v", err) log.Printf("sendPLI: %v", err)
} }
case *rtcp.ReceiverEstimatedMaximumBitrate: case *rtcp.ReceiverEstimatedMaximumBitrate:
track.maxBitrate.Set(p.Bitrate, track.maxREMBBitrate.Set(
mono.Microseconds(), p.Bitrate, mono.Microseconds(),
) )
case *rtcp.ReceiverReport: case *rtcp.ReceiverReport:
for _, r := range p.Reports { for _, r := range p.Reports {
if r.SSRC == track.track.SSRC() { if r.SSRC == track.track.SSRC() {
atomic.StoreUint32( now := mono.Microseconds()
&track.loss, track.stats.Set(
uint32(r.FractionLost), r.FractionLost,
r.Jitter,
now,
)
track.updateRate(
r.FractionLost,
now,
) )
atomic.StoreUint32(
&track.jitter,
r.Jitter)
} }
} }
case *rtcp.TransportLayerNack: case *rtcp.TransportLayerNack:
maxBitrate := track.maxBitrate.Get( maxBitrate := track.GetMaxBitrate(
mono.Microseconds(), mono.Microseconds(),
) )
bitrate := track.rate.Estimate() bitrate := track.rate.Estimate()
@ -675,7 +715,7 @@ func updateUpBitrate(up *upConnection) {
track.maxBitrate = ^uint64(0) track.maxBitrate = ^uint64(0)
local := track.getLocal() local := track.getLocal()
for _, l := range local { for _, l := range local {
bitrate := l.maxBitrate.Get(now) bitrate := l.GetMaxBitrate(now)
if bitrate == ^uint64(0) { if bitrate == ^uint64(0) {
continue continue
} }

View file

@ -72,34 +72,63 @@ type upConnection struct {
tracks []*upTrack tracks []*upTrack
} }
type timeStampedBitrate struct { type bitrate struct {
bitrate uint64 bitrate uint64
microseconds uint64 microseconds uint64
} }
func (tb *timeStampedBitrate) Set(bitrate, us uint64) { func (br *bitrate) Set(bitrate uint64, now uint64) {
// this is racy -- a reader might read the // this is racy -- a reader might read the
// data between the two writes. This shouldn't // data between the two writes. This shouldn't
// matter, we'll recover at the next sample. // matter, we'll recover at the next sample.
atomic.StoreUint64(&tb.bitrate, bitrate) atomic.StoreUint64(&br.bitrate, bitrate)
atomic.StoreUint64(&tb.microseconds, us) atomic.StoreUint64(&br.microseconds, now)
} }
func (tb *timeStampedBitrate) Get(now uint64) uint64 { func (br *bitrate) Get(now uint64) uint64 {
ts := atomic.LoadUint64(&tb.microseconds) ts := atomic.LoadUint64(&br.microseconds)
if now < ts || now > ts+4000000 { if now < ts || now > ts+4000000 {
return ^uint64(0) return ^uint64(0)
} }
return atomic.LoadUint64(&tb.bitrate) return atomic.LoadUint64(&br.bitrate)
}
type receiverStats struct {
loss uint32
jitter uint32
microseconds uint64
}
func (s *receiverStats) Set(loss uint8, jitter uint32, now uint64) {
atomic.StoreUint32(&s.loss, uint32(loss))
atomic.StoreUint32(&s.jitter, jitter)
atomic.StoreUint64(&s.microseconds, now)
}
func (s *receiverStats) Get(now uint64) (uint8, uint32) {
ts := atomic.LoadUint64(&s.microseconds)
if now < ts || now > ts+4000000 {
return 0, 0
}
return uint8(atomic.LoadUint32(&s.loss)), atomic.LoadUint32(&s.jitter)
} }
type downTrack struct { type downTrack struct {
track *webrtc.Track track *webrtc.Track
remote *upTrack remote *upTrack
maxBitrate *timeStampedBitrate maxLossBitrate *bitrate
maxREMBBitrate *bitrate
rate *estimator.Estimator rate *estimator.Estimator
loss uint32 stats *receiverStats
jitter uint32 }
func (down *downTrack) GetMaxBitrate(now uint64) uint64 {
br1 := down.maxLossBitrate.Get(now)
br2 := down.maxREMBBitrate.Get(now)
if br1 < br2 {
return br1
}
return br2
} }
type downConnection struct { type downConnection struct {
@ -725,15 +754,14 @@ func getClientStats(c *client) clientStats {
for _, down := range c.down { for _, down := range c.down {
conns := connStats{id: down.id} conns := connStats{id: down.id}
for _, t := range down.tracks { for _, t := range down.tracks {
loss := atomic.LoadUint32(&t.loss) loss, jitter := t.stats.Get(mono.Microseconds())
jitter := time.Duration(atomic.LoadUint32(&t.jitter)) * j := time.Duration(jitter) * time.Second /
time.Second /
time.Duration(t.track.Codec().ClockRate) time.Duration(t.track.Codec().ClockRate)
conns.tracks = append(conns.tracks, trackStats{ conns.tracks = append(conns.tracks, trackStats{
bitrate: uint64(t.rate.Estimate()) * 8, bitrate: uint64(t.rate.Estimate()) * 8,
maxBitrate: t.maxBitrate.Get(mono.Microseconds()), maxBitrate: t.GetMaxBitrate(mono.Microseconds()),
loss: uint8((loss * 100) / 256), loss: uint8(uint32(loss) * 100 / 256),
jitter: jitter, jitter: j,
}) })
} }
cs.down = append(cs.down, conns) cs.down = append(cs.down, conns)