mirror of
https://github.com/jech/galene.git
synced 2024-11-09 18:25:58 +01:00
Implement renegotiation of down streams.
We used to destroy and recreate down streams whenever something changed, which turned out to be racy. We now properly implement renegotiation, as well as atomic replacement of a stream by another one.
This commit is contained in:
parent
368da133fd
commit
263258a0c1
3 changed files with 148 additions and 150 deletions
|
@ -17,7 +17,6 @@ type Up interface {
|
|||
DelLocal(Down) bool
|
||||
Id() string
|
||||
User() (string, string)
|
||||
Codecs() []webrtc.RTPCodecCapability
|
||||
}
|
||||
|
||||
// Type UpTrack represents a track in the client to server direction.
|
||||
|
|
|
@ -78,6 +78,7 @@ type downTrackAtomics struct {
|
|||
|
||||
type rtpDownTrack struct {
|
||||
track *webrtc.TrackLocalStaticRTP
|
||||
sender *webrtc.RTPSender
|
||||
remote conn.UpTrack
|
||||
ssrc webrtc.SSRC
|
||||
maxBitrate *bitrate
|
||||
|
@ -156,7 +157,7 @@ func (down *rtpDownConnection) getTracks() []*rtpDownTrack {
|
|||
}
|
||||
|
||||
func newDownConn(c group.Client, id string, remote conn.Up) (*rtpDownConnection, error) {
|
||||
api := group.APIFromCodecs(remote.Codecs())
|
||||
api := c.Group().API()
|
||||
pc, err := api.NewPeerConnection(*ice.ICEConfiguration())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -366,17 +367,6 @@ func (up *rtpUpConnection) User() (string, string) {
|
|||
return up.userId, up.username
|
||||
}
|
||||
|
||||
func (up *rtpUpConnection) Codecs() []webrtc.RTPCodecCapability {
|
||||
up.mu.Lock()
|
||||
defer up.mu.Unlock()
|
||||
|
||||
codecs := make([]webrtc.RTPCodecCapability, len(up.tracks))
|
||||
for i := range up.tracks {
|
||||
codecs[i] = up.tracks[i].Codec()
|
||||
}
|
||||
return codecs
|
||||
}
|
||||
|
||||
func (up *rtpUpConnection) AddLocal(local conn.Down) error {
|
||||
up.mu.Lock()
|
||||
defer up.mu.Unlock()
|
||||
|
|
|
@ -315,55 +315,50 @@ func getConn(c *webClient, id string) iceConnection {
|
|||
return nil
|
||||
}
|
||||
|
||||
func addDownConn(c *webClient, id string, remote conn.Up) (*rtpDownConnection, error) {
|
||||
conn, err := newDownConn(c, id, remote)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func addDownConn(c *webClient, remote conn.Up) (*rtpDownConnection, bool, error) {
|
||||
id := remote.Id()
|
||||
|
||||
err = addDownConnHelper(c, conn, remote)
|
||||
if err != nil {
|
||||
conn.pc.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
func addDownConnHelper(c *webClient, conn *rtpDownConnection, remote conn.Up) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.up != nil && c.up[conn.id] != nil {
|
||||
return errors.New("Adding duplicate connection")
|
||||
if c.up != nil && c.up[id] != nil {
|
||||
return nil, false, errors.New("adding duplicate connection")
|
||||
}
|
||||
|
||||
if c.down == nil {
|
||||
c.down = make(map[string]*rtpDownConnection)
|
||||
}
|
||||
|
||||
old := c.down[conn.id]
|
||||
if old != nil {
|
||||
// Avoid calling Close under a lock
|
||||
go old.pc.Close()
|
||||
if down := c.down[id]; down != nil {
|
||||
return down, false, nil
|
||||
}
|
||||
|
||||
conn.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
sendICE(c, conn.id, candidate)
|
||||
})
|
||||
|
||||
conn.pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
|
||||
if state == webrtc.ICEConnectionStateFailed {
|
||||
c.action(connectionFailedAction{id: conn.id})
|
||||
}
|
||||
})
|
||||
|
||||
err := remote.AddLocal(conn)
|
||||
down, err := newDownConn(c, id, remote)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
c.down[conn.id] = conn
|
||||
return nil
|
||||
down.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
sendICE(c, down.id, candidate)
|
||||
})
|
||||
|
||||
down.pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
|
||||
if state == webrtc.ICEConnectionStateFailed {
|
||||
c.action(connectionFailedAction{id: down.id})
|
||||
}
|
||||
})
|
||||
|
||||
err = remote.AddLocal(down)
|
||||
if err != nil {
|
||||
down.pc.Close()
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
c.down[down.id] = down
|
||||
|
||||
go rtcpDownSender(down)
|
||||
|
||||
return down, true, nil
|
||||
}
|
||||
|
||||
func delDownConn(c *webClient, id string) error {
|
||||
|
@ -397,46 +392,40 @@ func delDownConnHelper(c *webClient, id string) *rtpDownConnection {
|
|||
return conn
|
||||
}
|
||||
|
||||
func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrack, remoteConn conn.Up) (*webrtc.RTPSender, error) {
|
||||
rt, ok := remoteTrack.(*rtpUpTrack)
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected up track type")
|
||||
}
|
||||
var errUnexpectedTrackType = errors.New("unexpected track type, this shouldn't happen")
|
||||
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
remoteSSRC := rt.track.SSRC()
|
||||
func addDownTrackUnlocked(conn *rtpDownConnection, remoteTrack *rtpUpTrack, remoteConn conn.Up) error {
|
||||
for _, t := range conn.tracks {
|
||||
tt, ok := t.remote.(*rtpUpTrack)
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected up track type")
|
||||
return errUnexpectedTrackType
|
||||
}
|
||||
if tt.track.SSRC() == remoteSSRC {
|
||||
return nil, os.ErrExist
|
||||
if tt == remoteTrack {
|
||||
return os.ErrExist
|
||||
}
|
||||
}
|
||||
|
||||
local, err := webrtc.NewTrackLocalStaticRTP(
|
||||
remoteTrack.Codec(),
|
||||
rt.track.ID(), rt.track.StreamID(),
|
||||
remoteTrack.track.ID(), remoteTrack.track.StreamID(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
sender, err := conn.pc.AddTrack(local)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
parms := sender.GetParameters()
|
||||
if len(parms.Encodings) != 1 {
|
||||
return nil, errors.New("got multiple encodings")
|
||||
return errors.New("got multiple encodings")
|
||||
}
|
||||
|
||||
track := &rtpDownTrack{
|
||||
track: local,
|
||||
sender: sender,
|
||||
ssrc: parms.Encodings[0].SSRC,
|
||||
remote: remoteTrack,
|
||||
maxBitrate: new(bitrate),
|
||||
|
@ -449,7 +438,79 @@ func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrac
|
|||
|
||||
go rtcpDownListener(conn, track, sender)
|
||||
|
||||
return sender, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func delDownTrackUnlocked(conn *rtpDownConnection, track *rtpDownTrack) error {
|
||||
for i := range conn.tracks {
|
||||
if conn.tracks[i] == track {
|
||||
track.remote.DelLocal(track)
|
||||
conn.tracks =
|
||||
append(conn.tracks[:i], conn.tracks[i+1:]...)
|
||||
return conn.pc.RemoveTrack(track.sender)
|
||||
}
|
||||
}
|
||||
return os.ErrNotExist
|
||||
}
|
||||
|
||||
func replaceTracks(conn *rtpDownConnection, remote []conn.UpTrack, remoteConn conn.Up) error {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
var add []*rtpUpTrack
|
||||
var del []*rtpDownTrack
|
||||
|
||||
outer:
|
||||
for _, rtrack := range remote {
|
||||
rt, ok := rtrack.(*rtpUpTrack)
|
||||
if !ok {
|
||||
return errUnexpectedTrackType
|
||||
}
|
||||
for _, track := range conn.tracks {
|
||||
rt2, ok := track.remote.(*rtpUpTrack)
|
||||
if !ok {
|
||||
return errUnexpectedTrackType
|
||||
}
|
||||
if rt == rt2 {
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
add = append(add, rt)
|
||||
}
|
||||
|
||||
outer2:
|
||||
for _, track := range conn.tracks {
|
||||
rt, ok := track.remote.(*rtpUpTrack)
|
||||
if !ok {
|
||||
return errUnexpectedTrackType
|
||||
}
|
||||
for _, rtrack := range remote {
|
||||
rt2, ok := rtrack.(*rtpUpTrack)
|
||||
if !ok {
|
||||
return errUnexpectedTrackType
|
||||
}
|
||||
if rt == rt2 {
|
||||
continue outer2
|
||||
}
|
||||
}
|
||||
del = append(del, track)
|
||||
}
|
||||
|
||||
for _, t := range del {
|
||||
err := delDownTrackUnlocked(conn, t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, rt := range add {
|
||||
err := addDownTrackUnlocked(conn, rt, remoteConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func negotiate(c *webClient, down *rtpDownConnection, restartIce bool, replace string) error {
|
||||
|
@ -618,20 +679,6 @@ func gotICE(c *webClient, candidate *webrtc.ICECandidateInit, id string) error {
|
|||
}
|
||||
|
||||
func (c *webClient) setRequested(requested map[string]uint32) error {
|
||||
if c.down != nil {
|
||||
down := make([]string, 0, len(c.down))
|
||||
for id := range c.down {
|
||||
down = append(down, id)
|
||||
}
|
||||
for _, id := range down {
|
||||
c.write(clientMessage{
|
||||
Type: "close",
|
||||
Id: id,
|
||||
})
|
||||
delDownConn(c, id)
|
||||
}
|
||||
}
|
||||
|
||||
c.requested = requested
|
||||
|
||||
go pushConns(c, c.group)
|
||||
|
@ -652,40 +699,6 @@ func (c *webClient) isRequested(label string) bool {
|
|||
return c.requested[label] != 0
|
||||
}
|
||||
|
||||
func addDownConnTracks(c *webClient, remote conn.Up, tracks []conn.UpTrack) (*rtpDownConnection, error) {
|
||||
requested := false
|
||||
for _, t := range tracks {
|
||||
if c.isRequested(t.Label()) {
|
||||
requested = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !requested {
|
||||
delDownConn(c, remote.Id())
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
down, err := addDownConn(c, remote.Id(), remote)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, t := range tracks {
|
||||
if !c.isRequested(t.Label()) {
|
||||
continue
|
||||
}
|
||||
_, err = addDownTrack(c, down, t, remote)
|
||||
if err != nil {
|
||||
delDownConn(c, down.id)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
go rtcpDownSender(down)
|
||||
|
||||
return down, nil
|
||||
}
|
||||
|
||||
func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error {
|
||||
err := c.action(pushConnAction{g, id, up, tracks, replace})
|
||||
if err != nil {
|
||||
|
@ -823,32 +836,32 @@ func clientLoop(c *webClient, ws *websocket.Conn) error {
|
|||
}
|
||||
if a.conn == nil {
|
||||
if a.replace != "" {
|
||||
err := delDownConn(
|
||||
c, a.replace,
|
||||
closeDownConnection(
|
||||
c, a.replace, "",
|
||||
)
|
||||
if err == nil {
|
||||
c.write(clientMessage{
|
||||
Type: "close",
|
||||
Id: a.replace,
|
||||
})
|
||||
}
|
||||
}
|
||||
err := delDownConn(c, a.id)
|
||||
if err == nil {
|
||||
c.write(clientMessage{
|
||||
Type: "close",
|
||||
Id: a.id,
|
||||
})
|
||||
}
|
||||
closeDownConnection(c, a.id, "")
|
||||
continue
|
||||
}
|
||||
down, err := addDownConnTracks(
|
||||
c, a.conn, a.tracks,
|
||||
)
|
||||
tracks := make([]conn.UpTrack, 0, len(a.tracks))
|
||||
for _, t := range a.tracks {
|
||||
if c.isRequested(t.Label()) {
|
||||
tracks = append(tracks, t)
|
||||
}
|
||||
}
|
||||
if len(tracks) == 0 {
|
||||
closeDownConnection(c, a.id, "")
|
||||
continue
|
||||
}
|
||||
|
||||
down, _, err := addDownConn(c, a.conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = replaceTracks(down, tracks, a.conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if down != nil {
|
||||
err = negotiate(
|
||||
c, down, false, a.replace,
|
||||
)
|
||||
|
@ -862,7 +875,6 @@ func clientLoop(c *webClient, ws *websocket.Conn) error {
|
|||
))
|
||||
continue
|
||||
}
|
||||
}
|
||||
case pushConnsAction:
|
||||
g := c.group
|
||||
if g == nil || a.group != g {
|
||||
|
@ -1011,7 +1023,11 @@ func leaveGroup(c *webClient) {
|
|||
c.group = nil
|
||||
}
|
||||
|
||||
func failDownConnection(c *webClient, id string, message string) error {
|
||||
func closeDownConnection(c *webClient, id string, message string) error {
|
||||
err := delDownConn(c, id)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
log.Printf("Close down connection: %v", err)
|
||||
}
|
||||
if id != "" {
|
||||
err := c.write(clientMessage{
|
||||
Type: "close",
|
||||
|
@ -1207,7 +1223,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
|
|||
if err != ErrUnknownId {
|
||||
message = "negotiation failed"
|
||||
}
|
||||
return failDownConnection(c, m.Id, message)
|
||||
return closeDownConnection(c, m.Id, message)
|
||||
}
|
||||
down := getDownConn(c, m.Id)
|
||||
if down.negotiationNeeded > negotiationUnneeded {
|
||||
|
@ -1217,7 +1233,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
|
|||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return failDownConnection(
|
||||
return closeDownConnection(
|
||||
c, m.Id, "negotiation failed",
|
||||
)
|
||||
}
|
||||
|
@ -1227,7 +1243,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
|
|||
if down != nil {
|
||||
err := negotiate(c, down, true, "")
|
||||
if err != nil {
|
||||
return failDownConnection(
|
||||
return closeDownConnection(
|
||||
c, m.Id, "renegotiation failed",
|
||||
)
|
||||
}
|
||||
|
@ -1242,14 +1258,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
|
|||
return nil
|
||||
}
|
||||
case "abort":
|
||||
err := delDownConn(c, m.Id)
|
||||
if err != nil {
|
||||
log.Printf("Abort: %v", err)
|
||||
}
|
||||
c.write(clientMessage{
|
||||
Type: "close",
|
||||
Id: m.Id,
|
||||
})
|
||||
return closeDownConnection(c, m.Id, "")
|
||||
case "ice":
|
||||
if m.Candidate == nil {
|
||||
return group.ProtocolError("null candidate")
|
||||
|
|
Loading…
Reference in a new issue