1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-11-26 02:25:58 +01:00

Protect upConn.tracks by the upConn mutex rather than the client mutex.

Also don't rely on tracks being immutable in sendRR.
This commit is contained in:
Juliusz Chroboczek 2020-06-08 19:10:08 +02:00
parent 8ba50bd2ca
commit da97560cb3
3 changed files with 104 additions and 107 deletions

11
conn.go
View file

@ -107,17 +107,25 @@ type upConnection struct {
id string id string
label string label string
pc *webrtc.PeerConnection pc *webrtc.PeerConnection
tracks []*upTrack
labels map[string]string labels map[string]string
iceCandidates []*webrtc.ICECandidateInit iceCandidates []*webrtc.ICECandidateInit
mu sync.Mutex mu sync.Mutex
closed bool closed bool
tracks []*upTrack
local []downConnection local []downConnection
} }
var ErrConnectionClosed = errors.New("connection is closed") var ErrConnectionClosed = errors.New("connection is closed")
func (up *upConnection) getTracks() []*upTrack {
up.mu.Lock()
defer up.mu.Unlock()
tracks := make([]*upTrack, len(up.tracks))
copy(tracks, up.tracks)
return tracks
}
func (up *upConnection) addLocal(local downConnection) error { func (up *upConnection) addLocal(local downConnection) error {
up.mu.Lock() up.mu.Lock()
defer up.mu.Unlock() defer up.mu.Unlock()
@ -206,6 +214,7 @@ func getUpMid(pc *webrtc.PeerConnection, track *webrtc.Track) string {
return "" return ""
} }
// called locked
func (up *upConnection) complete() bool { func (up *upConnection) complete() bool {
for mid, _ := range up.labels { for mid, _ := range up.labels {
found := false found := false

View file

@ -565,7 +565,8 @@ func getClientStats(c *webClient) clientStats {
for _, up := range c.up { for _, up := range c.up {
conns := connStats{id: up.id} conns := connStats{id: up.id}
for _, t := range up.tracks { tracks := up.getTracks()
for _, t := range tracks {
expected, lost, _, _ := t.cache.GetStats(false) expected, lost, _, _ := t.cache.GetStats(false)
if expected == 0 { if expected == 0 {
expected = 1 expected = 1

View file

@ -282,12 +282,12 @@ func getUpConn(c *webClient, id string) *upConnection {
return conn return conn
} }
func getUpConns(c *webClient) []string { func getUpConns(c *webClient) []*upConnection {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
up := make([]string, 0, len(c.up)) up := make([]*upConnection, 0, len(c.up))
for id := range c.up { for _, u := range c.up {
up = append(up, id) up = append(up, u)
} }
return up return up
} }
@ -337,22 +337,16 @@ func addUpConn(c *webClient, id string) (*upConnection, error) {
}) })
pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) { pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) {
c.mu.Lock() conn.mu.Lock()
u, ok := c.up[id] defer conn.mu.Unlock()
if !ok {
log.Printf("Unknown connection")
c.mu.Unlock()
return
}
mid := getUpMid(pc, remote) mid := getUpMid(pc, remote)
if mid == "" { if mid == "" {
log.Printf("Couldn't get track's mid") log.Printf("Couldn't get track's mid")
c.mu.Unlock()
return return
} }
label, ok := u.labels[mid] label, ok := conn.labels[mid]
if !ok { if !ok {
log.Printf("Couldn't get track's label") log.Printf("Couldn't get track's label")
isvideo := remote.Kind() == webrtc.RTPCodecTypeVideo isvideo := remote.Kind() == webrtc.RTPCodecTypeVideo
@ -373,25 +367,24 @@ func addUpConn(c *webClient, id string) (*upConnection, error) {
localCh: make(chan localTrackAction, 2), localCh: make(chan localTrackAction, 2),
writerDone: make(chan struct{}), writerDone: make(chan struct{}),
} }
u.tracks = append(u.tracks, track)
var tracks []*upTrack conn.tracks = append(conn.tracks, track)
if u.complete() {
tracks = make([]*upTrack, len(u.tracks))
copy(tracks, u.tracks)
}
if remote.Kind() == webrtc.RTPCodecTypeVideo { if remote.Kind() == webrtc.RTPCodecTypeVideo {
atomic.AddUint32(&c.group.videoCount, 1) atomic.AddUint32(&c.group.videoCount, 1)
} }
c.mu.Unlock()
go readLoop(conn, track) go readLoop(conn, track)
go rtcpUpListener(conn, track, receiver) go rtcpUpListener(conn, track, receiver)
if tracks != nil { if conn.complete() {
// cannot call getTracks, we're locked
tracks := make([]*upTrack, len(conn.tracks))
copy(tracks, conn.tracks)
clients := c.group.getClients(c) clients := c.group.getClients(c)
for _, cc := range clients { for _, cc := range clients {
cc.pushConn(u, tracks, u.label) cc.pushConn(conn, tracks, conn.label)
} }
go rtcpUpSender(conn) go rtcpUpSender(conn)
} }
@ -573,7 +566,7 @@ func rtcpUpListener(conn *upConnection, track *upTrack, r *webrtc.RTPReceiver) {
} }
} }
if(firstSR) { if firstSR {
// this is the first SR we got for at least one track, // this is the first SR we got for at least one track,
// quickly propagate the time offsets downstream // quickly propagate the time offsets downstream
local := conn.getLocal() local := conn.getLocal()
@ -591,6 +584,9 @@ func rtcpUpListener(conn *upConnection, track *upTrack, r *webrtc.RTPReceiver) {
} }
func sendRR(conn *upConnection) error { func sendRR(conn *upConnection) error {
conn.mu.Lock()
defer conn.mu.Unlock()
if len(conn.tracks) == 0 { if len(conn.tracks) == 0 {
return nil return nil
} }
@ -650,6 +646,8 @@ func rtcpUpSender(conn *upConnection) {
} }
func sendSR(conn *rtpDownConnection) error { func sendSR(conn *rtpDownConnection) error {
// since this is only called after all tracks have been created,
// there is no need for locking.
packets := make([]rtcp.Packet, 0, len(conn.tracks)) packets := make([]rtcp.Packet, 0, len(conn.tracks))
now := time.Now() now := time.Now()
@ -716,17 +714,19 @@ func rtcpDownSender(conn *rtpDownConnection) {
func delUpConn(c *webClient, id string) bool { func delUpConn(c *webClient, id string) bool {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock()
if c.up == nil { if c.up == nil {
c.mu.Unlock()
return false return false
} }
conn := c.up[id] conn := c.up[id]
if conn == nil { if conn == nil {
c.mu.Unlock()
return false return false
} }
delete(c.up, id)
c.mu.Unlock()
conn.mu.Lock()
for _, track := range conn.tracks { for _, track := range conn.tracks {
if track.track.Kind() == webrtc.RTPCodecTypeVideo { if track.track.Kind() == webrtc.RTPCodecTypeVideo {
count := atomic.AddUint32(&c.group.videoCount, count := atomic.AddUint32(&c.group.videoCount,
@ -737,9 +737,9 @@ func delUpConn(c *webClient, id string) bool {
} }
} }
} }
conn.mu.Unlock()
conn.Close() conn.Close()
delete(c.up, id)
return true return true
} }
@ -1007,59 +1007,59 @@ func handleReport(track *rtpDownTrack, report rtcp.ReceptionReport, jiffies uint
} }
} }
func updateUpTrack(up *upConnection, maxVideoRate uint64) { func updateUpTrack(track *upTrack, maxVideoRate uint64) uint64 {
now := rtptime.Jiffies() now := rtptime.Jiffies()
for _, track := range up.tracks { isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo clockrate := track.track.Codec().ClockRate
clockrate := track.track.Codec().ClockRate minrate := uint64(minAudioRate)
minrate := uint64(minAudioRate) rate := ^uint64(0)
rate := ^uint64(0) if isvideo {
if isvideo { minrate = minVideoRate
minrate = minVideoRate rate = maxVideoRate
rate = maxVideoRate if rate < minrate {
if rate < minrate { rate = minrate
rate = minrate
}
} }
local := track.getLocal()
var maxrto uint64
for _, l := range local {
bitrate := l.GetMaxBitrate(now)
if bitrate == ^uint64(0) {
continue
}
if bitrate <= minrate {
rate = minrate
break
}
if rate > bitrate {
rate = bitrate
}
ll, ok := l.(*rtpDownTrack)
if ok {
_, j := ll.stats.Get(now)
jitter := uint64(j) *
(rtptime.JiffiesPerSec /
uint64(clockrate))
rtt := atomic.LoadUint64(&ll.rtt)
rto := rtt + 4*jitter
if rto > maxrto {
maxrto = rto
}
}
}
track.maxBitrate = rate
_, r := track.rate.Estimate()
packets := int((uint64(r) * maxrto * 4) / rtptime.JiffiesPerSec)
if packets < 32 {
packets = 32
}
if packets > 256 {
packets = 256
}
track.cache.ResizeCond(packets)
} }
local := track.getLocal()
var maxrto uint64
for _, l := range local {
bitrate := l.GetMaxBitrate(now)
if bitrate == ^uint64(0) {
continue
}
if bitrate <= minrate {
rate = minrate
break
}
if rate > bitrate {
rate = bitrate
}
ll, ok := l.(*rtpDownTrack)
if ok {
_, j := ll.stats.Get(now)
jitter := uint64(j) *
(rtptime.JiffiesPerSec /
uint64(clockrate))
rtt := atomic.LoadUint64(&ll.rtt)
rto := rtt + 4*jitter
if rto > maxrto {
maxrto = rto
}
}
}
track.maxBitrate = rate
_, r := track.rate.Estimate()
packets := int((uint64(r) * maxrto * 4) / rtptime.JiffiesPerSec)
if packets < 32 {
packets = 32
}
if packets > 256 {
packets = 256
}
track.cache.ResizeCond(packets)
return rate
} }
var ErrUnsupportedFeedback = errors.New("unsupported feedback type") var ErrUnsupportedFeedback = errors.New("unsupported feedback type")
@ -1468,8 +1468,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
}) })
case pushConnsAction: case pushConnsAction:
for _, u := range c.up { for _, u := range c.up {
tracks := make([]*upTrack, len(u.tracks)) tracks := u.getTracks()
copy(tracks, u.tracks)
go a.c.pushConn(u, tracks, u.label) go a.c.pushConn(u, tracks, u.label)
} }
case connectionFailedAction: case connectionFailedAction:
@ -1490,12 +1489,12 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
Permissions: c.permissions, Permissions: c.permissions,
}) })
if !c.permissions.Present { if !c.permissions.Present {
ids := getUpConns(c) up := getUpConns(c)
for _, id := range ids { for _, u := range up {
found := delUpConn(c, id) found := delUpConn(c, u.id)
if found { if found {
failConnection( failConnection(
c, id, c, u.id,
"permission denied", "permission denied",
) )
} }
@ -1727,13 +1726,6 @@ func handleClientMessage(c *webClient, m clientMessage) error {
} }
func sendRateUpdate(c *webClient) { func sendRateUpdate(c *webClient) {
type remb struct {
pc *webrtc.PeerConnection
ssrc uint32
bitrate uint64
}
rembs := make([]remb, 0)
maxVideoRate := ^uint64(0) maxVideoRate := ^uint64(0)
count := atomic.LoadUint32(&c.group.videoCount) count := atomic.LoadUint32(&c.group.videoCount)
if count >= 3 { if count >= 3 {
@ -1743,27 +1735,22 @@ func sendRateUpdate(c *webClient) {
} }
} }
c.mu.Lock() up := getUpConns(c)
for _, u := range c.up {
updateUpTrack(u, maxVideoRate) for _, u := range up {
for _, t := range u.tracks { tracks := u.getTracks()
for _, t := range tracks {
rate := updateUpTrack(t, maxVideoRate)
if !t.hasRtcpFb("goog-remb", "") { if !t.hasRtcpFb("goog-remb", "") {
continue continue
} }
bitrate := t.maxBitrate if rate == ^uint64(0) {
if bitrate == ^uint64(0) {
continue continue
} }
rembs = append(rembs, err := sendREMB(u.pc, t.track.SSRC(), rate)
remb{u.pc, t.track.SSRC(), bitrate}) if err != nil {
} log.Printf("sendREMB: %v", err)
} }
c.mu.Unlock()
for _, r := range rembs {
err := sendREMB(r.pc, r.ssrc, r.bitrate)
if err != nil {
log.Printf("sendREMB: %v", err)
} }
} }
} }