diff --git a/conn/conn.go b/conn/conn.go index 0116eea..322c663 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -17,7 +17,6 @@ type Up interface { DelLocal(Down) bool Id() string User() (string, string) - Codecs() []webrtc.RTPCodecCapability } // Type UpTrack represents a track in the client to server direction. diff --git a/rtpconn/rtpconn.go b/rtpconn/rtpconn.go index 6931a81..c2f62ee 100644 --- a/rtpconn/rtpconn.go +++ b/rtpconn/rtpconn.go @@ -78,6 +78,7 @@ type downTrackAtomics struct { type rtpDownTrack struct { track *webrtc.TrackLocalStaticRTP + sender *webrtc.RTPSender remote conn.UpTrack ssrc webrtc.SSRC maxBitrate *bitrate @@ -156,7 +157,7 @@ func (down *rtpDownConnection) getTracks() []*rtpDownTrack { } func newDownConn(c group.Client, id string, remote conn.Up) (*rtpDownConnection, error) { - api := group.APIFromCodecs(remote.Codecs()) + api := c.Group().API() pc, err := api.NewPeerConnection(*ice.ICEConfiguration()) if err != nil { return nil, err @@ -366,17 +367,6 @@ func (up *rtpUpConnection) User() (string, string) { return up.userId, up.username } -func (up *rtpUpConnection) Codecs() []webrtc.RTPCodecCapability { - up.mu.Lock() - defer up.mu.Unlock() - - codecs := make([]webrtc.RTPCodecCapability, len(up.tracks)) - for i := range up.tracks { - codecs[i] = up.tracks[i].Codec() - } - return codecs -} - func (up *rtpUpConnection) AddLocal(local conn.Down) error { up.mu.Lock() defer up.mu.Unlock() diff --git a/rtpconn/webclient.go b/rtpconn/webclient.go index 345290f..62d16cf 100644 --- a/rtpconn/webclient.go +++ b/rtpconn/webclient.go @@ -315,55 +315,50 @@ func getConn(c *webClient, id string) iceConnection { return nil } -func addDownConn(c *webClient, id string, remote conn.Up) (*rtpDownConnection, error) { - conn, err := newDownConn(c, id, remote) - if err != nil { - return nil, err - } +func addDownConn(c *webClient, remote conn.Up) (*rtpDownConnection, bool, error) { + id := remote.Id() - err = addDownConnHelper(c, conn, remote) - if err != nil { - conn.pc.Close() - return nil, err - } - return conn, err -} - -func addDownConnHelper(c *webClient, conn *rtpDownConnection, remote conn.Up) error { c.mu.Lock() defer c.mu.Unlock() - if c.up != nil && c.up[conn.id] != nil { - return errors.New("Adding duplicate connection") + if c.up != nil && c.up[id] != nil { + return nil, false, errors.New("adding duplicate connection") } if c.down == nil { c.down = make(map[string]*rtpDownConnection) } - old := c.down[conn.id] - if old != nil { - // Avoid calling Close under a lock - go old.pc.Close() + if down := c.down[id]; down != nil { + return down, false, nil } - conn.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { - sendICE(c, conn.id, candidate) + down, err := newDownConn(c, id, remote) + if err != nil { + return nil, false, err + } + + down.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { + sendICE(c, down.id, candidate) }) - conn.pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { + down.pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { if state == webrtc.ICEConnectionStateFailed { - c.action(connectionFailedAction{id: conn.id}) + c.action(connectionFailedAction{id: down.id}) } }) - err := remote.AddLocal(conn) + err = remote.AddLocal(down) if err != nil { - return err + down.pc.Close() + return nil, false, err } - c.down[conn.id] = conn - return nil + c.down[down.id] = down + + go rtcpDownSender(down) + + return down, true, nil } func delDownConn(c *webClient, id string) error { @@ -397,46 +392,40 @@ func delDownConnHelper(c *webClient, id string) *rtpDownConnection { return conn } -func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrack, remoteConn conn.Up) (*webrtc.RTPSender, error) { - rt, ok := remoteTrack.(*rtpUpTrack) - if !ok { - return nil, errors.New("unexpected up track type") - } +var errUnexpectedTrackType = errors.New("unexpected track type, this shouldn't happen") - conn.mu.Lock() - defer conn.mu.Unlock() - - remoteSSRC := rt.track.SSRC() +func addDownTrackUnlocked(conn *rtpDownConnection, remoteTrack *rtpUpTrack, remoteConn conn.Up) error { for _, t := range conn.tracks { tt, ok := t.remote.(*rtpUpTrack) if !ok { - return nil, errors.New("unexpected up track type") + return errUnexpectedTrackType } - if tt.track.SSRC() == remoteSSRC { - return nil, os.ErrExist + if tt == remoteTrack { + return os.ErrExist } } local, err := webrtc.NewTrackLocalStaticRTP( remoteTrack.Codec(), - rt.track.ID(), rt.track.StreamID(), + remoteTrack.track.ID(), remoteTrack.track.StreamID(), ) if err != nil { - return nil, err + return err } sender, err := conn.pc.AddTrack(local) if err != nil { - return nil, err + return err } parms := sender.GetParameters() if len(parms.Encodings) != 1 { - return nil, errors.New("got multiple encodings") + return errors.New("got multiple encodings") } track := &rtpDownTrack{ track: local, + sender: sender, ssrc: parms.Encodings[0].SSRC, remote: remoteTrack, maxBitrate: new(bitrate), @@ -449,7 +438,79 @@ func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrac go rtcpDownListener(conn, track, sender) - return sender, nil + return nil +} + +func delDownTrackUnlocked(conn *rtpDownConnection, track *rtpDownTrack) error { + for i := range conn.tracks { + if conn.tracks[i] == track { + track.remote.DelLocal(track) + conn.tracks = + append(conn.tracks[:i], conn.tracks[i+1:]...) + return conn.pc.RemoveTrack(track.sender) + } + } + return os.ErrNotExist +} + +func replaceTracks(conn *rtpDownConnection, remote []conn.UpTrack, remoteConn conn.Up) error { + conn.mu.Lock() + defer conn.mu.Unlock() + + var add []*rtpUpTrack + var del []*rtpDownTrack + +outer: + for _, rtrack := range remote { + rt, ok := rtrack.(*rtpUpTrack) + if !ok { + return errUnexpectedTrackType + } + for _, track := range conn.tracks { + rt2, ok := track.remote.(*rtpUpTrack) + if !ok { + return errUnexpectedTrackType + } + if rt == rt2 { + continue outer + } + } + add = append(add, rt) + } + +outer2: + for _, track := range conn.tracks { + rt, ok := track.remote.(*rtpUpTrack) + if !ok { + return errUnexpectedTrackType + } + for _, rtrack := range remote { + rt2, ok := rtrack.(*rtpUpTrack) + if !ok { + return errUnexpectedTrackType + } + if rt == rt2 { + continue outer2 + } + } + del = append(del, track) + } + + for _, t := range del { + err := delDownTrackUnlocked(conn, t) + if err != nil { + return err + } + } + + for _, rt := range add { + err := addDownTrackUnlocked(conn, rt, remoteConn) + if err != nil { + return err + } + } + + return nil } func negotiate(c *webClient, down *rtpDownConnection, restartIce bool, replace string) error { @@ -618,20 +679,6 @@ func gotICE(c *webClient, candidate *webrtc.ICECandidateInit, id string) error { } func (c *webClient) setRequested(requested map[string]uint32) error { - if c.down != nil { - down := make([]string, 0, len(c.down)) - for id := range c.down { - down = append(down, id) - } - for _, id := range down { - c.write(clientMessage{ - Type: "close", - Id: id, - }) - delDownConn(c, id) - } - } - c.requested = requested go pushConns(c, c.group) @@ -652,40 +699,6 @@ func (c *webClient) isRequested(label string) bool { return c.requested[label] != 0 } -func addDownConnTracks(c *webClient, remote conn.Up, tracks []conn.UpTrack) (*rtpDownConnection, error) { - requested := false - for _, t := range tracks { - if c.isRequested(t.Label()) { - requested = true - break - } - } - if !requested { - delDownConn(c, remote.Id()) - return nil, nil - } - - down, err := addDownConn(c, remote.Id(), remote) - if err != nil { - return nil, err - } - - for _, t := range tracks { - if !c.isRequested(t.Label()) { - continue - } - _, err = addDownTrack(c, down, t, remote) - if err != nil { - delDownConn(c, down.id) - return nil, err - } - } - - go rtcpDownSender(down) - - return down, nil -} - func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error { err := c.action(pushConnAction{g, id, up, tracks, replace}) if err != nil { @@ -823,45 +836,44 @@ func clientLoop(c *webClient, ws *websocket.Conn) error { } if a.conn == nil { if a.replace != "" { - err := delDownConn( - c, a.replace, + closeDownConnection( + c, a.replace, "", ) - if err == nil { - c.write(clientMessage{ - Type: "close", - Id: a.replace, - }) - } - } - err := delDownConn(c, a.id) - if err == nil { - c.write(clientMessage{ - Type: "close", - Id: a.id, - }) } + closeDownConnection(c, a.id, "") continue } - down, err := addDownConnTracks( - c, a.conn, a.tracks, - ) + tracks := make([]conn.UpTrack, 0, len(a.tracks)) + for _, t := range a.tracks { + if c.isRequested(t.Label()) { + tracks = append(tracks, t) + } + } + if len(tracks) == 0 { + closeDownConnection(c, a.id, "") + continue + } + + down, _, err := addDownConn(c, a.conn) if err != nil { return err } - if down != nil { - err = negotiate( - c, down, false, a.replace, - ) - if err != nil { - log.Printf( - "Negotiation failed: %v", - err) - delDownConn(c, down.id) - c.error(group.UserError( - "Negotiation failed", - )) - continue - } + err = replaceTracks(down, tracks, a.conn) + if err != nil { + return err + } + err = negotiate( + c, down, false, a.replace, + ) + if err != nil { + log.Printf( + "Negotiation failed: %v", + err) + delDownConn(c, down.id) + c.error(group.UserError( + "Negotiation failed", + )) + continue } case pushConnsAction: g := c.group @@ -1011,7 +1023,11 @@ func leaveGroup(c *webClient) { c.group = nil } -func failDownConnection(c *webClient, id string, message string) error { +func closeDownConnection(c *webClient, id string, message string) error { + err := delDownConn(c, id) + if err != nil && !os.IsNotExist(err) { + log.Printf("Close down connection: %v", err) + } if id != "" { err := c.write(clientMessage{ Type: "close", @@ -1207,7 +1223,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { if err != ErrUnknownId { message = "negotiation failed" } - return failDownConnection(c, m.Id, message) + return closeDownConnection(c, m.Id, message) } down := getDownConn(c, m.Id) if down.negotiationNeeded > negotiationUnneeded { @@ -1217,7 +1233,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { "", ) if err != nil { - return failDownConnection( + return closeDownConnection( c, m.Id, "negotiation failed", ) } @@ -1227,7 +1243,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { if down != nil { err := negotiate(c, down, true, "") if err != nil { - return failDownConnection( + return closeDownConnection( c, m.Id, "renegotiation failed", ) } @@ -1242,14 +1258,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { return nil } case "abort": - err := delDownConn(c, m.Id) - if err != nil { - log.Printf("Abort: %v", err) - } - c.write(clientMessage{ - Type: "close", - Id: m.Id, - }) + return closeDownConnection(c, m.Id, "") case "ice": if m.Candidate == nil { return group.ProtocolError("null candidate")