From 00fbfafeeb60d3fbfe0817bf575ecf6e7b211531 Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Sat, 9 Dec 2023 16:56:33 +0100 Subject: [PATCH] Make unbounded channels explicit. We used to have unbounded channels embedded within rtpconn and webClient. Make the structure explicit and testable. --- rtpconn/rtpconn.go | 22 ++++--------- rtpconn/rtpreader.go | 7 ++--- rtpconn/webclient.go | 61 +++++++++++++------------------------ unbounded/unbounded.go | 47 ++++++++++++++++++++++++++++ unbounded/unbounded_test.go | 52 +++++++++++++++++++++++++++++++ 5 files changed, 128 insertions(+), 61 deletions(-) create mode 100644 unbounded/unbounded.go create mode 100644 unbounded/unbounded_test.go diff --git a/rtpconn/rtpconn.go b/rtpconn/rtpconn.go index 9f5739e..5cf1ee9 100644 --- a/rtpconn/rtpconn.go +++ b/rtpconn/rtpconn.go @@ -23,6 +23,7 @@ import ( "github.com/jech/galene/packetcache" "github.com/jech/galene/packetmap" "github.com/jech/galene/rtptime" + "github.com/jech/galene/unbounded" ) type bitrate struct { @@ -403,7 +404,7 @@ type rtpUpTrack struct { jitter *jitter.Estimator cname atomic.Value - actionCh chan struct{} + actions *unbounded.Channel[trackAction] readerDone chan struct{} mu sync.Mutex @@ -412,7 +413,6 @@ type rtpUpTrack struct { srRTPTime uint32 local []conn.DownTrack bufferedNACKs []uint16 - actions []trackAction } const ( @@ -427,17 +427,7 @@ type trackAction struct { } func (up *rtpUpTrack) action(action int, track conn.DownTrack) { - up.mu.Lock() - empty := len(up.actions) == 0 - up.actions = append(up.actions, trackAction{action, track}) - up.mu.Unlock() - - if empty { - select { - case up.actionCh <- struct{}{}: - default: - } - } + up.actions.Put(trackAction{action, track}) } func (up *rtpUpTrack) AddLocal(local conn.DownTrack) error { @@ -682,7 +672,7 @@ func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpCon cache: packetcache.New(minPacketCache(remote)), rate: estimator.New(time.Second), jitter: jitter.New(remote.Codec().ClockRate), - actionCh: make(chan struct{}, 1), + actions: unbounded.New[trackAction](), readerDone: make(chan struct{}), } @@ -923,7 +913,7 @@ func maxUpBitrate(t *rtpUpTrack) uint64 { // assume that lower spatial layers take up 1/5 of // the throughput if maxsid > 0 { - maxrate = sadd(maxrate, maxrate / 4) + maxrate = sadd(maxrate, maxrate/4) } // assume that each layer takes two times less // throughput than the higher one. Then we've @@ -1003,7 +993,7 @@ func sendUpRTCP(up *rtpUpConnection) error { } ssrcs = append(ssrcs, uint32(t.track.SSRC())) if t.Kind() == webrtc.RTPCodecTypeAudio { - rate = sadd(rate, 100 * 1024) + rate = sadd(rate, 100*1024) } else if t.Label() == "l" { rate = sadd(rate, group.LowBitrate) } else { diff --git a/rtpconn/rtpreader.go b/rtpconn/rtpreader.go index 6a501df..89ced80 100644 --- a/rtpconn/rtpreader.go +++ b/rtpconn/rtpreader.go @@ -31,11 +31,8 @@ func readLoop(track *rtpUpTrack) { for { select { - case <-track.actionCh: - track.mu.Lock() - actions := track.actions - track.actions = nil - track.mu.Unlock() + case <-track.actions.Ch: + actions := track.actions.Get() for _, action := range actions { switch action.action { case trackActionAdd, trackActionDel: diff --git a/rtpconn/webclient.go b/rtpconn/webclient.go index 0b282d1..52ab38c 100644 --- a/rtpconn/webclient.go +++ b/rtpconn/webclient.go @@ -20,6 +20,7 @@ import ( "github.com/jech/galene/group" "github.com/jech/galene/ice" "github.com/jech/galene/token" + "github.com/jech/galene/unbounded" ) func errorToWSCloseMessage(id string, err error) (*clientMessage, []byte) { @@ -65,16 +66,11 @@ type webClient struct { done chan struct{} writeCh chan interface{} writerDone chan struct{} - actionCh chan struct{} + actions *unbounded.Channel[any] mu sync.Mutex down map[string]*rtpDownConnection up map[string]*rtpUpConnection - - // action may be called with the group mutex taken, and therefore - // actions needs to use its own mutex. - actionMu sync.Mutex - actions []interface{} } func (c *webClient) Group() *group.Group { @@ -106,9 +102,10 @@ func (c *webClient) SetPermissions(perms []string) { } func (c *webClient) PushClient(group, kind, id string, username string, perms []string, data map[string]interface{}) error { - return c.action(pushClientAction{ + c.action(pushClientAction{ group, kind, id, username, perms, data, }) + return nil } type clientMessage struct { @@ -733,7 +730,8 @@ func (c *webClient) setRequestedStream(down *rtpDownConnection, requested []stri } func (c *webClient) RequestConns(target group.Client, g *group.Group, id string) error { - return c.action(requestConnsAction{g, target, id}) + c.action(requestConnsAction{g, target, id}) + return nil } func requestConns(target group.Client, g *group.Group, id string) { @@ -804,10 +802,7 @@ func requestedTracks(c *webClient, requested []string, tracks []conn.UpTrack) ([ } func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error { - err := c.action(pushConnAction{g, id, up, tracks, replace}) - if err != nil { - return err - } + c.action(pushConnAction{g, id, up, tracks, replace}) return nil } @@ -854,9 +849,9 @@ func StartClient(conn *websocket.Conn) (err error) { } c := &webClient{ - id: m.Id, - actionCh: make(chan struct{}, 1), - done: make(chan struct{}), + id: m.Id, + actions: unbounded.New[any](), + done: make(chan struct{}), } defer close(c.done) @@ -996,11 +991,8 @@ func clientLoop(c *webClient, ws *websocket.Conn, versionError bool) error { case error: return m } - case <-c.actionCh: - c.actionMu.Lock() - actions := c.actions - c.actions = nil - c.actionMu.Unlock() + case <-c.actions.Ch: + actions := c.actions.Get() for _, a := range actions { err := handleAction(c, a) if err != nil { @@ -1090,7 +1082,7 @@ func pushDownConn(c *webClient, id string, up conn.Up, tracks []conn.UpTrack, re return nil } -func handleAction(c *webClient, a interface{}) error { +func handleAction(c *webClient, a any) error { switch a := a.(type) { case pushConnAction: if c.group == nil || c.group != a.group { @@ -1353,15 +1345,18 @@ func setPermissions(g *group.Group, id string, perm string) error { default: return group.UserError("unknown permission") } - return c.action(permissionsChangedAction{}) + c.action(permissionsChangedAction{}) + return nil } func (c *webClient) Kick(id string, user *string, message string) error { - return c.action(kickAction{id, user, message}) + c.action(kickAction{id, user, message}) + return nil } func (c *webClient) Joined(group, kind string) error { - return c.action(joinedAction{group, kind}) + c.action(joinedAction{group, kind}) + return nil } func kickClient(g *group.Group, id string, user *string, dest string, message string) error { @@ -2087,22 +2082,8 @@ func (c *webClient) Warn(oponly bool, message string) error { var ErrClientDead = errors.New("client is dead") -func (c *webClient) action(a interface{}) error { - c.actionMu.Lock() - empty := len(c.actions) == 0 - c.actions = append(c.actions, a) - c.actionMu.Unlock() - - if empty { - select { - case c.actionCh <- struct{}{}: - return nil - case <-c.done: - return ErrClientDead - default: - } - } - return nil +func (c *webClient) action(a interface{}) { + c.actions.Put(a) } func (c *webClient) write(m clientMessage) error { diff --git a/unbounded/unbounded.go b/unbounded/unbounded.go new file mode 100644 index 0000000..58f60b8 --- /dev/null +++ b/unbounded/unbounded.go @@ -0,0 +1,47 @@ +package unbounded + +import ( + "sync" +) + +// Type Channel implements an unbounded channel +type Channel[T any] struct { + // Ch triggers whenever the channel becomes non-empty + Ch chan struct{} + + mu sync.Mutex + queue []T +} + +// New creates a new unbounded channel +func New[T any]() *Channel[T] { + return &Channel[T]{ + Ch: make(chan struct{}, 1), + } +} + +// Put inserts a new element into ch. +// If ch was previously empty, it triggers ch.Ch. +func (ch *Channel[T]) Put(v T) { + ch.mu.Lock() + empty := len(ch.queue) == 0 + ch.queue = append(ch.queue, v) + ch.mu.Unlock() + + if empty { + select { + case ch.Ch <- struct{}{}: + default: + } + } +} + +// Get removes all the elements of ch. +// It is usually called when ch.Ch triggers, but may be called at any time. +func (ch *Channel[T]) Get() []T { + ch.mu.Lock() + defer ch.mu.Unlock() + queue := ch.queue + ch.queue = nil + return queue +} diff --git a/unbounded/unbounded_test.go b/unbounded/unbounded_test.go new file mode 100644 index 0000000..8cbd58f --- /dev/null +++ b/unbounded/unbounded_test.go @@ -0,0 +1,52 @@ +package unbounded + +import ( + "testing" + "time" +) + +func TestUnbounded(t *testing.T) { + ch := New[int]() + + go func() { + for i := 0; i < 1000; i++ { + ch.Put(i) + } + }() + + n := 0 + for n < 1000 { + <-ch.Ch + vs := ch.Get() + for _, v := range vs { + if n != v { + t.Errorf("Expected %v, got %v", n, v) + } + n++ + } + } + + go func() { + for i := 0; i < 1000; i++ { + ch.Put(i) + time.Sleep(time.Microsecond) + } + }() + + n = 0 + for n < 1000 { + <-ch.Ch + vs := ch.Get() + for _, v := range vs { + if n != v { + t.Errorf("Expected %v, got %v", n, v) + } + n++ + } + } + + vs := ch.Get() + if len(vs) != 0 { + t.Errorf("Channel is not empty (%v)", len(vs)) + } +}