diff --git a/client.go b/client.go index 2dd6ae6..a300b89 100644 --- a/client.go +++ b/client.go @@ -10,7 +10,6 @@ import ( "errors" "io" "log" - "math" "os" "strings" "sync" @@ -98,8 +97,6 @@ type clientMessage struct { Answer *webrtc.SessionDescription `json:"answer,omitempty"` Candidate *webrtc.ICECandidateInit `json:"candidate,omitempty"` Del bool `json:"del,omitempty"` - AudioRate int `json:"audiorate,omitempty"` - VideoRate int `json:"videorate,omitempty"` } type closeMessage struct { @@ -282,8 +279,9 @@ func addUpConn(c *client, id string) (*upConnection, error) { return } u.pairs = append(u.pairs, trackPair{ - remote: remote, - local: local, + remote: remote, + local: local, + maxBitrate: ^uint64(0), }) done := len(u.pairs) >= u.trackCount c.group.mu.Unlock() @@ -442,12 +440,24 @@ func addDownTrack(c *client, id string, track *webrtc.Track, remote *upConnectio return nil, nil, err } - go rtcpListener(c.group, conn, s) + conn.tracks = append(conn.tracks, + downTrack{track.SSRC(), new(timeStampedBitrate)}, + ) + + go rtcpListener(c.group, conn, s, + conn.tracks[len(conn.tracks)-1].maxBitrate) return conn, s, nil } -func rtcpListener(g *group, c *downConnection, s *webrtc.RTPSender) { +var epoch = time.Now() + +func msSinceEpoch() uint64 { + return uint64(time.Since(epoch) / time.Millisecond) +} + +func rtcpListener(g *group, c *downConnection, s *webrtc.RTPSender, + bitrate *timeStampedBitrate) { for { ps, err := s.ReadRTCP() if err != nil { @@ -460,16 +470,23 @@ func rtcpListener(g *group, c *downConnection, s *webrtc.RTPSender) { for _, p := range ps { switch p := p.(type) { case *rtcp.PictureLossIndication: - err := sendPLI(c.remote, p.MediaSSRC) + err := sendPLI(c.remote.pc, p.MediaSSRC) if err != nil { log.Printf("sendPLI: %v", err) } case *rtcp.ReceiverEstimatedMaximumBitrate: - bitrate := uint32(math.MaxInt32) - if p.Bitrate < math.MaxInt32 { - bitrate = uint32(p.Bitrate) - } - atomic.StoreUint32(&c.maxBitrate, bitrate) + ms := msSinceEpoch() + // this is racy -- a reader might read the + // data between the two writes. This shouldn't + // matter, we'll recover at the next sample. + atomic.StoreUint64( + &bitrate.bitrate, + p.Bitrate, + ) + atomic.StoreUint64( + &bitrate.timestamp, + uint64(ms), + ) case *rtcp.ReceiverReport: default: log.Printf("RTCP: %T", p) @@ -520,42 +537,61 @@ func splitBitrate(bitrate uint32, audio, video bool) (uint32, uint32) { return audioRate, bitrate - audioRate } -func updateBitrate(g *group, up *upConnection) (uint32, uint32) { - audio := uint32(math.MaxInt32) - video := uint32(math.MaxInt32) +func updateUpBitrate(g *group, up *upConnection) { + for i := range up.pairs { + up.pairs[i].maxBitrate = ^uint64(0) + } + + now := msSinceEpoch() + g.Range(func(c *client) bool { for _, down := range c.down { if down.remote == up { - bitrate := atomic.LoadUint32(&down.maxBitrate) - if bitrate == 0 { - bitrate = 256000 - } else if bitrate < 6000 { - bitrate = 6000 - } - hasAudio, hasVideo := trackKinds(down) - a, v := splitBitrate(bitrate, hasAudio, hasVideo) - if a < audio { - audio = a - } - if v < video { - video = v + for _, dt := range down.tracks { + ms := atomic.LoadUint64( + &dt.maxBitrate.timestamp, + ) + bitrate := atomic.LoadUint64( + &dt.maxBitrate.bitrate, + ) + if bitrate == 0 { + continue + } + + if now - ms > 5000 { + continue + } + + for i, p := range up.pairs { + if p.local.SSRC() == dt.ssrc { + if p.maxBitrate > bitrate { + up.pairs[i].maxBitrate = bitrate + break + } + } + } } } } return true }) - up.maxAudioBitrate = audio - up.maxVideoBitrate = video - return audio, video } -func sendPLI(up *upConnection, ssrc uint32) error { - // we use equal SSRC values on both sides - return up.pc.WriteRTCP([]rtcp.Packet{ +func sendPLI(pc *webrtc.PeerConnection, ssrc uint32) error { + return pc.WriteRTCP([]rtcp.Packet{ &rtcp.PictureLossIndication{MediaSSRC: ssrc}, }) } +func sendREMB(pc *webrtc.PeerConnection, ssrc uint32, bitrate uint64) error { + return pc.WriteRTCP([]rtcp.Packet{ + &rtcp.ReceiverEstimatedMaximumBitrate{ + Bitrate: bitrate, + SSRCs: []uint32{ssrc}, + }, + }) +} + func countMediaStreams(data string) (int, error) { desc := sdp.NewJSEPSessionDescription(false) err := desc.Unmarshal(data) @@ -709,7 +745,7 @@ func clientLoop(c *client, conn *websocket.Conn) error { readTime := time.Now() - ticker := time.NewTicker(2 * time.Second) + ticker := time.NewTicker(time.Second) defer ticker.Stop() slowTicker := time.NewTicker(10 * time.Second) defer slowTicker.Stop() @@ -888,16 +924,19 @@ func handleClientMessage(c *client, m clientMessage) error { func sendRateUpdate(c *client) { for _, u := range c.up { - oldaudio := u.maxAudioBitrate - oldvideo := u.maxVideoBitrate - audio, video := updateBitrate(c.group, u) - if audio != oldaudio || video != oldvideo { - c.write(clientMessage{ - Type: "maxbitrate", - Id: u.id, - AudioRate: int(audio), - VideoRate: int(video), - }) + updateUpBitrate(c.group, u) + for _, p := range u.pairs { + bitrate := p.maxBitrate + if bitrate != ^uint64(0) { + if bitrate < 6000 { + bitrate = 6000 + } + err := sendREMB(u.pc, p.remote.SSRC(), + uint64(bitrate)) + if err != nil { + log.Printf("sendREMB: %v", err) + } + } } } } diff --git a/group.go b/group.go index ed613d0..17173a2 100644 --- a/group.go +++ b/group.go @@ -19,23 +19,31 @@ import ( type trackPair struct { remote, local *webrtc.Track + maxBitrate uint64 } type upConnection struct { - id string - label string - pc *webrtc.PeerConnection - maxAudioBitrate uint32 - maxVideoBitrate uint32 - trackCount int - pairs []trackPair + id string + label string + pc *webrtc.PeerConnection + trackCount int + pairs []trackPair +} + +type timeStampedBitrate struct { + bitrate uint64 + timestamp uint64 +} +type downTrack struct { + ssrc uint32 + maxBitrate *timeStampedBitrate } type downConnection struct { - id string - pc *webrtc.PeerConnection - remote *upConnection - maxBitrate uint32 + id string + pc *webrtc.PeerConnection + remote *upConnection + tracks []downTrack } type client struct { diff --git a/static/sfu.js b/static/sfu.js index 3bbf6c1..abe4211 100644 --- a/static/sfu.js +++ b/static/sfu.js @@ -323,9 +323,6 @@ function serverConnect() { case 'ice': gotICE(m.id, m.candidate); break; - case 'maxbitrate': - setMaxBitrate(m.id, m.audiorate, m.videorate); - break; case 'label': gotLabel(m.id, m.value); break; @@ -450,49 +447,6 @@ async function gotICE(id, candidate) { conn.iceCandidates.push(candidate) } -let maxaudiorate, maxvideorate; - -async function setMaxBitrate(id, audio, video) { - let conn = up[id]; - if(!conn) - throw new Error("Setting bitrate of unknown id"); - - let senders = conn.pc.getSenders(); - for(let i = 0; i < senders.length; i++) { - let s = senders[i]; - if(!s.track) - return; - let p = s.getParameters(); - let bitrate; - if(s.track.kind == 'audio') - bitrate = audio; - else if(s.track.kind == 'video') - bitrate = video; - for(let j = 0; j < p.encodings.length; j++) { - let e = p.encodings[j]; - if(bitrate) - e.maxBitrate = bitrate; - else - delete(e.maxBitrate); - await s.setParameters(p); - } - } - - if((audio && audio < 128000) || (video && video < 256000)) { - let l = ''; - if(audio) - l = `${Math.round(audio/1000)}kbps` - if(video) { - if(l) - l = l + ' + '; - l = l + `${Math.round(video/1000)}kbps` - } - setLabel(id, l) - } else { - setLabel(id); - } -} - async function addIceCandidates(conn) { let promises = [] conn.iceCandidates.forEach(c => {