1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-11-22 16:45:58 +01:00

Implement renegotiation of down streams.

We used to destroy and recreate down streams whenever something changed,
which turned out to be racy.  We now properly implement renegotiation,
as well as atomic replacement of a stream by another one.
This commit is contained in:
Juliusz Chroboczek 2021-02-02 19:38:15 +01:00
parent 368da133fd
commit 263258a0c1
3 changed files with 148 additions and 150 deletions

View file

@ -17,7 +17,6 @@ type Up interface {
DelLocal(Down) bool DelLocal(Down) bool
Id() string Id() string
User() (string, string) User() (string, string)
Codecs() []webrtc.RTPCodecCapability
} }
// Type UpTrack represents a track in the client to server direction. // Type UpTrack represents a track in the client to server direction.

View file

@ -78,6 +78,7 @@ type downTrackAtomics struct {
type rtpDownTrack struct { type rtpDownTrack struct {
track *webrtc.TrackLocalStaticRTP track *webrtc.TrackLocalStaticRTP
sender *webrtc.RTPSender
remote conn.UpTrack remote conn.UpTrack
ssrc webrtc.SSRC ssrc webrtc.SSRC
maxBitrate *bitrate maxBitrate *bitrate
@ -156,7 +157,7 @@ func (down *rtpDownConnection) getTracks() []*rtpDownTrack {
} }
func newDownConn(c group.Client, id string, remote conn.Up) (*rtpDownConnection, error) { func newDownConn(c group.Client, id string, remote conn.Up) (*rtpDownConnection, error) {
api := group.APIFromCodecs(remote.Codecs()) api := c.Group().API()
pc, err := api.NewPeerConnection(*ice.ICEConfiguration()) pc, err := api.NewPeerConnection(*ice.ICEConfiguration())
if err != nil { if err != nil {
return nil, err return nil, err
@ -366,17 +367,6 @@ func (up *rtpUpConnection) User() (string, string) {
return up.userId, up.username return up.userId, up.username
} }
func (up *rtpUpConnection) Codecs() []webrtc.RTPCodecCapability {
up.mu.Lock()
defer up.mu.Unlock()
codecs := make([]webrtc.RTPCodecCapability, len(up.tracks))
for i := range up.tracks {
codecs[i] = up.tracks[i].Codec()
}
return codecs
}
func (up *rtpUpConnection) AddLocal(local conn.Down) error { func (up *rtpUpConnection) AddLocal(local conn.Down) error {
up.mu.Lock() up.mu.Lock()
defer up.mu.Unlock() defer up.mu.Unlock()

View file

@ -315,55 +315,50 @@ func getConn(c *webClient, id string) iceConnection {
return nil return nil
} }
func addDownConn(c *webClient, id string, remote conn.Up) (*rtpDownConnection, error) { func addDownConn(c *webClient, remote conn.Up) (*rtpDownConnection, bool, error) {
conn, err := newDownConn(c, id, remote) id := remote.Id()
if err != nil {
return nil, err
}
err = addDownConnHelper(c, conn, remote)
if err != nil {
conn.pc.Close()
return nil, err
}
return conn, err
}
func addDownConnHelper(c *webClient, conn *rtpDownConnection, remote conn.Up) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.up != nil && c.up[conn.id] != nil { if c.up != nil && c.up[id] != nil {
return errors.New("Adding duplicate connection") return nil, false, errors.New("adding duplicate connection")
} }
if c.down == nil { if c.down == nil {
c.down = make(map[string]*rtpDownConnection) c.down = make(map[string]*rtpDownConnection)
} }
old := c.down[conn.id] if down := c.down[id]; down != nil {
if old != nil { return down, false, nil
// Avoid calling Close under a lock
go old.pc.Close()
} }
conn.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { down, err := newDownConn(c, id, remote)
sendICE(c, conn.id, candidate)
})
conn.pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
if state == webrtc.ICEConnectionStateFailed {
c.action(connectionFailedAction{id: conn.id})
}
})
err := remote.AddLocal(conn)
if err != nil { if err != nil {
return err return nil, false, err
} }
c.down[conn.id] = conn down.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
return nil sendICE(c, down.id, candidate)
})
down.pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
if state == webrtc.ICEConnectionStateFailed {
c.action(connectionFailedAction{id: down.id})
}
})
err = remote.AddLocal(down)
if err != nil {
down.pc.Close()
return nil, false, err
}
c.down[down.id] = down
go rtcpDownSender(down)
return down, true, nil
} }
func delDownConn(c *webClient, id string) error { func delDownConn(c *webClient, id string) error {
@ -397,46 +392,40 @@ func delDownConnHelper(c *webClient, id string) *rtpDownConnection {
return conn return conn
} }
func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrack, remoteConn conn.Up) (*webrtc.RTPSender, error) { var errUnexpectedTrackType = errors.New("unexpected track type, this shouldn't happen")
rt, ok := remoteTrack.(*rtpUpTrack)
if !ok {
return nil, errors.New("unexpected up track type")
}
conn.mu.Lock() func addDownTrackUnlocked(conn *rtpDownConnection, remoteTrack *rtpUpTrack, remoteConn conn.Up) error {
defer conn.mu.Unlock()
remoteSSRC := rt.track.SSRC()
for _, t := range conn.tracks { for _, t := range conn.tracks {
tt, ok := t.remote.(*rtpUpTrack) tt, ok := t.remote.(*rtpUpTrack)
if !ok { if !ok {
return nil, errors.New("unexpected up track type") return errUnexpectedTrackType
} }
if tt.track.SSRC() == remoteSSRC { if tt == remoteTrack {
return nil, os.ErrExist return os.ErrExist
} }
} }
local, err := webrtc.NewTrackLocalStaticRTP( local, err := webrtc.NewTrackLocalStaticRTP(
remoteTrack.Codec(), remoteTrack.Codec(),
rt.track.ID(), rt.track.StreamID(), remoteTrack.track.ID(), remoteTrack.track.StreamID(),
) )
if err != nil { if err != nil {
return nil, err return err
} }
sender, err := conn.pc.AddTrack(local) sender, err := conn.pc.AddTrack(local)
if err != nil { if err != nil {
return nil, err return err
} }
parms := sender.GetParameters() parms := sender.GetParameters()
if len(parms.Encodings) != 1 { if len(parms.Encodings) != 1 {
return nil, errors.New("got multiple encodings") return errors.New("got multiple encodings")
} }
track := &rtpDownTrack{ track := &rtpDownTrack{
track: local, track: local,
sender: sender,
ssrc: parms.Encodings[0].SSRC, ssrc: parms.Encodings[0].SSRC,
remote: remoteTrack, remote: remoteTrack,
maxBitrate: new(bitrate), maxBitrate: new(bitrate),
@ -449,7 +438,79 @@ func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrac
go rtcpDownListener(conn, track, sender) go rtcpDownListener(conn, track, sender)
return sender, nil return nil
}
func delDownTrackUnlocked(conn *rtpDownConnection, track *rtpDownTrack) error {
for i := range conn.tracks {
if conn.tracks[i] == track {
track.remote.DelLocal(track)
conn.tracks =
append(conn.tracks[:i], conn.tracks[i+1:]...)
return conn.pc.RemoveTrack(track.sender)
}
}
return os.ErrNotExist
}
func replaceTracks(conn *rtpDownConnection, remote []conn.UpTrack, remoteConn conn.Up) error {
conn.mu.Lock()
defer conn.mu.Unlock()
var add []*rtpUpTrack
var del []*rtpDownTrack
outer:
for _, rtrack := range remote {
rt, ok := rtrack.(*rtpUpTrack)
if !ok {
return errUnexpectedTrackType
}
for _, track := range conn.tracks {
rt2, ok := track.remote.(*rtpUpTrack)
if !ok {
return errUnexpectedTrackType
}
if rt == rt2 {
continue outer
}
}
add = append(add, rt)
}
outer2:
for _, track := range conn.tracks {
rt, ok := track.remote.(*rtpUpTrack)
if !ok {
return errUnexpectedTrackType
}
for _, rtrack := range remote {
rt2, ok := rtrack.(*rtpUpTrack)
if !ok {
return errUnexpectedTrackType
}
if rt == rt2 {
continue outer2
}
}
del = append(del, track)
}
for _, t := range del {
err := delDownTrackUnlocked(conn, t)
if err != nil {
return err
}
}
for _, rt := range add {
err := addDownTrackUnlocked(conn, rt, remoteConn)
if err != nil {
return err
}
}
return nil
} }
func negotiate(c *webClient, down *rtpDownConnection, restartIce bool, replace string) error { func negotiate(c *webClient, down *rtpDownConnection, restartIce bool, replace string) error {
@ -618,20 +679,6 @@ func gotICE(c *webClient, candidate *webrtc.ICECandidateInit, id string) error {
} }
func (c *webClient) setRequested(requested map[string]uint32) error { func (c *webClient) setRequested(requested map[string]uint32) error {
if c.down != nil {
down := make([]string, 0, len(c.down))
for id := range c.down {
down = append(down, id)
}
for _, id := range down {
c.write(clientMessage{
Type: "close",
Id: id,
})
delDownConn(c, id)
}
}
c.requested = requested c.requested = requested
go pushConns(c, c.group) go pushConns(c, c.group)
@ -652,40 +699,6 @@ func (c *webClient) isRequested(label string) bool {
return c.requested[label] != 0 return c.requested[label] != 0
} }
func addDownConnTracks(c *webClient, remote conn.Up, tracks []conn.UpTrack) (*rtpDownConnection, error) {
requested := false
for _, t := range tracks {
if c.isRequested(t.Label()) {
requested = true
break
}
}
if !requested {
delDownConn(c, remote.Id())
return nil, nil
}
down, err := addDownConn(c, remote.Id(), remote)
if err != nil {
return nil, err
}
for _, t := range tracks {
if !c.isRequested(t.Label()) {
continue
}
_, err = addDownTrack(c, down, t, remote)
if err != nil {
delDownConn(c, down.id)
return nil, err
}
}
go rtcpDownSender(down)
return down, nil
}
func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error { func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error {
err := c.action(pushConnAction{g, id, up, tracks, replace}) err := c.action(pushConnAction{g, id, up, tracks, replace})
if err != nil { if err != nil {
@ -823,32 +836,32 @@ func clientLoop(c *webClient, ws *websocket.Conn) error {
} }
if a.conn == nil { if a.conn == nil {
if a.replace != "" { if a.replace != "" {
err := delDownConn( closeDownConnection(
c, a.replace, c, a.replace, "",
) )
if err == nil {
c.write(clientMessage{
Type: "close",
Id: a.replace,
})
}
}
err := delDownConn(c, a.id)
if err == nil {
c.write(clientMessage{
Type: "close",
Id: a.id,
})
} }
closeDownConnection(c, a.id, "")
continue continue
} }
down, err := addDownConnTracks( tracks := make([]conn.UpTrack, 0, len(a.tracks))
c, a.conn, a.tracks, for _, t := range a.tracks {
) if c.isRequested(t.Label()) {
tracks = append(tracks, t)
}
}
if len(tracks) == 0 {
closeDownConnection(c, a.id, "")
continue
}
down, _, err := addDownConn(c, a.conn)
if err != nil {
return err
}
err = replaceTracks(down, tracks, a.conn)
if err != nil { if err != nil {
return err return err
} }
if down != nil {
err = negotiate( err = negotiate(
c, down, false, a.replace, c, down, false, a.replace,
) )
@ -862,7 +875,6 @@ func clientLoop(c *webClient, ws *websocket.Conn) error {
)) ))
continue continue
} }
}
case pushConnsAction: case pushConnsAction:
g := c.group g := c.group
if g == nil || a.group != g { if g == nil || a.group != g {
@ -1011,7 +1023,11 @@ func leaveGroup(c *webClient) {
c.group = nil c.group = nil
} }
func failDownConnection(c *webClient, id string, message string) error { func closeDownConnection(c *webClient, id string, message string) error {
err := delDownConn(c, id)
if err != nil && !os.IsNotExist(err) {
log.Printf("Close down connection: %v", err)
}
if id != "" { if id != "" {
err := c.write(clientMessage{ err := c.write(clientMessage{
Type: "close", Type: "close",
@ -1207,7 +1223,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
if err != ErrUnknownId { if err != ErrUnknownId {
message = "negotiation failed" message = "negotiation failed"
} }
return failDownConnection(c, m.Id, message) return closeDownConnection(c, m.Id, message)
} }
down := getDownConn(c, m.Id) down := getDownConn(c, m.Id)
if down.negotiationNeeded > negotiationUnneeded { if down.negotiationNeeded > negotiationUnneeded {
@ -1217,7 +1233,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
"", "",
) )
if err != nil { if err != nil {
return failDownConnection( return closeDownConnection(
c, m.Id, "negotiation failed", c, m.Id, "negotiation failed",
) )
} }
@ -1227,7 +1243,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
if down != nil { if down != nil {
err := negotiate(c, down, true, "") err := negotiate(c, down, true, "")
if err != nil { if err != nil {
return failDownConnection( return closeDownConnection(
c, m.Id, "renegotiation failed", c, m.Id, "renegotiation failed",
) )
} }
@ -1242,14 +1258,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
return nil return nil
} }
case "abort": case "abort":
err := delDownConn(c, m.Id) return closeDownConnection(c, m.Id, "")
if err != nil {
log.Printf("Abort: %v", err)
}
c.write(clientMessage{
Type: "close",
Id: m.Id,
})
case "ice": case "ice":
if m.Candidate == nil { if m.Candidate == nil {
return group.ProtocolError("null candidate") return group.ProtocolError("null candidate")