From 263258a0c1afef62f81eab3fcde92adf41c55016 Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Tue, 2 Feb 2021 19:38:15 +0100 Subject: [PATCH] Implement renegotiation of down streams. We used to destroy and recreate down streams whenever something changed, which turned out to be racy. We now properly implement renegotiation, as well as atomic replacement of a stream by another one. --- conn/conn.go | 1 - rtpconn/rtpconn.go | 14 +-- rtpconn/webclient.go | 283 ++++++++++++++++++++++--------------------- 3 files changed, 148 insertions(+), 150 deletions(-) diff --git a/conn/conn.go b/conn/conn.go index 0116eea..322c663 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -17,7 +17,6 @@ type Up interface { DelLocal(Down) bool Id() string User() (string, string) - Codecs() []webrtc.RTPCodecCapability } // Type UpTrack represents a track in the client to server direction. diff --git a/rtpconn/rtpconn.go b/rtpconn/rtpconn.go index 6931a81..c2f62ee 100644 --- a/rtpconn/rtpconn.go +++ b/rtpconn/rtpconn.go @@ -78,6 +78,7 @@ type downTrackAtomics struct { type rtpDownTrack struct { track *webrtc.TrackLocalStaticRTP + sender *webrtc.RTPSender remote conn.UpTrack ssrc webrtc.SSRC maxBitrate *bitrate @@ -156,7 +157,7 @@ func (down *rtpDownConnection) getTracks() []*rtpDownTrack { } func newDownConn(c group.Client, id string, remote conn.Up) (*rtpDownConnection, error) { - api := group.APIFromCodecs(remote.Codecs()) + api := c.Group().API() pc, err := api.NewPeerConnection(*ice.ICEConfiguration()) if err != nil { return nil, err @@ -366,17 +367,6 @@ func (up *rtpUpConnection) User() (string, string) { return up.userId, up.username } -func (up *rtpUpConnection) Codecs() []webrtc.RTPCodecCapability { - up.mu.Lock() - defer up.mu.Unlock() - - codecs := make([]webrtc.RTPCodecCapability, len(up.tracks)) - for i := range up.tracks { - codecs[i] = up.tracks[i].Codec() - } - return codecs -} - func (up *rtpUpConnection) AddLocal(local conn.Down) error { up.mu.Lock() defer up.mu.Unlock() diff --git a/rtpconn/webclient.go b/rtpconn/webclient.go index 345290f..62d16cf 100644 --- a/rtpconn/webclient.go +++ b/rtpconn/webclient.go @@ -315,55 +315,50 @@ func getConn(c *webClient, id string) iceConnection { return nil } -func addDownConn(c *webClient, id string, remote conn.Up) (*rtpDownConnection, error) { - conn, err := newDownConn(c, id, remote) - if err != nil { - return nil, err - } +func addDownConn(c *webClient, remote conn.Up) (*rtpDownConnection, bool, error) { + id := remote.Id() - err = addDownConnHelper(c, conn, remote) - if err != nil { - conn.pc.Close() - return nil, err - } - return conn, err -} - -func addDownConnHelper(c *webClient, conn *rtpDownConnection, remote conn.Up) error { c.mu.Lock() defer c.mu.Unlock() - if c.up != nil && c.up[conn.id] != nil { - return errors.New("Adding duplicate connection") + if c.up != nil && c.up[id] != nil { + return nil, false, errors.New("adding duplicate connection") } if c.down == nil { c.down = make(map[string]*rtpDownConnection) } - old := c.down[conn.id] - if old != nil { - // Avoid calling Close under a lock - go old.pc.Close() + if down := c.down[id]; down != nil { + return down, false, nil } - conn.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { - sendICE(c, conn.id, candidate) + down, err := newDownConn(c, id, remote) + if err != nil { + return nil, false, err + } + + down.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { + sendICE(c, down.id, candidate) }) - conn.pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { + down.pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { if state == webrtc.ICEConnectionStateFailed { - c.action(connectionFailedAction{id: conn.id}) + c.action(connectionFailedAction{id: down.id}) } }) - err := remote.AddLocal(conn) + err = remote.AddLocal(down) if err != nil { - return err + down.pc.Close() + return nil, false, err } - c.down[conn.id] = conn - return nil + c.down[down.id] = down + + go rtcpDownSender(down) + + return down, true, nil } func delDownConn(c *webClient, id string) error { @@ -397,46 +392,40 @@ func delDownConnHelper(c *webClient, id string) *rtpDownConnection { return conn } -func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrack, remoteConn conn.Up) (*webrtc.RTPSender, error) { - rt, ok := remoteTrack.(*rtpUpTrack) - if !ok { - return nil, errors.New("unexpected up track type") - } +var errUnexpectedTrackType = errors.New("unexpected track type, this shouldn't happen") - conn.mu.Lock() - defer conn.mu.Unlock() - - remoteSSRC := rt.track.SSRC() +func addDownTrackUnlocked(conn *rtpDownConnection, remoteTrack *rtpUpTrack, remoteConn conn.Up) error { for _, t := range conn.tracks { tt, ok := t.remote.(*rtpUpTrack) if !ok { - return nil, errors.New("unexpected up track type") + return errUnexpectedTrackType } - if tt.track.SSRC() == remoteSSRC { - return nil, os.ErrExist + if tt == remoteTrack { + return os.ErrExist } } local, err := webrtc.NewTrackLocalStaticRTP( remoteTrack.Codec(), - rt.track.ID(), rt.track.StreamID(), + remoteTrack.track.ID(), remoteTrack.track.StreamID(), ) if err != nil { - return nil, err + return err } sender, err := conn.pc.AddTrack(local) if err != nil { - return nil, err + return err } parms := sender.GetParameters() if len(parms.Encodings) != 1 { - return nil, errors.New("got multiple encodings") + return errors.New("got multiple encodings") } track := &rtpDownTrack{ track: local, + sender: sender, ssrc: parms.Encodings[0].SSRC, remote: remoteTrack, maxBitrate: new(bitrate), @@ -449,7 +438,79 @@ func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrac go rtcpDownListener(conn, track, sender) - return sender, nil + return nil +} + +func delDownTrackUnlocked(conn *rtpDownConnection, track *rtpDownTrack) error { + for i := range conn.tracks { + if conn.tracks[i] == track { + track.remote.DelLocal(track) + conn.tracks = + append(conn.tracks[:i], conn.tracks[i+1:]...) + return conn.pc.RemoveTrack(track.sender) + } + } + return os.ErrNotExist +} + +func replaceTracks(conn *rtpDownConnection, remote []conn.UpTrack, remoteConn conn.Up) error { + conn.mu.Lock() + defer conn.mu.Unlock() + + var add []*rtpUpTrack + var del []*rtpDownTrack + +outer: + for _, rtrack := range remote { + rt, ok := rtrack.(*rtpUpTrack) + if !ok { + return errUnexpectedTrackType + } + for _, track := range conn.tracks { + rt2, ok := track.remote.(*rtpUpTrack) + if !ok { + return errUnexpectedTrackType + } + if rt == rt2 { + continue outer + } + } + add = append(add, rt) + } + +outer2: + for _, track := range conn.tracks { + rt, ok := track.remote.(*rtpUpTrack) + if !ok { + return errUnexpectedTrackType + } + for _, rtrack := range remote { + rt2, ok := rtrack.(*rtpUpTrack) + if !ok { + return errUnexpectedTrackType + } + if rt == rt2 { + continue outer2 + } + } + del = append(del, track) + } + + for _, t := range del { + err := delDownTrackUnlocked(conn, t) + if err != nil { + return err + } + } + + for _, rt := range add { + err := addDownTrackUnlocked(conn, rt, remoteConn) + if err != nil { + return err + } + } + + return nil } func negotiate(c *webClient, down *rtpDownConnection, restartIce bool, replace string) error { @@ -618,20 +679,6 @@ func gotICE(c *webClient, candidate *webrtc.ICECandidateInit, id string) error { } func (c *webClient) setRequested(requested map[string]uint32) error { - if c.down != nil { - down := make([]string, 0, len(c.down)) - for id := range c.down { - down = append(down, id) - } - for _, id := range down { - c.write(clientMessage{ - Type: "close", - Id: id, - }) - delDownConn(c, id) - } - } - c.requested = requested go pushConns(c, c.group) @@ -652,40 +699,6 @@ func (c *webClient) isRequested(label string) bool { return c.requested[label] != 0 } -func addDownConnTracks(c *webClient, remote conn.Up, tracks []conn.UpTrack) (*rtpDownConnection, error) { - requested := false - for _, t := range tracks { - if c.isRequested(t.Label()) { - requested = true - break - } - } - if !requested { - delDownConn(c, remote.Id()) - return nil, nil - } - - down, err := addDownConn(c, remote.Id(), remote) - if err != nil { - return nil, err - } - - for _, t := range tracks { - if !c.isRequested(t.Label()) { - continue - } - _, err = addDownTrack(c, down, t, remote) - if err != nil { - delDownConn(c, down.id) - return nil, err - } - } - - go rtcpDownSender(down) - - return down, nil -} - func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error { err := c.action(pushConnAction{g, id, up, tracks, replace}) if err != nil { @@ -823,45 +836,44 @@ func clientLoop(c *webClient, ws *websocket.Conn) error { } if a.conn == nil { if a.replace != "" { - err := delDownConn( - c, a.replace, + closeDownConnection( + c, a.replace, "", ) - if err == nil { - c.write(clientMessage{ - Type: "close", - Id: a.replace, - }) - } - } - err := delDownConn(c, a.id) - if err == nil { - c.write(clientMessage{ - Type: "close", - Id: a.id, - }) } + closeDownConnection(c, a.id, "") continue } - down, err := addDownConnTracks( - c, a.conn, a.tracks, - ) + tracks := make([]conn.UpTrack, 0, len(a.tracks)) + for _, t := range a.tracks { + if c.isRequested(t.Label()) { + tracks = append(tracks, t) + } + } + if len(tracks) == 0 { + closeDownConnection(c, a.id, "") + continue + } + + down, _, err := addDownConn(c, a.conn) if err != nil { return err } - if down != nil { - err = negotiate( - c, down, false, a.replace, - ) - if err != nil { - log.Printf( - "Negotiation failed: %v", - err) - delDownConn(c, down.id) - c.error(group.UserError( - "Negotiation failed", - )) - continue - } + err = replaceTracks(down, tracks, a.conn) + if err != nil { + return err + } + err = negotiate( + c, down, false, a.replace, + ) + if err != nil { + log.Printf( + "Negotiation failed: %v", + err) + delDownConn(c, down.id) + c.error(group.UserError( + "Negotiation failed", + )) + continue } case pushConnsAction: g := c.group @@ -1011,7 +1023,11 @@ func leaveGroup(c *webClient) { c.group = nil } -func failDownConnection(c *webClient, id string, message string) error { +func closeDownConnection(c *webClient, id string, message string) error { + err := delDownConn(c, id) + if err != nil && !os.IsNotExist(err) { + log.Printf("Close down connection: %v", err) + } if id != "" { err := c.write(clientMessage{ Type: "close", @@ -1207,7 +1223,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { if err != ErrUnknownId { message = "negotiation failed" } - return failDownConnection(c, m.Id, message) + return closeDownConnection(c, m.Id, message) } down := getDownConn(c, m.Id) if down.negotiationNeeded > negotiationUnneeded { @@ -1217,7 +1233,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { "", ) if err != nil { - return failDownConnection( + return closeDownConnection( c, m.Id, "negotiation failed", ) } @@ -1227,7 +1243,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { if down != nil { err := negotiate(c, down, true, "") if err != nil { - return failDownConnection( + return closeDownConnection( c, m.Id, "renegotiation failed", ) } @@ -1242,14 +1258,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { return nil } case "abort": - err := delDownConn(c, m.Id) - if err != nil { - log.Printf("Abort: %v", err) - } - c.write(clientMessage{ - Type: "close", - Id: m.Id, - }) + return closeDownConnection(c, m.Id, "") case "ice": if m.Candidate == nil { return group.ProtocolError("null candidate")