1
Fork 0

Make unbounded channels explicit.

We used to have unbounded channels embedded within rtpconn
and webClient.  Make the structure explicit and testable.
This commit is contained in:
Juliusz Chroboczek 2023-12-09 16:56:33 +01:00
parent dcde4562f5
commit 00fbfafeeb
5 changed files with 128 additions and 61 deletions

View File

@ -23,6 +23,7 @@ import (
"github.com/jech/galene/packetcache"
"github.com/jech/galene/packetmap"
"github.com/jech/galene/rtptime"
"github.com/jech/galene/unbounded"
)
type bitrate struct {
@ -403,7 +404,7 @@ type rtpUpTrack struct {
jitter *jitter.Estimator
cname atomic.Value
actionCh chan struct{}
actions *unbounded.Channel[trackAction]
readerDone chan struct{}
mu sync.Mutex
@ -412,7 +413,6 @@ type rtpUpTrack struct {
srRTPTime uint32
local []conn.DownTrack
bufferedNACKs []uint16
actions []trackAction
}
const (
@ -427,17 +427,7 @@ type trackAction struct {
}
func (up *rtpUpTrack) action(action int, track conn.DownTrack) {
up.mu.Lock()
empty := len(up.actions) == 0
up.actions = append(up.actions, trackAction{action, track})
up.mu.Unlock()
if empty {
select {
case up.actionCh <- struct{}{}:
default:
}
}
up.actions.Put(trackAction{action, track})
}
func (up *rtpUpTrack) AddLocal(local conn.DownTrack) error {
@ -682,7 +672,7 @@ func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpCon
cache: packetcache.New(minPacketCache(remote)),
rate: estimator.New(time.Second),
jitter: jitter.New(remote.Codec().ClockRate),
actionCh: make(chan struct{}, 1),
actions: unbounded.New[trackAction](),
readerDone: make(chan struct{}),
}
@ -923,7 +913,7 @@ func maxUpBitrate(t *rtpUpTrack) uint64 {
// assume that lower spatial layers take up 1/5 of
// the throughput
if maxsid > 0 {
maxrate = sadd(maxrate, maxrate / 4)
maxrate = sadd(maxrate, maxrate/4)
}
// assume that each layer takes two times less
// throughput than the higher one. Then we've
@ -1003,7 +993,7 @@ func sendUpRTCP(up *rtpUpConnection) error {
}
ssrcs = append(ssrcs, uint32(t.track.SSRC()))
if t.Kind() == webrtc.RTPCodecTypeAudio {
rate = sadd(rate, 100 * 1024)
rate = sadd(rate, 100*1024)
} else if t.Label() == "l" {
rate = sadd(rate, group.LowBitrate)
} else {

View File

@ -31,11 +31,8 @@ func readLoop(track *rtpUpTrack) {
for {
select {
case <-track.actionCh:
track.mu.Lock()
actions := track.actions
track.actions = nil
track.mu.Unlock()
case <-track.actions.Ch:
actions := track.actions.Get()
for _, action := range actions {
switch action.action {
case trackActionAdd, trackActionDel:

View File

@ -20,6 +20,7 @@ import (
"github.com/jech/galene/group"
"github.com/jech/galene/ice"
"github.com/jech/galene/token"
"github.com/jech/galene/unbounded"
)
func errorToWSCloseMessage(id string, err error) (*clientMessage, []byte) {
@ -65,16 +66,11 @@ type webClient struct {
done chan struct{}
writeCh chan interface{}
writerDone chan struct{}
actionCh chan struct{}
actions *unbounded.Channel[any]
mu sync.Mutex
down map[string]*rtpDownConnection
up map[string]*rtpUpConnection
// action may be called with the group mutex taken, and therefore
// actions needs to use its own mutex.
actionMu sync.Mutex
actions []interface{}
}
func (c *webClient) Group() *group.Group {
@ -106,9 +102,10 @@ func (c *webClient) SetPermissions(perms []string) {
}
func (c *webClient) PushClient(group, kind, id string, username string, perms []string, data map[string]interface{}) error {
return c.action(pushClientAction{
c.action(pushClientAction{
group, kind, id, username, perms, data,
})
return nil
}
type clientMessage struct {
@ -733,7 +730,8 @@ func (c *webClient) setRequestedStream(down *rtpDownConnection, requested []stri
}
func (c *webClient) RequestConns(target group.Client, g *group.Group, id string) error {
return c.action(requestConnsAction{g, target, id})
c.action(requestConnsAction{g, target, id})
return nil
}
func requestConns(target group.Client, g *group.Group, id string) {
@ -804,10 +802,7 @@ func requestedTracks(c *webClient, requested []string, tracks []conn.UpTrack) ([
}
func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error {
err := c.action(pushConnAction{g, id, up, tracks, replace})
if err != nil {
return err
}
c.action(pushConnAction{g, id, up, tracks, replace})
return nil
}
@ -854,9 +849,9 @@ func StartClient(conn *websocket.Conn) (err error) {
}
c := &webClient{
id: m.Id,
actionCh: make(chan struct{}, 1),
done: make(chan struct{}),
id: m.Id,
actions: unbounded.New[any](),
done: make(chan struct{}),
}
defer close(c.done)
@ -996,11 +991,8 @@ func clientLoop(c *webClient, ws *websocket.Conn, versionError bool) error {
case error:
return m
}
case <-c.actionCh:
c.actionMu.Lock()
actions := c.actions
c.actions = nil
c.actionMu.Unlock()
case <-c.actions.Ch:
actions := c.actions.Get()
for _, a := range actions {
err := handleAction(c, a)
if err != nil {
@ -1090,7 +1082,7 @@ func pushDownConn(c *webClient, id string, up conn.Up, tracks []conn.UpTrack, re
return nil
}
func handleAction(c *webClient, a interface{}) error {
func handleAction(c *webClient, a any) error {
switch a := a.(type) {
case pushConnAction:
if c.group == nil || c.group != a.group {
@ -1353,15 +1345,18 @@ func setPermissions(g *group.Group, id string, perm string) error {
default:
return group.UserError("unknown permission")
}
return c.action(permissionsChangedAction{})
c.action(permissionsChangedAction{})
return nil
}
func (c *webClient) Kick(id string, user *string, message string) error {
return c.action(kickAction{id, user, message})
c.action(kickAction{id, user, message})
return nil
}
func (c *webClient) Joined(group, kind string) error {
return c.action(joinedAction{group, kind})
c.action(joinedAction{group, kind})
return nil
}
func kickClient(g *group.Group, id string, user *string, dest string, message string) error {
@ -2087,22 +2082,8 @@ func (c *webClient) Warn(oponly bool, message string) error {
var ErrClientDead = errors.New("client is dead")
func (c *webClient) action(a interface{}) error {
c.actionMu.Lock()
empty := len(c.actions) == 0
c.actions = append(c.actions, a)
c.actionMu.Unlock()
if empty {
select {
case c.actionCh <- struct{}{}:
return nil
case <-c.done:
return ErrClientDead
default:
}
}
return nil
func (c *webClient) action(a interface{}) {
c.actions.Put(a)
}
func (c *webClient) write(m clientMessage) error {

47
unbounded/unbounded.go Normal file
View File

@ -0,0 +1,47 @@
package unbounded
import (
"sync"
)
// Type Channel implements an unbounded channel
type Channel[T any] struct {
// Ch triggers whenever the channel becomes non-empty
Ch chan struct{}
mu sync.Mutex
queue []T
}
// New creates a new unbounded channel
func New[T any]() *Channel[T] {
return &Channel[T]{
Ch: make(chan struct{}, 1),
}
}
// Put inserts a new element into ch.
// If ch was previously empty, it triggers ch.Ch.
func (ch *Channel[T]) Put(v T) {
ch.mu.Lock()
empty := len(ch.queue) == 0
ch.queue = append(ch.queue, v)
ch.mu.Unlock()
if empty {
select {
case ch.Ch <- struct{}{}:
default:
}
}
}
// Get removes all the elements of ch.
// It is usually called when ch.Ch triggers, but may be called at any time.
func (ch *Channel[T]) Get() []T {
ch.mu.Lock()
defer ch.mu.Unlock()
queue := ch.queue
ch.queue = nil
return queue
}

View File

@ -0,0 +1,52 @@
package unbounded
import (
"testing"
"time"
)
func TestUnbounded(t *testing.T) {
ch := New[int]()
go func() {
for i := 0; i < 1000; i++ {
ch.Put(i)
}
}()
n := 0
for n < 1000 {
<-ch.Ch
vs := ch.Get()
for _, v := range vs {
if n != v {
t.Errorf("Expected %v, got %v", n, v)
}
n++
}
}
go func() {
for i := 0; i < 1000; i++ {
ch.Put(i)
time.Sleep(time.Microsecond)
}
}()
n = 0
for n < 1000 {
<-ch.Ch
vs := ch.Get()
for _, v := range vs {
if n != v {
t.Errorf("Expected %v, got %v", n, v)
}
n++
}
}
vs := ch.Get()
if len(vs) != 0 {
t.Errorf("Channel is not empty (%v)", len(vs))
}
}