mirror of
https://github.com/jech/galene.git
synced 2024-12-18 05:15:49 +01:00
Make unbounded channels explicit.
We used to have unbounded channels embedded within rtpconn and webClient. Make the structure explicit and testable.
This commit is contained in:
parent
dcde4562f5
commit
00fbfafeeb
5 changed files with 128 additions and 61 deletions
|
@ -23,6 +23,7 @@ import (
|
||||||
"github.com/jech/galene/packetcache"
|
"github.com/jech/galene/packetcache"
|
||||||
"github.com/jech/galene/packetmap"
|
"github.com/jech/galene/packetmap"
|
||||||
"github.com/jech/galene/rtptime"
|
"github.com/jech/galene/rtptime"
|
||||||
|
"github.com/jech/galene/unbounded"
|
||||||
)
|
)
|
||||||
|
|
||||||
type bitrate struct {
|
type bitrate struct {
|
||||||
|
@ -403,7 +404,7 @@ type rtpUpTrack struct {
|
||||||
jitter *jitter.Estimator
|
jitter *jitter.Estimator
|
||||||
cname atomic.Value
|
cname atomic.Value
|
||||||
|
|
||||||
actionCh chan struct{}
|
actions *unbounded.Channel[trackAction]
|
||||||
readerDone chan struct{}
|
readerDone chan struct{}
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
@ -412,7 +413,6 @@ type rtpUpTrack struct {
|
||||||
srRTPTime uint32
|
srRTPTime uint32
|
||||||
local []conn.DownTrack
|
local []conn.DownTrack
|
||||||
bufferedNACKs []uint16
|
bufferedNACKs []uint16
|
||||||
actions []trackAction
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -427,17 +427,7 @@ type trackAction struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (up *rtpUpTrack) action(action int, track conn.DownTrack) {
|
func (up *rtpUpTrack) action(action int, track conn.DownTrack) {
|
||||||
up.mu.Lock()
|
up.actions.Put(trackAction{action, track})
|
||||||
empty := len(up.actions) == 0
|
|
||||||
up.actions = append(up.actions, trackAction{action, track})
|
|
||||||
up.mu.Unlock()
|
|
||||||
|
|
||||||
if empty {
|
|
||||||
select {
|
|
||||||
case up.actionCh <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (up *rtpUpTrack) AddLocal(local conn.DownTrack) error {
|
func (up *rtpUpTrack) AddLocal(local conn.DownTrack) error {
|
||||||
|
@ -682,7 +672,7 @@ func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpCon
|
||||||
cache: packetcache.New(minPacketCache(remote)),
|
cache: packetcache.New(minPacketCache(remote)),
|
||||||
rate: estimator.New(time.Second),
|
rate: estimator.New(time.Second),
|
||||||
jitter: jitter.New(remote.Codec().ClockRate),
|
jitter: jitter.New(remote.Codec().ClockRate),
|
||||||
actionCh: make(chan struct{}, 1),
|
actions: unbounded.New[trackAction](),
|
||||||
readerDone: make(chan struct{}),
|
readerDone: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -923,7 +913,7 @@ func maxUpBitrate(t *rtpUpTrack) uint64 {
|
||||||
// assume that lower spatial layers take up 1/5 of
|
// assume that lower spatial layers take up 1/5 of
|
||||||
// the throughput
|
// the throughput
|
||||||
if maxsid > 0 {
|
if maxsid > 0 {
|
||||||
maxrate = sadd(maxrate, maxrate / 4)
|
maxrate = sadd(maxrate, maxrate/4)
|
||||||
}
|
}
|
||||||
// assume that each layer takes two times less
|
// assume that each layer takes two times less
|
||||||
// throughput than the higher one. Then we've
|
// throughput than the higher one. Then we've
|
||||||
|
@ -1003,7 +993,7 @@ func sendUpRTCP(up *rtpUpConnection) error {
|
||||||
}
|
}
|
||||||
ssrcs = append(ssrcs, uint32(t.track.SSRC()))
|
ssrcs = append(ssrcs, uint32(t.track.SSRC()))
|
||||||
if t.Kind() == webrtc.RTPCodecTypeAudio {
|
if t.Kind() == webrtc.RTPCodecTypeAudio {
|
||||||
rate = sadd(rate, 100 * 1024)
|
rate = sadd(rate, 100*1024)
|
||||||
} else if t.Label() == "l" {
|
} else if t.Label() == "l" {
|
||||||
rate = sadd(rate, group.LowBitrate)
|
rate = sadd(rate, group.LowBitrate)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -31,11 +31,8 @@ func readLoop(track *rtpUpTrack) {
|
||||||
for {
|
for {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-track.actionCh:
|
case <-track.actions.Ch:
|
||||||
track.mu.Lock()
|
actions := track.actions.Get()
|
||||||
actions := track.actions
|
|
||||||
track.actions = nil
|
|
||||||
track.mu.Unlock()
|
|
||||||
for _, action := range actions {
|
for _, action := range actions {
|
||||||
switch action.action {
|
switch action.action {
|
||||||
case trackActionAdd, trackActionDel:
|
case trackActionAdd, trackActionDel:
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"github.com/jech/galene/group"
|
"github.com/jech/galene/group"
|
||||||
"github.com/jech/galene/ice"
|
"github.com/jech/galene/ice"
|
||||||
"github.com/jech/galene/token"
|
"github.com/jech/galene/token"
|
||||||
|
"github.com/jech/galene/unbounded"
|
||||||
)
|
)
|
||||||
|
|
||||||
func errorToWSCloseMessage(id string, err error) (*clientMessage, []byte) {
|
func errorToWSCloseMessage(id string, err error) (*clientMessage, []byte) {
|
||||||
|
@ -65,16 +66,11 @@ type webClient struct {
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
writeCh chan interface{}
|
writeCh chan interface{}
|
||||||
writerDone chan struct{}
|
writerDone chan struct{}
|
||||||
actionCh chan struct{}
|
actions *unbounded.Channel[any]
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
down map[string]*rtpDownConnection
|
down map[string]*rtpDownConnection
|
||||||
up map[string]*rtpUpConnection
|
up map[string]*rtpUpConnection
|
||||||
|
|
||||||
// action may be called with the group mutex taken, and therefore
|
|
||||||
// actions needs to use its own mutex.
|
|
||||||
actionMu sync.Mutex
|
|
||||||
actions []interface{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *webClient) Group() *group.Group {
|
func (c *webClient) Group() *group.Group {
|
||||||
|
@ -106,9 +102,10 @@ func (c *webClient) SetPermissions(perms []string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *webClient) PushClient(group, kind, id string, username string, perms []string, data map[string]interface{}) error {
|
func (c *webClient) PushClient(group, kind, id string, username string, perms []string, data map[string]interface{}) error {
|
||||||
return c.action(pushClientAction{
|
c.action(pushClientAction{
|
||||||
group, kind, id, username, perms, data,
|
group, kind, id, username, perms, data,
|
||||||
})
|
})
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientMessage struct {
|
type clientMessage struct {
|
||||||
|
@ -733,7 +730,8 @@ func (c *webClient) setRequestedStream(down *rtpDownConnection, requested []stri
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *webClient) RequestConns(target group.Client, g *group.Group, id string) error {
|
func (c *webClient) RequestConns(target group.Client, g *group.Group, id string) error {
|
||||||
return c.action(requestConnsAction{g, target, id})
|
c.action(requestConnsAction{g, target, id})
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestConns(target group.Client, g *group.Group, id string) {
|
func requestConns(target group.Client, g *group.Group, id string) {
|
||||||
|
@ -804,10 +802,7 @@ func requestedTracks(c *webClient, requested []string, tracks []conn.UpTrack) ([
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error {
|
func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error {
|
||||||
err := c.action(pushConnAction{g, id, up, tracks, replace})
|
c.action(pushConnAction{g, id, up, tracks, replace})
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -854,9 +849,9 @@ func StartClient(conn *websocket.Conn) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &webClient{
|
c := &webClient{
|
||||||
id: m.Id,
|
id: m.Id,
|
||||||
actionCh: make(chan struct{}, 1),
|
actions: unbounded.New[any](),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
defer close(c.done)
|
defer close(c.done)
|
||||||
|
@ -996,11 +991,8 @@ func clientLoop(c *webClient, ws *websocket.Conn, versionError bool) error {
|
||||||
case error:
|
case error:
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
case <-c.actionCh:
|
case <-c.actions.Ch:
|
||||||
c.actionMu.Lock()
|
actions := c.actions.Get()
|
||||||
actions := c.actions
|
|
||||||
c.actions = nil
|
|
||||||
c.actionMu.Unlock()
|
|
||||||
for _, a := range actions {
|
for _, a := range actions {
|
||||||
err := handleAction(c, a)
|
err := handleAction(c, a)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1090,7 +1082,7 @@ func pushDownConn(c *webClient, id string, up conn.Up, tracks []conn.UpTrack, re
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleAction(c *webClient, a interface{}) error {
|
func handleAction(c *webClient, a any) error {
|
||||||
switch a := a.(type) {
|
switch a := a.(type) {
|
||||||
case pushConnAction:
|
case pushConnAction:
|
||||||
if c.group == nil || c.group != a.group {
|
if c.group == nil || c.group != a.group {
|
||||||
|
@ -1353,15 +1345,18 @@ func setPermissions(g *group.Group, id string, perm string) error {
|
||||||
default:
|
default:
|
||||||
return group.UserError("unknown permission")
|
return group.UserError("unknown permission")
|
||||||
}
|
}
|
||||||
return c.action(permissionsChangedAction{})
|
c.action(permissionsChangedAction{})
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *webClient) Kick(id string, user *string, message string) error {
|
func (c *webClient) Kick(id string, user *string, message string) error {
|
||||||
return c.action(kickAction{id, user, message})
|
c.action(kickAction{id, user, message})
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *webClient) Joined(group, kind string) error {
|
func (c *webClient) Joined(group, kind string) error {
|
||||||
return c.action(joinedAction{group, kind})
|
c.action(joinedAction{group, kind})
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func kickClient(g *group.Group, id string, user *string, dest string, message string) error {
|
func kickClient(g *group.Group, id string, user *string, dest string, message string) error {
|
||||||
|
@ -2087,22 +2082,8 @@ func (c *webClient) Warn(oponly bool, message string) error {
|
||||||
|
|
||||||
var ErrClientDead = errors.New("client is dead")
|
var ErrClientDead = errors.New("client is dead")
|
||||||
|
|
||||||
func (c *webClient) action(a interface{}) error {
|
func (c *webClient) action(a interface{}) {
|
||||||
c.actionMu.Lock()
|
c.actions.Put(a)
|
||||||
empty := len(c.actions) == 0
|
|
||||||
c.actions = append(c.actions, a)
|
|
||||||
c.actionMu.Unlock()
|
|
||||||
|
|
||||||
if empty {
|
|
||||||
select {
|
|
||||||
case c.actionCh <- struct{}{}:
|
|
||||||
return nil
|
|
||||||
case <-c.done:
|
|
||||||
return ErrClientDead
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *webClient) write(m clientMessage) error {
|
func (c *webClient) write(m clientMessage) error {
|
||||||
|
|
47
unbounded/unbounded.go
Normal file
47
unbounded/unbounded.go
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
package unbounded
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Type Channel implements an unbounded channel
|
||||||
|
type Channel[T any] struct {
|
||||||
|
// Ch triggers whenever the channel becomes non-empty
|
||||||
|
Ch chan struct{}
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
queue []T
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new unbounded channel
|
||||||
|
func New[T any]() *Channel[T] {
|
||||||
|
return &Channel[T]{
|
||||||
|
Ch: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put inserts a new element into ch.
|
||||||
|
// If ch was previously empty, it triggers ch.Ch.
|
||||||
|
func (ch *Channel[T]) Put(v T) {
|
||||||
|
ch.mu.Lock()
|
||||||
|
empty := len(ch.queue) == 0
|
||||||
|
ch.queue = append(ch.queue, v)
|
||||||
|
ch.mu.Unlock()
|
||||||
|
|
||||||
|
if empty {
|
||||||
|
select {
|
||||||
|
case ch.Ch <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get removes all the elements of ch.
|
||||||
|
// It is usually called when ch.Ch triggers, but may be called at any time.
|
||||||
|
func (ch *Channel[T]) Get() []T {
|
||||||
|
ch.mu.Lock()
|
||||||
|
defer ch.mu.Unlock()
|
||||||
|
queue := ch.queue
|
||||||
|
ch.queue = nil
|
||||||
|
return queue
|
||||||
|
}
|
52
unbounded/unbounded_test.go
Normal file
52
unbounded/unbounded_test.go
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
package unbounded
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUnbounded(t *testing.T) {
|
||||||
|
ch := New[int]()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
ch.Put(i)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
n := 0
|
||||||
|
for n < 1000 {
|
||||||
|
<-ch.Ch
|
||||||
|
vs := ch.Get()
|
||||||
|
for _, v := range vs {
|
||||||
|
if n != v {
|
||||||
|
t.Errorf("Expected %v, got %v", n, v)
|
||||||
|
}
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
ch.Put(i)
|
||||||
|
time.Sleep(time.Microsecond)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
n = 0
|
||||||
|
for n < 1000 {
|
||||||
|
<-ch.Ch
|
||||||
|
vs := ch.Get()
|
||||||
|
for _, v := range vs {
|
||||||
|
if n != v {
|
||||||
|
t.Errorf("Expected %v, got %v", n, v)
|
||||||
|
}
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vs := ch.Get()
|
||||||
|
if len(vs) != 0 {
|
||||||
|
t.Errorf("Channel is not empty (%v)", len(vs))
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue