diff --git a/diskwriter/diskwriter.go b/diskwriter/diskwriter.go index ad67491..986a986 100644 --- a/diskwriter/diskwriter.go +++ b/diskwriter/diskwriter.go @@ -87,7 +87,11 @@ func (client *Client) Kick(id, user, message string) error { return err } -func (client *Client) PushConn(id string, up conn.Up, tracks []conn.UpTrack, label string) error { +func (client *Client) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, label string) error { + if client.group != g { + return nil + } + client.mu.Lock() defer client.mu.Unlock() diff --git a/group/client.go b/group/client.go index e0e2887..275c33f 100644 --- a/group/client.go +++ b/group/client.go @@ -97,7 +97,7 @@ type Client interface { Challengeable SetPermissions(ClientPermissions) OverridePermissions(*Group) bool - PushConn(id string, conn conn.Up, tracks []conn.UpTrack, label string) error + PushConn(g *Group, id string, conn conn.Up, tracks []conn.UpTrack, label string) error PushClient(id, username string, add bool) error } diff --git a/rtpconn/rtpconn.go b/rtpconn/rtpconn.go index e399e26..338e6c2 100644 --- a/rtpconn/rtpconn.go +++ b/rtpconn/rtpconn.go @@ -450,7 +450,7 @@ func newUpConn(c group.Client, id string) (*rtpUpConnection, error) { if complete { clients := c.Group().GetClients(c) for _, cc := range clients { - cc.PushConn(up.id, up, tracks, up.label) + cc.PushConn(c.Group(), up.id, up, tracks, up.label) } go rtcpUpSender(up) } diff --git a/rtpconn/webclient.go b/rtpconn/webclient.go index 9b748b3..e90a9d5 100644 --- a/rtpconn/webclient.go +++ b/rtpconn/webclient.go @@ -258,7 +258,7 @@ func delUpConn(c *webClient, id string) bool { if g != nil { go func(clients []group.Client) { for _, c := range clients { - err := c.PushConn(conn.id, nil, nil, "") + err := c.PushConn(g, conn.id, nil, nil, "") if err != nil { log.Printf("PushConn: %v", err) } @@ -582,21 +582,16 @@ func (c *webClient) setRequested(requested map[string]uint32) error { c.requested = requested - go pushConns(c) + go pushConns(c, c.group) return nil } -func pushConns(c group.Client) { - group := c.Group() - if group == nil { - log.Printf("Pushing connections to unjoined client") - return - } - clients := group.GetClients(c) +func pushConns(c group.Client, g *group.Group) { + clients := g.GetClients(c) for _, cc := range clients { ccc, ok := cc.(*webClient) if ok { - ccc.action(pushConnsAction{c}) + ccc.action(pushConnsAction{g, c}) } } } @@ -638,8 +633,8 @@ func addDownConnTracks(c *webClient, remote conn.Up, tracks []conn.UpTrack) (*rt return down, nil } -func (c *webClient) PushConn(id string, up conn.Up, tracks []conn.UpTrack, label string) error { - err := c.action(pushConnAction{id, up, tracks}) +func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, label string) error { + err := c.action(pushConnAction{g, id, up, tracks}) if err != nil { return err } @@ -709,6 +704,7 @@ func StartClient(conn *websocket.Conn) error { } type pushConnAction struct { + group *group.Group id string conn conn.Up tracks []conn.UpTrack @@ -720,7 +716,8 @@ type addLabelAction struct { } type pushConnsAction struct { - c group.Client + group *group.Group + client group.Client } type connectionFailedAction struct { @@ -736,24 +733,10 @@ type kickAction struct { } func clientLoop(c *webClient, ws *websocket.Conn) error { - defer func() { - if c.group != nil { - group.DelClient(c) - c.group = nil - } - }() - read := make(chan interface{}, 1) go clientReader(ws, read, c.done) - defer func() { - c.setRequested(map[string]uint32{}) - if c.up != nil { - for id := range c.up { - delUpConn(c, id) - } - } - }() + defer leaveGroup(c) readTime := time.Now() @@ -779,6 +762,10 @@ func clientLoop(c *webClient, ws *websocket.Conn) error { case a := <-c.actionCh: switch a := a.(type) { case pushConnAction: + g := c.group + if g == nil || a.group != g { + return nil + } if a.conn == nil { found := delDownConn(c, a.id) if found { @@ -821,6 +808,10 @@ func clientLoop(c *webClient, ws *websocket.Conn) error { Value: &label, }) case pushConnsAction: + g := c.group + if g == nil || a.group != g { + return nil + } for _, u := range c.up { if !u.complete() { continue @@ -831,8 +822,8 @@ func clientLoop(c *webClient, ws *websocket.Conn) error { ts[i] = t } go func() { - err := a.c.PushConn( - u.id, u, ts, u.label, + err := a.client.PushConn( + g, u.id, u, ts, u.label, ) if err != nil { log.Printf( @@ -855,6 +846,7 @@ func clientLoop(c *webClient, ws *websocket.Conn) error { tracks[i] = t.remote } go c.PushConn( + c.group, down.remote.Id(), down.remote, tracks, down.remote.Label(), ) @@ -935,6 +927,24 @@ func failUpConnection(c *webClient, id string, message string) error { return nil } +func leaveGroup(c *webClient) { + if c.group == nil { + return + } + + c.setRequested(map[string]uint32{}) + if c.up != nil { + for id := range c.up { + delUpConn(c, id) + } + } + + group.DelClient(c) + c.permissions = group.ClientPermissions{} + c.group = nil +} + + func failDownConnection(c *webClient, id string, message string) error { if id != "" { err := c.write(clientMessage{ @@ -1009,8 +1019,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { if c.group == nil || c.group.Name() != m.Group { return group.ProtocolError("you are not joined") } - c.group = nil - c.permissions = group.ClientPermissions{} + leaveGroup(c) perms := c.permissions return c.write(clientMessage{ Type: "joined", @@ -1245,7 +1254,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { disk.Close() return c.error(err) } - go pushConns(disk) + go pushConns(disk, c.group) case "unrecord": if !c.permissions.Record { return c.error(group.UserError("not authorised"))