mirror of
https://github.com/jech/galene.git
synced 2024-11-10 02:35:58 +01:00
Implement loss-based congestion control on the down side.
This commit is contained in:
parent
5205c0773b
commit
ae7e32a36a
2 changed files with 101 additions and 33 deletions
68
client.go
68
client.go
|
@ -589,10 +589,12 @@ func addDownTrack(c *client, id string, remoteTrack *upTrack, remoteConn *upConn
|
||||||
}
|
}
|
||||||
|
|
||||||
track := &downTrack{
|
track := &downTrack{
|
||||||
track: local,
|
track: local,
|
||||||
remote: remoteTrack,
|
remote: remoteTrack,
|
||||||
maxBitrate: new(timeStampedBitrate),
|
maxLossBitrate: new(bitrate),
|
||||||
rate: estimator.New(time.Second),
|
maxREMBBitrate: new(bitrate),
|
||||||
|
stats: new(receiverStats),
|
||||||
|
rate: estimator.New(time.Second),
|
||||||
}
|
}
|
||||||
conn.tracks = append(conn.tracks, track)
|
conn.tracks = append(conn.tracks, track)
|
||||||
remoteTrack.addLocal(track)
|
remoteTrack.addLocal(track)
|
||||||
|
@ -602,6 +604,41 @@ func addDownTrack(c *client, id string, remoteTrack *upTrack, remoteConn *upConn
|
||||||
return conn, s, nil
|
return conn, s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
minLossRate = 9600
|
||||||
|
initLossRate = 512 * 1000
|
||||||
|
maxLossRate = 1 << 30
|
||||||
|
)
|
||||||
|
|
||||||
|
func (track *downTrack) updateRate(loss uint8, now uint64) {
|
||||||
|
rate := track.maxLossBitrate.Get(now)
|
||||||
|
if rate > maxLossRate {
|
||||||
|
// no recent feedback, reset
|
||||||
|
rate = initLossRate
|
||||||
|
}
|
||||||
|
if loss < 5 {
|
||||||
|
// if our actual rate is low, then we're not probing the
|
||||||
|
// bottleneck
|
||||||
|
actual := 8 * uint64(track.rate.Estimate())
|
||||||
|
if actual >= (rate*7)/8 {
|
||||||
|
// loss < 0.02, multiply by 1.05
|
||||||
|
rate = rate * 269 / 256
|
||||||
|
if rate > maxLossRate {
|
||||||
|
rate = maxLossRate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if loss > 25 {
|
||||||
|
// loss > 0.1, multiply by (1 - loss/2)
|
||||||
|
rate = rate * (512 - uint64(loss)) / 512
|
||||||
|
if rate < minLossRate {
|
||||||
|
rate = minLossRate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// update unconditionally, to set the timestamp
|
||||||
|
track.maxLossBitrate.Set(rate, now)
|
||||||
|
}
|
||||||
|
|
||||||
func rtcpDownListener(g *group, conn *downConnection, track *downTrack, s *webrtc.RTPSender) {
|
func rtcpDownListener(g *group, conn *downConnection, track *downTrack, s *webrtc.RTPSender) {
|
||||||
for {
|
for {
|
||||||
ps, err := s.ReadRTCP()
|
ps, err := s.ReadRTCP()
|
||||||
|
@ -620,23 +657,26 @@ func rtcpDownListener(g *group, conn *downConnection, track *downTrack, s *webrt
|
||||||
log.Printf("sendPLI: %v", err)
|
log.Printf("sendPLI: %v", err)
|
||||||
}
|
}
|
||||||
case *rtcp.ReceiverEstimatedMaximumBitrate:
|
case *rtcp.ReceiverEstimatedMaximumBitrate:
|
||||||
track.maxBitrate.Set(p.Bitrate,
|
track.maxREMBBitrate.Set(
|
||||||
mono.Microseconds(),
|
p.Bitrate, mono.Microseconds(),
|
||||||
)
|
)
|
||||||
case *rtcp.ReceiverReport:
|
case *rtcp.ReceiverReport:
|
||||||
for _, r := range p.Reports {
|
for _, r := range p.Reports {
|
||||||
if r.SSRC == track.track.SSRC() {
|
if r.SSRC == track.track.SSRC() {
|
||||||
atomic.StoreUint32(
|
now := mono.Microseconds()
|
||||||
&track.loss,
|
track.stats.Set(
|
||||||
uint32(r.FractionLost),
|
r.FractionLost,
|
||||||
|
r.Jitter,
|
||||||
|
now,
|
||||||
|
)
|
||||||
|
track.updateRate(
|
||||||
|
r.FractionLost,
|
||||||
|
now,
|
||||||
)
|
)
|
||||||
atomic.StoreUint32(
|
|
||||||
&track.jitter,
|
|
||||||
r.Jitter)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case *rtcp.TransportLayerNack:
|
case *rtcp.TransportLayerNack:
|
||||||
maxBitrate := track.maxBitrate.Get(
|
maxBitrate := track.GetMaxBitrate(
|
||||||
mono.Microseconds(),
|
mono.Microseconds(),
|
||||||
)
|
)
|
||||||
bitrate := track.rate.Estimate()
|
bitrate := track.rate.Estimate()
|
||||||
|
@ -675,7 +715,7 @@ func updateUpBitrate(up *upConnection) {
|
||||||
track.maxBitrate = ^uint64(0)
|
track.maxBitrate = ^uint64(0)
|
||||||
local := track.getLocal()
|
local := track.getLocal()
|
||||||
for _, l := range local {
|
for _, l := range local {
|
||||||
bitrate := l.maxBitrate.Get(now)
|
bitrate := l.GetMaxBitrate(now)
|
||||||
if bitrate == ^uint64(0) {
|
if bitrate == ^uint64(0) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
66
group.go
66
group.go
|
@ -72,34 +72,63 @@ type upConnection struct {
|
||||||
tracks []*upTrack
|
tracks []*upTrack
|
||||||
}
|
}
|
||||||
|
|
||||||
type timeStampedBitrate struct {
|
type bitrate struct {
|
||||||
bitrate uint64
|
bitrate uint64
|
||||||
microseconds uint64
|
microseconds uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tb *timeStampedBitrate) Set(bitrate, us uint64) {
|
func (br *bitrate) Set(bitrate uint64, now uint64) {
|
||||||
// this is racy -- a reader might read the
|
// this is racy -- a reader might read the
|
||||||
// data between the two writes. This shouldn't
|
// data between the two writes. This shouldn't
|
||||||
// matter, we'll recover at the next sample.
|
// matter, we'll recover at the next sample.
|
||||||
atomic.StoreUint64(&tb.bitrate, bitrate)
|
atomic.StoreUint64(&br.bitrate, bitrate)
|
||||||
atomic.StoreUint64(&tb.microseconds, us)
|
atomic.StoreUint64(&br.microseconds, now)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tb *timeStampedBitrate) Get(now uint64) uint64 {
|
func (br *bitrate) Get(now uint64) uint64 {
|
||||||
ts := atomic.LoadUint64(&tb.microseconds)
|
ts := atomic.LoadUint64(&br.microseconds)
|
||||||
if now < ts || now > ts+4000000 {
|
if now < ts || now > ts+4000000 {
|
||||||
return ^uint64(0)
|
return ^uint64(0)
|
||||||
}
|
}
|
||||||
return atomic.LoadUint64(&tb.bitrate)
|
return atomic.LoadUint64(&br.bitrate)
|
||||||
|
}
|
||||||
|
|
||||||
|
type receiverStats struct {
|
||||||
|
loss uint32
|
||||||
|
jitter uint32
|
||||||
|
microseconds uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *receiverStats) Set(loss uint8, jitter uint32, now uint64) {
|
||||||
|
atomic.StoreUint32(&s.loss, uint32(loss))
|
||||||
|
atomic.StoreUint32(&s.jitter, jitter)
|
||||||
|
atomic.StoreUint64(&s.microseconds, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *receiverStats) Get(now uint64) (uint8, uint32) {
|
||||||
|
ts := atomic.LoadUint64(&s.microseconds)
|
||||||
|
if now < ts || now > ts+4000000 {
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
return uint8(atomic.LoadUint32(&s.loss)), atomic.LoadUint32(&s.jitter)
|
||||||
}
|
}
|
||||||
|
|
||||||
type downTrack struct {
|
type downTrack struct {
|
||||||
track *webrtc.Track
|
track *webrtc.Track
|
||||||
remote *upTrack
|
remote *upTrack
|
||||||
maxBitrate *timeStampedBitrate
|
maxLossBitrate *bitrate
|
||||||
rate *estimator.Estimator
|
maxREMBBitrate *bitrate
|
||||||
loss uint32
|
rate *estimator.Estimator
|
||||||
jitter uint32
|
stats *receiverStats
|
||||||
|
}
|
||||||
|
|
||||||
|
func (down *downTrack) GetMaxBitrate(now uint64) uint64 {
|
||||||
|
br1 := down.maxLossBitrate.Get(now)
|
||||||
|
br2 := down.maxREMBBitrate.Get(now)
|
||||||
|
if br1 < br2 {
|
||||||
|
return br1
|
||||||
|
}
|
||||||
|
return br2
|
||||||
}
|
}
|
||||||
|
|
||||||
type downConnection struct {
|
type downConnection struct {
|
||||||
|
@ -725,15 +754,14 @@ func getClientStats(c *client) clientStats {
|
||||||
for _, down := range c.down {
|
for _, down := range c.down {
|
||||||
conns := connStats{id: down.id}
|
conns := connStats{id: down.id}
|
||||||
for _, t := range down.tracks {
|
for _, t := range down.tracks {
|
||||||
loss := atomic.LoadUint32(&t.loss)
|
loss, jitter := t.stats.Get(mono.Microseconds())
|
||||||
jitter := time.Duration(atomic.LoadUint32(&t.jitter)) *
|
j := time.Duration(jitter) * time.Second /
|
||||||
time.Second /
|
|
||||||
time.Duration(t.track.Codec().ClockRate)
|
time.Duration(t.track.Codec().ClockRate)
|
||||||
conns.tracks = append(conns.tracks, trackStats{
|
conns.tracks = append(conns.tracks, trackStats{
|
||||||
bitrate: uint64(t.rate.Estimate()) * 8,
|
bitrate: uint64(t.rate.Estimate()) * 8,
|
||||||
maxBitrate: t.maxBitrate.Get(mono.Microseconds()),
|
maxBitrate: t.GetMaxBitrate(mono.Microseconds()),
|
||||||
loss: uint8((loss * 100) / 256),
|
loss: uint8(uint32(loss) * 100 / 256),
|
||||||
jitter: jitter,
|
jitter: j,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
cs.down = append(cs.down, conns)
|
cs.down = append(cs.down, conns)
|
||||||
|
|
Loading…
Reference in a new issue