1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-12-18 05:15:49 +01:00

Make unbounded channels explicit.

We used to have unbounded channels embedded within rtpconn
and webClient.  Make the structure explicit and testable.
This commit is contained in:
Juliusz Chroboczek 2023-12-09 16:56:33 +01:00
parent dcde4562f5
commit 00fbfafeeb
5 changed files with 128 additions and 61 deletions

View file

@ -23,6 +23,7 @@ import (
"github.com/jech/galene/packetcache" "github.com/jech/galene/packetcache"
"github.com/jech/galene/packetmap" "github.com/jech/galene/packetmap"
"github.com/jech/galene/rtptime" "github.com/jech/galene/rtptime"
"github.com/jech/galene/unbounded"
) )
type bitrate struct { type bitrate struct {
@ -403,7 +404,7 @@ type rtpUpTrack struct {
jitter *jitter.Estimator jitter *jitter.Estimator
cname atomic.Value cname atomic.Value
actionCh chan struct{} actions *unbounded.Channel[trackAction]
readerDone chan struct{} readerDone chan struct{}
mu sync.Mutex mu sync.Mutex
@ -412,7 +413,6 @@ type rtpUpTrack struct {
srRTPTime uint32 srRTPTime uint32
local []conn.DownTrack local []conn.DownTrack
bufferedNACKs []uint16 bufferedNACKs []uint16
actions []trackAction
} }
const ( const (
@ -427,17 +427,7 @@ type trackAction struct {
} }
func (up *rtpUpTrack) action(action int, track conn.DownTrack) { func (up *rtpUpTrack) action(action int, track conn.DownTrack) {
up.mu.Lock() up.actions.Put(trackAction{action, track})
empty := len(up.actions) == 0
up.actions = append(up.actions, trackAction{action, track})
up.mu.Unlock()
if empty {
select {
case up.actionCh <- struct{}{}:
default:
}
}
} }
func (up *rtpUpTrack) AddLocal(local conn.DownTrack) error { func (up *rtpUpTrack) AddLocal(local conn.DownTrack) error {
@ -682,7 +672,7 @@ func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpCon
cache: packetcache.New(minPacketCache(remote)), cache: packetcache.New(minPacketCache(remote)),
rate: estimator.New(time.Second), rate: estimator.New(time.Second),
jitter: jitter.New(remote.Codec().ClockRate), jitter: jitter.New(remote.Codec().ClockRate),
actionCh: make(chan struct{}, 1), actions: unbounded.New[trackAction](),
readerDone: make(chan struct{}), readerDone: make(chan struct{}),
} }
@ -923,7 +913,7 @@ func maxUpBitrate(t *rtpUpTrack) uint64 {
// assume that lower spatial layers take up 1/5 of // assume that lower spatial layers take up 1/5 of
// the throughput // the throughput
if maxsid > 0 { if maxsid > 0 {
maxrate = sadd(maxrate, maxrate / 4) maxrate = sadd(maxrate, maxrate/4)
} }
// assume that each layer takes two times less // assume that each layer takes two times less
// throughput than the higher one. Then we've // throughput than the higher one. Then we've
@ -1003,7 +993,7 @@ func sendUpRTCP(up *rtpUpConnection) error {
} }
ssrcs = append(ssrcs, uint32(t.track.SSRC())) ssrcs = append(ssrcs, uint32(t.track.SSRC()))
if t.Kind() == webrtc.RTPCodecTypeAudio { if t.Kind() == webrtc.RTPCodecTypeAudio {
rate = sadd(rate, 100 * 1024) rate = sadd(rate, 100*1024)
} else if t.Label() == "l" { } else if t.Label() == "l" {
rate = sadd(rate, group.LowBitrate) rate = sadd(rate, group.LowBitrate)
} else { } else {

View file

@ -31,11 +31,8 @@ func readLoop(track *rtpUpTrack) {
for { for {
select { select {
case <-track.actionCh: case <-track.actions.Ch:
track.mu.Lock() actions := track.actions.Get()
actions := track.actions
track.actions = nil
track.mu.Unlock()
for _, action := range actions { for _, action := range actions {
switch action.action { switch action.action {
case trackActionAdd, trackActionDel: case trackActionAdd, trackActionDel:

View file

@ -20,6 +20,7 @@ import (
"github.com/jech/galene/group" "github.com/jech/galene/group"
"github.com/jech/galene/ice" "github.com/jech/galene/ice"
"github.com/jech/galene/token" "github.com/jech/galene/token"
"github.com/jech/galene/unbounded"
) )
func errorToWSCloseMessage(id string, err error) (*clientMessage, []byte) { func errorToWSCloseMessage(id string, err error) (*clientMessage, []byte) {
@ -65,16 +66,11 @@ type webClient struct {
done chan struct{} done chan struct{}
writeCh chan interface{} writeCh chan interface{}
writerDone chan struct{} writerDone chan struct{}
actionCh chan struct{} actions *unbounded.Channel[any]
mu sync.Mutex mu sync.Mutex
down map[string]*rtpDownConnection down map[string]*rtpDownConnection
up map[string]*rtpUpConnection up map[string]*rtpUpConnection
// action may be called with the group mutex taken, and therefore
// actions needs to use its own mutex.
actionMu sync.Mutex
actions []interface{}
} }
func (c *webClient) Group() *group.Group { func (c *webClient) Group() *group.Group {
@ -106,9 +102,10 @@ func (c *webClient) SetPermissions(perms []string) {
} }
func (c *webClient) PushClient(group, kind, id string, username string, perms []string, data map[string]interface{}) error { func (c *webClient) PushClient(group, kind, id string, username string, perms []string, data map[string]interface{}) error {
return c.action(pushClientAction{ c.action(pushClientAction{
group, kind, id, username, perms, data, group, kind, id, username, perms, data,
}) })
return nil
} }
type clientMessage struct { type clientMessage struct {
@ -733,7 +730,8 @@ func (c *webClient) setRequestedStream(down *rtpDownConnection, requested []stri
} }
func (c *webClient) RequestConns(target group.Client, g *group.Group, id string) error { func (c *webClient) RequestConns(target group.Client, g *group.Group, id string) error {
return c.action(requestConnsAction{g, target, id}) c.action(requestConnsAction{g, target, id})
return nil
} }
func requestConns(target group.Client, g *group.Group, id string) { func requestConns(target group.Client, g *group.Group, id string) {
@ -804,10 +802,7 @@ func requestedTracks(c *webClient, requested []string, tracks []conn.UpTrack) ([
} }
func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error { func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error {
err := c.action(pushConnAction{g, id, up, tracks, replace}) c.action(pushConnAction{g, id, up, tracks, replace})
if err != nil {
return err
}
return nil return nil
} }
@ -855,7 +850,7 @@ func StartClient(conn *websocket.Conn) (err error) {
c := &webClient{ c := &webClient{
id: m.Id, id: m.Id,
actionCh: make(chan struct{}, 1), actions: unbounded.New[any](),
done: make(chan struct{}), done: make(chan struct{}),
} }
@ -996,11 +991,8 @@ func clientLoop(c *webClient, ws *websocket.Conn, versionError bool) error {
case error: case error:
return m return m
} }
case <-c.actionCh: case <-c.actions.Ch:
c.actionMu.Lock() actions := c.actions.Get()
actions := c.actions
c.actions = nil
c.actionMu.Unlock()
for _, a := range actions { for _, a := range actions {
err := handleAction(c, a) err := handleAction(c, a)
if err != nil { if err != nil {
@ -1090,7 +1082,7 @@ func pushDownConn(c *webClient, id string, up conn.Up, tracks []conn.UpTrack, re
return nil return nil
} }
func handleAction(c *webClient, a interface{}) error { func handleAction(c *webClient, a any) error {
switch a := a.(type) { switch a := a.(type) {
case pushConnAction: case pushConnAction:
if c.group == nil || c.group != a.group { if c.group == nil || c.group != a.group {
@ -1353,15 +1345,18 @@ func setPermissions(g *group.Group, id string, perm string) error {
default: default:
return group.UserError("unknown permission") return group.UserError("unknown permission")
} }
return c.action(permissionsChangedAction{}) c.action(permissionsChangedAction{})
return nil
} }
func (c *webClient) Kick(id string, user *string, message string) error { func (c *webClient) Kick(id string, user *string, message string) error {
return c.action(kickAction{id, user, message}) c.action(kickAction{id, user, message})
return nil
} }
func (c *webClient) Joined(group, kind string) error { func (c *webClient) Joined(group, kind string) error {
return c.action(joinedAction{group, kind}) c.action(joinedAction{group, kind})
return nil
} }
func kickClient(g *group.Group, id string, user *string, dest string, message string) error { func kickClient(g *group.Group, id string, user *string, dest string, message string) error {
@ -2087,22 +2082,8 @@ func (c *webClient) Warn(oponly bool, message string) error {
var ErrClientDead = errors.New("client is dead") var ErrClientDead = errors.New("client is dead")
func (c *webClient) action(a interface{}) error { func (c *webClient) action(a interface{}) {
c.actionMu.Lock() c.actions.Put(a)
empty := len(c.actions) == 0
c.actions = append(c.actions, a)
c.actionMu.Unlock()
if empty {
select {
case c.actionCh <- struct{}{}:
return nil
case <-c.done:
return ErrClientDead
default:
}
}
return nil
} }
func (c *webClient) write(m clientMessage) error { func (c *webClient) write(m clientMessage) error {

47
unbounded/unbounded.go Normal file
View file

@ -0,0 +1,47 @@
package unbounded
import (
"sync"
)
// Type Channel implements an unbounded channel
type Channel[T any] struct {
// Ch triggers whenever the channel becomes non-empty
Ch chan struct{}
mu sync.Mutex
queue []T
}
// New creates a new unbounded channel
func New[T any]() *Channel[T] {
return &Channel[T]{
Ch: make(chan struct{}, 1),
}
}
// Put inserts a new element into ch.
// If ch was previously empty, it triggers ch.Ch.
func (ch *Channel[T]) Put(v T) {
ch.mu.Lock()
empty := len(ch.queue) == 0
ch.queue = append(ch.queue, v)
ch.mu.Unlock()
if empty {
select {
case ch.Ch <- struct{}{}:
default:
}
}
}
// Get removes all the elements of ch.
// It is usually called when ch.Ch triggers, but may be called at any time.
func (ch *Channel[T]) Get() []T {
ch.mu.Lock()
defer ch.mu.Unlock()
queue := ch.queue
ch.queue = nil
return queue
}

View file

@ -0,0 +1,52 @@
package unbounded
import (
"testing"
"time"
)
func TestUnbounded(t *testing.T) {
ch := New[int]()
go func() {
for i := 0; i < 1000; i++ {
ch.Put(i)
}
}()
n := 0
for n < 1000 {
<-ch.Ch
vs := ch.Get()
for _, v := range vs {
if n != v {
t.Errorf("Expected %v, got %v", n, v)
}
n++
}
}
go func() {
for i := 0; i < 1000; i++ {
ch.Put(i)
time.Sleep(time.Microsecond)
}
}()
n = 0
for n < 1000 {
<-ch.Ch
vs := ch.Get()
for _, v := range vs {
if n != v {
t.Errorf("Expected %v, got %v", n, v)
}
n++
}
}
vs := ch.Get()
if len(vs) != 0 {
t.Errorf("Channel is not empty (%v)", len(vs))
}
}