1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-11-09 18:25:58 +01:00

Protect upConn.tracks by the upConn mutex rather than the client mutex.

Also don't rely on tracks being immutable in sendRR.
This commit is contained in:
Juliusz Chroboczek 2020-06-08 19:10:08 +02:00
parent 8ba50bd2ca
commit da97560cb3
3 changed files with 104 additions and 107 deletions

11
conn.go
View file

@ -107,17 +107,25 @@ type upConnection struct {
id string
label string
pc *webrtc.PeerConnection
tracks []*upTrack
labels map[string]string
iceCandidates []*webrtc.ICECandidateInit
mu sync.Mutex
closed bool
tracks []*upTrack
local []downConnection
}
var ErrConnectionClosed = errors.New("connection is closed")
func (up *upConnection) getTracks() []*upTrack {
up.mu.Lock()
defer up.mu.Unlock()
tracks := make([]*upTrack, len(up.tracks))
copy(tracks, up.tracks)
return tracks
}
func (up *upConnection) addLocal(local downConnection) error {
up.mu.Lock()
defer up.mu.Unlock()
@ -206,6 +214,7 @@ func getUpMid(pc *webrtc.PeerConnection, track *webrtc.Track) string {
return ""
}
// called locked
func (up *upConnection) complete() bool {
for mid, _ := range up.labels {
found := false

View file

@ -565,7 +565,8 @@ func getClientStats(c *webClient) clientStats {
for _, up := range c.up {
conns := connStats{id: up.id}
for _, t := range up.tracks {
tracks := up.getTracks()
for _, t := range tracks {
expected, lost, _, _ := t.cache.GetStats(false)
if expected == 0 {
expected = 1

View file

@ -282,12 +282,12 @@ func getUpConn(c *webClient, id string) *upConnection {
return conn
}
func getUpConns(c *webClient) []string {
func getUpConns(c *webClient) []*upConnection {
c.mu.Lock()
defer c.mu.Unlock()
up := make([]string, 0, len(c.up))
for id := range c.up {
up = append(up, id)
up := make([]*upConnection, 0, len(c.up))
for _, u := range c.up {
up = append(up, u)
}
return up
}
@ -337,22 +337,16 @@ func addUpConn(c *webClient, id string) (*upConnection, error) {
})
pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) {
c.mu.Lock()
u, ok := c.up[id]
if !ok {
log.Printf("Unknown connection")
c.mu.Unlock()
return
}
conn.mu.Lock()
defer conn.mu.Unlock()
mid := getUpMid(pc, remote)
if mid == "" {
log.Printf("Couldn't get track's mid")
c.mu.Unlock()
return
}
label, ok := u.labels[mid]
label, ok := conn.labels[mid]
if !ok {
log.Printf("Couldn't get track's label")
isvideo := remote.Kind() == webrtc.RTPCodecTypeVideo
@ -373,25 +367,24 @@ func addUpConn(c *webClient, id string) (*upConnection, error) {
localCh: make(chan localTrackAction, 2),
writerDone: make(chan struct{}),
}
u.tracks = append(u.tracks, track)
var tracks []*upTrack
if u.complete() {
tracks = make([]*upTrack, len(u.tracks))
copy(tracks, u.tracks)
}
conn.tracks = append(conn.tracks, track)
if remote.Kind() == webrtc.RTPCodecTypeVideo {
atomic.AddUint32(&c.group.videoCount, 1)
}
c.mu.Unlock()
go readLoop(conn, track)
go rtcpUpListener(conn, track, receiver)
if tracks != nil {
if conn.complete() {
// cannot call getTracks, we're locked
tracks := make([]*upTrack, len(conn.tracks))
copy(tracks, conn.tracks)
clients := c.group.getClients(c)
for _, cc := range clients {
cc.pushConn(u, tracks, u.label)
cc.pushConn(conn, tracks, conn.label)
}
go rtcpUpSender(conn)
}
@ -573,7 +566,7 @@ func rtcpUpListener(conn *upConnection, track *upTrack, r *webrtc.RTPReceiver) {
}
}
if(firstSR) {
if firstSR {
// this is the first SR we got for at least one track,
// quickly propagate the time offsets downstream
local := conn.getLocal()
@ -591,6 +584,9 @@ func rtcpUpListener(conn *upConnection, track *upTrack, r *webrtc.RTPReceiver) {
}
func sendRR(conn *upConnection) error {
conn.mu.Lock()
defer conn.mu.Unlock()
if len(conn.tracks) == 0 {
return nil
}
@ -650,6 +646,8 @@ func rtcpUpSender(conn *upConnection) {
}
func sendSR(conn *rtpDownConnection) error {
// since this is only called after all tracks have been created,
// there is no need for locking.
packets := make([]rtcp.Packet, 0, len(conn.tracks))
now := time.Now()
@ -716,17 +714,19 @@ func rtcpDownSender(conn *rtpDownConnection) {
func delUpConn(c *webClient, id string) bool {
c.mu.Lock()
defer c.mu.Unlock()
if c.up == nil {
c.mu.Unlock()
return false
}
conn := c.up[id]
if conn == nil {
c.mu.Unlock()
return false
}
delete(c.up, id)
c.mu.Unlock()
conn.mu.Lock()
for _, track := range conn.tracks {
if track.track.Kind() == webrtc.RTPCodecTypeVideo {
count := atomic.AddUint32(&c.group.videoCount,
@ -737,9 +737,9 @@ func delUpConn(c *webClient, id string) bool {
}
}
}
conn.mu.Unlock()
conn.Close()
delete(c.up, id)
return true
}
@ -1007,59 +1007,59 @@ func handleReport(track *rtpDownTrack, report rtcp.ReceptionReport, jiffies uint
}
}
func updateUpTrack(up *upConnection, maxVideoRate uint64) {
func updateUpTrack(track *upTrack, maxVideoRate uint64) uint64 {
now := rtptime.Jiffies()
for _, track := range up.tracks {
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
clockrate := track.track.Codec().ClockRate
minrate := uint64(minAudioRate)
rate := ^uint64(0)
if isvideo {
minrate = minVideoRate
rate = maxVideoRate
if rate < minrate {
rate = minrate
}
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
clockrate := track.track.Codec().ClockRate
minrate := uint64(minAudioRate)
rate := ^uint64(0)
if isvideo {
minrate = minVideoRate
rate = maxVideoRate
if rate < minrate {
rate = minrate
}
local := track.getLocal()
var maxrto uint64
for _, l := range local {
bitrate := l.GetMaxBitrate(now)
if bitrate == ^uint64(0) {
continue
}
if bitrate <= minrate {
rate = minrate
break
}
if rate > bitrate {
rate = bitrate
}
ll, ok := l.(*rtpDownTrack)
if ok {
_, j := ll.stats.Get(now)
jitter := uint64(j) *
(rtptime.JiffiesPerSec /
uint64(clockrate))
rtt := atomic.LoadUint64(&ll.rtt)
rto := rtt + 4*jitter
if rto > maxrto {
maxrto = rto
}
}
}
track.maxBitrate = rate
_, r := track.rate.Estimate()
packets := int((uint64(r) * maxrto * 4) / rtptime.JiffiesPerSec)
if packets < 32 {
packets = 32
}
if packets > 256 {
packets = 256
}
track.cache.ResizeCond(packets)
}
local := track.getLocal()
var maxrto uint64
for _, l := range local {
bitrate := l.GetMaxBitrate(now)
if bitrate == ^uint64(0) {
continue
}
if bitrate <= minrate {
rate = minrate
break
}
if rate > bitrate {
rate = bitrate
}
ll, ok := l.(*rtpDownTrack)
if ok {
_, j := ll.stats.Get(now)
jitter := uint64(j) *
(rtptime.JiffiesPerSec /
uint64(clockrate))
rtt := atomic.LoadUint64(&ll.rtt)
rto := rtt + 4*jitter
if rto > maxrto {
maxrto = rto
}
}
}
track.maxBitrate = rate
_, r := track.rate.Estimate()
packets := int((uint64(r) * maxrto * 4) / rtptime.JiffiesPerSec)
if packets < 32 {
packets = 32
}
if packets > 256 {
packets = 256
}
track.cache.ResizeCond(packets)
return rate
}
var ErrUnsupportedFeedback = errors.New("unsupported feedback type")
@ -1468,8 +1468,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
})
case pushConnsAction:
for _, u := range c.up {
tracks := make([]*upTrack, len(u.tracks))
copy(tracks, u.tracks)
tracks := u.getTracks()
go a.c.pushConn(u, tracks, u.label)
}
case connectionFailedAction:
@ -1490,12 +1489,12 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
Permissions: c.permissions,
})
if !c.permissions.Present {
ids := getUpConns(c)
for _, id := range ids {
found := delUpConn(c, id)
up := getUpConns(c)
for _, u := range up {
found := delUpConn(c, u.id)
if found {
failConnection(
c, id,
c, u.id,
"permission denied",
)
}
@ -1727,13 +1726,6 @@ func handleClientMessage(c *webClient, m clientMessage) error {
}
func sendRateUpdate(c *webClient) {
type remb struct {
pc *webrtc.PeerConnection
ssrc uint32
bitrate uint64
}
rembs := make([]remb, 0)
maxVideoRate := ^uint64(0)
count := atomic.LoadUint32(&c.group.videoCount)
if count >= 3 {
@ -1743,27 +1735,22 @@ func sendRateUpdate(c *webClient) {
}
}
c.mu.Lock()
for _, u := range c.up {
updateUpTrack(u, maxVideoRate)
for _, t := range u.tracks {
up := getUpConns(c)
for _, u := range up {
tracks := u.getTracks()
for _, t := range tracks {
rate := updateUpTrack(t, maxVideoRate)
if !t.hasRtcpFb("goog-remb", "") {
continue
}
bitrate := t.maxBitrate
if bitrate == ^uint64(0) {
if rate == ^uint64(0) {
continue
}
rembs = append(rembs,
remb{u.pc, t.track.SSRC(), bitrate})
}
}
c.mu.Unlock()
for _, r := range rembs {
err := sendREMB(r.pc, r.ssrc, r.bitrate)
if err != nil {
log.Printf("sendREMB: %v", err)
err := sendREMB(u.pc, t.track.SSRC(), rate)
if err != nil {
log.Printf("sendREMB: %v", err)
}
}
}
}