1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-11-22 16:45:58 +01:00

Perform congestion control at the connection level.

REMB applies to the whole transport, not to individual tracks.
This commit is contained in:
Juliusz Chroboczek 2020-06-12 17:39:16 +02:00
parent 903e499dd6
commit a4d0741704
6 changed files with 131 additions and 138 deletions

View file

@ -34,11 +34,11 @@ type upTrack interface {
} }
type downConnection interface { type downConnection interface {
GetMaxBitrate(now uint64) uint64
} }
type downTrack interface { type downTrack interface {
WriteRTP(packat *rtp.Packet) error WriteRTP(packat *rtp.Packet) error
Accumulate(bytes uint32) Accumulate(bytes uint32)
GetMaxBitrate(now uint64) uint64
setTimeOffset(ntp uint64, rtp uint32) setTimeOffset(ntp uint64, rtp uint32)
} }

View file

@ -382,10 +382,10 @@ func (conn *diskConn) initWriter(width, height uint32) error {
return nil return nil
} }
func (down *diskConn) GetMaxBitrate(now uint64) uint64 {
return ^uint64(0)
}
func (t *diskTrack) Accumulate(bytes uint32) { func (t *diskTrack) Accumulate(bytes uint32) {
return return
} }
func (down *diskTrack) GetMaxBitrate(now uint64) uint64 {
return ^uint64(0)
}

View file

@ -37,8 +37,7 @@ type chatHistoryEntry struct {
} }
const ( const (
minVideoRate = 200000 minBitrate = 200000
minAudioRate = 9600
) )
type group struct { type group struct {
@ -506,8 +505,9 @@ type clientStats struct {
} }
type connStats struct { type connStats struct {
id string id string
tracks []trackStats maxBitrate uint64
tracks []trackStats
} }
type trackStats struct { type trackStats struct {
@ -560,7 +560,9 @@ func getClientStats(c *webClient) clientStats {
} }
for _, up := range c.up { for _, up := range c.up {
conns := connStats{id: up.id} conns := connStats{
id: up.id,
}
tracks := up.getTracks() tracks := up.getTracks()
for _, t := range tracks { for _, t := range tracks {
expected, lost, _, _ := t.cache.GetStats(false) expected, lost, _, _ := t.cache.GetStats(false)
@ -572,10 +574,9 @@ func getClientStats(c *webClient) clientStats {
(time.Second / time.Duration(t.jitter.HZ())) (time.Second / time.Duration(t.jitter.HZ()))
rate, _ := t.rate.Estimate() rate, _ := t.rate.Estimate()
conns.tracks = append(conns.tracks, trackStats{ conns.tracks = append(conns.tracks, trackStats{
bitrate: uint64(rate) * 8, bitrate: uint64(rate) * 8,
maxBitrate: atomic.LoadUint64(&t.maxBitrate), loss: loss,
loss: loss, jitter: jitter,
jitter: jitter,
}) })
} }
cs.up = append(cs.up, conns) cs.up = append(cs.up, conns)
@ -584,10 +585,13 @@ func getClientStats(c *webClient) clientStats {
return cs.up[i].id < cs.up[j].id return cs.up[i].id < cs.up[j].id
}) })
jiffies := rtptime.Jiffies()
for _, down := range c.down { for _, down := range c.down {
conns := connStats{id: down.id} conns := connStats{
id: down.id,
maxBitrate: down.GetMaxBitrate(jiffies),
}
for _, t := range down.tracks { for _, t := range down.tracks {
jiffies := rtptime.Jiffies()
rate, _ := t.rate.Estimate() rate, _ := t.rate.Estimate()
rtt := rtptime.ToDuration(atomic.LoadUint64(&t.rtt), rtt := rtptime.ToDuration(atomic.LoadUint64(&t.rtt),
rtptime.JiffiesPerSec) rtptime.JiffiesPerSec)
@ -596,7 +600,7 @@ func getClientStats(c *webClient) clientStats {
time.Duration(t.track.Codec().ClockRate) time.Duration(t.track.Codec().ClockRate)
conns.tracks = append(conns.tracks, trackStats{ conns.tracks = append(conns.tracks, trackStats{
bitrate: uint64(rate) * 8, bitrate: uint64(rate) * 8,
maxBitrate: t.GetMaxBitrate(jiffies), maxBitrate: t.maxBitrate.Get(jiffies),
loss: uint8(uint32(loss) * 100 / 256), loss: uint8(uint32(loss) * 100 / 256),
rtt: rtt, rtt: rtt,
jitter: j, jitter: j,

View file

@ -70,17 +70,16 @@ type iceConnection interface {
} }
type rtpDownTrack struct { type rtpDownTrack struct {
track *webrtc.Track track *webrtc.Track
remote upTrack remote upTrack
maxLossBitrate *bitrate maxBitrate *bitrate
maxREMBBitrate *bitrate rate *estimator.Estimator
rate *estimator.Estimator stats *receiverStats
stats *receiverStats srTime uint64
srTime uint64 srNTPTime uint64
srNTPTime uint64 remoteNTPTime uint64
remoteNTPTime uint64 remoteRTPTime uint32
remoteRTPTime uint32 rtt uint64
rtt uint64
} }
func (down *rtpDownTrack) WriteRTP(packet *rtp.Packet) error { func (down *rtpDownTrack) WriteRTP(packet *rtp.Packet) error {
@ -91,26 +90,18 @@ func (down *rtpDownTrack) Accumulate(bytes uint32) {
down.rate.Accumulate(bytes) down.rate.Accumulate(bytes)
} }
func (down *rtpDownTrack) GetMaxBitrate(now uint64) uint64 {
br1 := down.maxLossBitrate.Get(now)
br2 := down.maxREMBBitrate.Get(now)
if br1 < br2 {
return br1
}
return br2
}
func (down *rtpDownTrack) setTimeOffset(ntp uint64, rtp uint32) { func (down *rtpDownTrack) setTimeOffset(ntp uint64, rtp uint32) {
atomic.StoreUint64(&down.remoteNTPTime, ntp) atomic.StoreUint64(&down.remoteNTPTime, ntp)
atomic.StoreUint32(&down.remoteRTPTime, rtp) atomic.StoreUint32(&down.remoteRTPTime, rtp)
} }
type rtpDownConnection struct { type rtpDownConnection struct {
id string id string
pc *webrtc.PeerConnection pc *webrtc.PeerConnection
remote upConnection remote upConnection
tracks []*rtpDownTrack tracks []*rtpDownTrack
iceCandidates []*webrtc.ICECandidateInit maxREMBBitrate *bitrate
iceCandidates []*webrtc.ICECandidateInit
} }
func newDownConn(id string, remote upConnection) (*rtpDownConnection, error) { func newDownConn(id string, remote upConnection) (*rtpDownConnection, error) {
@ -124,14 +115,35 @@ func newDownConn(id string, remote upConnection) (*rtpDownConnection, error) {
}) })
conn := &rtpDownConnection{ conn := &rtpDownConnection{
id: id, id: id,
pc: pc, pc: pc,
remote: remote, remote: remote,
maxREMBBitrate: new(bitrate),
} }
return conn, nil return conn, nil
} }
func (down *rtpDownConnection) GetMaxBitrate(now uint64) uint64 {
rate := down.maxREMBBitrate.Get(now)
var trackRate uint64
for _, t := range down.tracks {
r := t.maxBitrate.Get(now)
if r == ^uint64(0) {
if t.track.Kind() == webrtc.RTPCodecTypeAudio {
r = 128 * 1024
} else {
r = 512 * 1024
}
}
trackRate += r
}
if trackRate < rate {
return trackRate
}
return rate
}
func (down *rtpDownConnection) addICECandidate(candidate *webrtc.ICECandidateInit) error { func (down *rtpDownConnection) addICECandidate(candidate *webrtc.ICECandidateInit) error {
if down.pc.RemoteDescription() != nil { if down.pc.RemoteDescription() != nil {
return down.pc.AddICECandidate(*candidate) return down.pc.AddICECandidate(*candidate)
@ -162,15 +174,14 @@ func (down *rtpDownConnection) flushICECandidates() error {
} }
type rtpUpTrack struct { type rtpUpTrack struct {
track *webrtc.Track track *webrtc.Track
label string label string
rate *estimator.Estimator rate *estimator.Estimator
cache *packetcache.Cache cache *packetcache.Cache
jitter *jitter.Estimator jitter *jitter.Estimator
maxBitrate uint64 lastPLI uint64
lastPLI uint64 lastFIR uint64
lastFIR uint64 firSeqno uint32
firSeqno uint32
localCh chan localTrackAction localCh chan localTrackAction
writerDone chan struct{} writerDone chan struct{}
@ -422,7 +433,6 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) {
cache: packetcache.New(32), cache: packetcache.New(32),
rate: estimator.New(time.Second), rate: estimator.New(time.Second),
jitter: jitter.New(remote.Codec().ClockRate), jitter: jitter.New(remote.Codec().ClockRate),
maxBitrate: ^uint64(0),
localCh: make(chan localTrackAction, 2), localCh: make(chan localTrackAction, 2),
writerDone: make(chan struct{}), writerDone: make(chan struct{}),
} }
@ -690,15 +700,6 @@ func sendFIR(pc *webrtc.PeerConnection, ssrc uint32, seqno uint8) error {
}) })
} }
func sendREMB(pc *webrtc.PeerConnection, ssrc uint32, bitrate uint64) error {
return pc.WriteRTCP([]rtcp.Packet{
&rtcp.ReceiverEstimatedMaximumBitrate{
Bitrate: bitrate,
SSRCs: []uint32{ssrc},
},
})
}
func (up *rtpUpConnection) sendNACK(track *rtpUpTrack, first uint16, bitmap uint16) error { func (up *rtpUpConnection) sendNACK(track *rtpUpTrack, first uint16, bitmap uint16) error {
if !track.hasRtcpFb("nack", "") { if !track.hasRtcpFb("nack", "") {
return nil return nil
@ -797,7 +798,7 @@ func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPRecei
} }
} }
func sendRR(conn *rtpUpConnection) error { func sendUpRTCP(conn *rtpUpConnection) error {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@ -813,6 +814,7 @@ func sendRR(conn *rtpUpConnection) error {
reports := make([]rtcp.ReceptionReport, 0, len(conn.tracks)) reports := make([]rtcp.ReceptionReport, 0, len(conn.tracks))
for _, t := range conn.tracks { for _, t := range conn.tracks {
updateUpTrack(t)
expected, lost, totalLost, eseqno := t.cache.GetStats(true) expected, lost, totalLost, eseqno := t.cache.GetStats(true)
if expected == 0 { if expected == 0 {
expected = 1 expected = 1
@ -843,17 +845,46 @@ func sendRR(conn *rtpUpConnection) error {
}) })
} }
return conn.pc.WriteRTCP([]rtcp.Packet{ packets := []rtcp.Packet{
&rtcp.ReceiverReport{ &rtcp.ReceiverReport{
Reports: reports, Reports: reports,
}, },
}) }
rate := ^uint64(0)
for _, l := range conn.local {
r := l.GetMaxBitrate(now)
if r < rate {
rate = r
}
}
if rate < minBitrate {
rate = minBitrate
}
var ssrcs []uint32
for _, t := range conn.tracks {
if t.hasRtcpFb("goog-remb", "") {
continue
}
ssrcs = append(ssrcs, t.track.SSRC())
}
if len(ssrcs) > 0 {
packets = append(packets,
&rtcp.ReceiverEstimatedMaximumBitrate{
Bitrate: rate,
SSRCs: ssrcs,
},
)
}
return conn.pc.WriteRTCP(packets)
} }
func rtcpUpSender(conn *rtpUpConnection) { func rtcpUpSender(conn *rtpUpConnection) {
for { for {
time.Sleep(time.Second) time.Sleep(time.Second)
err := sendRR(conn) err := sendUpRTCP(conn)
if err != nil { if err != nil {
if err == io.EOF || err == io.ErrClosedPipe { if err == io.EOF || err == io.ErrClosedPipe {
return return
@ -936,7 +967,7 @@ const (
) )
func (track *rtpDownTrack) updateRate(loss uint8, now uint64) { func (track *rtpDownTrack) updateRate(loss uint8, now uint64) {
rate := track.maxLossBitrate.Get(now) rate := track.maxBitrate.Get(now)
if rate < minLossRate || rate > maxLossRate { if rate < minLossRate || rate > maxLossRate {
// no recent feedback, reset // no recent feedback, reset
rate = initLossRate rate = initLossRate
@ -962,7 +993,7 @@ func (track *rtpDownTrack) updateRate(loss uint8, now uint64) {
} }
// update unconditionally, to set the timestamp // update unconditionally, to set the timestamp
track.maxLossBitrate.Set(rate, now) track.maxBitrate.Set(rate, now)
} }
func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RTPSender) { func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RTPSender) {
@ -1034,7 +1065,7 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT
log.Printf("sendFIR: %v", err) log.Printf("sendFIR: %v", err)
} }
case *rtcp.ReceiverEstimatedMaximumBitrate: case *rtcp.ReceiverEstimatedMaximumBitrate:
track.maxREMBBitrate.Set(p.Bitrate, jiffies) conn.maxREMBBitrate.Set(p.Bitrate, jiffies)
case *rtcp.ReceiverReport: case *rtcp.ReceiverReport:
for _, r := range p.Reports { for _, r := range p.Reports {
if r.SSRC == track.track.SSRC() { if r.SSRC == track.track.SSRC() {
@ -1048,11 +1079,7 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT
} }
} }
case *rtcp.TransportLayerNack: case *rtcp.TransportLayerNack:
maxBitrate := track.GetMaxBitrate(jiffies) sendRecovery(p, track)
bitrate, _ := track.rate.Estimate()
if uint64(bitrate)*7/8 < maxBitrate {
sendRecovery(p, track)
}
} }
} }
} }
@ -1086,37 +1113,18 @@ func handleReport(track *rtpDownTrack, report rtcp.ReceptionReport, jiffies uint
} }
} }
func updateUpTrack(track *rtpUpTrack) uint64 { func updateUpTrack(track *rtpUpTrack) {
now := rtptime.Jiffies() now := rtptime.Jiffies()
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
clockrate := track.track.Codec().ClockRate clockrate := track.track.Codec().ClockRate
minrate := uint64(minAudioRate)
rate := ^uint64(0)
if isvideo {
minrate = minVideoRate
rate = ^uint64(0)
}
local := track.getLocal() local := track.getLocal()
var maxrto uint64 var maxrto uint64
for _, l := range local { for _, l := range local {
bitrate := l.GetMaxBitrate(now)
if bitrate == ^uint64(0) {
continue
}
if bitrate <= minrate {
rate = minrate
break
}
if rate > bitrate {
rate = bitrate
}
ll, ok := l.(*rtpDownTrack) ll, ok := l.(*rtpDownTrack)
if ok { if ok {
_, j := ll.stats.Get(now) _, j := ll.stats.Get(now)
jitter := uint64(j) * jitter := uint64(j) *
(rtptime.JiffiesPerSec / (rtptime.JiffiesPerSec / uint64(clockrate))
uint64(clockrate))
rtt := atomic.LoadUint64(&ll.rtt) rtt := atomic.LoadUint64(&ll.rtt)
rto := rtt + 4*jitter rto := rtt + 4*jitter
if rto > maxrto { if rto > maxrto {
@ -1124,7 +1132,6 @@ func updateUpTrack(track *rtpUpTrack) uint64 {
} }
} }
} }
track.maxBitrate = rate
_, r := track.rate.Estimate() _, r := track.rate.Estimate()
packets := int((uint64(r) * maxrto * 4) / rtptime.JiffiesPerSec) packets := int((uint64(r) * maxrto * 4) / rtptime.JiffiesPerSec)
if packets < 32 { if packets < 32 {
@ -1134,6 +1141,4 @@ func updateUpTrack(track *rtpUpTrack) uint64 {
packets = 256 packets = 256
} }
track.cache.ResizeCond(packets) track.cache.ResizeCond(packets)
return rate
} }

View file

@ -380,12 +380,11 @@ func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack upTrack, re
} }
track := &rtpDownTrack{ track := &rtpDownTrack{
track: local, track: local,
remote: remoteTrack, remote: remoteTrack,
maxLossBitrate: new(bitrate), maxBitrate: new(bitrate),
maxREMBBitrate: new(bitrate), stats: new(receiverStats),
stats: new(receiverStats), rate: estimator.New(time.Second),
rate: estimator.New(time.Second),
} }
conn.tracks = append(conn.tracks, track) conn.tracks = append(conn.tracks, track)
@ -692,10 +691,8 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
readTime := time.Now() readTime := time.Now()
ticker := time.NewTicker(time.Second) ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop() defer ticker.Stop()
slowTicker := time.NewTicker(10 * time.Second)
defer slowTicker.Stop()
for { for {
select { select {
@ -766,7 +763,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
go a.c.pushConn(u.id, u, ts, u.label) go a.c.pushConn(u.id, u, ts, u.label)
} }
case connectionFailedAction: case connectionFailedAction:
down := getDownConn(c, a.id); down := getDownConn(c, a.id)
if down == nil { if down == nil {
log.Printf("Failed indication for " + log.Printf("Failed indication for " +
"unknown connection") "unknown connection")
@ -804,8 +801,6 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
return errors.New("unexpected action") return errors.New("unexpected action")
} }
case <-ticker.C: case <-ticker.C:
sendRateUpdate(c)
case <-slowTicker.C:
if time.Since(readTime) > 90*time.Second { if time.Since(readTime) > 90*time.Second {
return errors.New("client is dead") return errors.New("client is dead")
} }
@ -1022,27 +1017,6 @@ func handleClientMessage(c *webClient, m clientMessage) error {
return nil return nil
} }
func sendRateUpdate(c *webClient) {
up := getUpConns(c)
for _, u := range up {
tracks := u.getTracks()
for _, t := range tracks {
rate := updateUpTrack(t)
if !t.hasRtcpFb("goog-remb", "") {
continue
}
if rate == ^uint64(0) {
continue
}
err := sendREMB(u.pc, t.track.SSRC(), rate)
if err != nil {
log.Printf("sendREMB: %v", err)
}
}
}
}
func clientReader(conn *websocket.Conn, read chan<- interface{}, done <-chan struct{}) { func clientReader(conn *websocket.Conn, read chan<- interface{}, done <-chan struct{}) {
defer close(read) defer close(read)
for { for {

View file

@ -171,16 +171,26 @@ func statsHandler(w http.ResponseWriter, r *http.Request) {
for _, cs := range gs.clients { for _, cs := range gs.clients {
fmt.Fprintf(w, "<tr><td>%v</td></tr>\n", cs.id) fmt.Fprintf(w, "<tr><td>%v</td></tr>\n", cs.id)
for _, up := range cs.up { for _, up := range cs.up {
fmt.Fprintf(w, "<tr><td></td><td>Up</td><td>%v</td></tr>\n", fmt.Fprintf(w, "<tr><td></td><td>Up</td><td>%v</td>",
up.id) up.id)
if up.maxBitrate > 0 {
fmt.Fprintf(w, "<td>%v</td>",
up.maxBitrate)
}
fmt.Fprintf(w, "</tr>\n")
for _, t := range up.tracks { for _, t := range up.tracks {
printTrack(w, t) printTrack(w, t)
} }
} }
for _, up := range cs.down { for _, down := range cs.down {
fmt.Fprintf(w, "<tr><td></td><td>Down</td><td> %v</td></tr>\n", fmt.Fprintf(w, "<tr><td></td><td>Down</td><td> %v</td>",
up.id) down.id)
for _, t := range up.tracks { if down.maxBitrate > 0 {
fmt.Fprintf(w, "<td>%v</td>",
down.maxBitrate)
}
fmt.Fprintf(w, "</tr>\n")
for _, t := range down.tracks {
printTrack(w, t) printTrack(w, t)
} }
} }