1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-11-22 16:45:58 +01:00

Use a separate track for each down connection.

This commit is contained in:
Juliusz Chroboczek 2020-04-28 14:54:50 +02:00
parent 9c9748b888
commit 038ab46d2b
2 changed files with 143 additions and 103 deletions

193
client.go
View file

@ -208,8 +208,8 @@ func startClient(conn *websocket.Conn) (err error) {
} }
func getUpConn(c *client, id string) *upConnection { func getUpConn(c *client, id string) *upConnection {
c.group.mu.Lock() c.mu.Lock()
defer c.group.mu.Unlock() defer c.mu.Unlock()
if c.up == nil { if c.up == nil {
return nil return nil
@ -222,8 +222,8 @@ func getUpConn(c *client, id string) *upConnection {
} }
func getUpConns(c *client) []string { func getUpConns(c *client) []string {
c.group.mu.Lock() c.mu.Lock()
defer c.group.mu.Unlock() defer c.mu.Unlock()
up := make([]string, 0, len(c.up)) up := make([]string, 0, len(c.up))
for id := range c.up { for id := range c.up {
up = append(up, id) up = append(up, id)
@ -262,34 +262,24 @@ func addUpConn(c *client, id string) (*upConnection, error) {
}) })
pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) { pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) {
local, err := pc.NewTrack( c.mu.Lock()
remote.PayloadType(),
remote.SSRC(),
remote.ID(),
remote.Label())
if err != nil {
log.Printf("%v", err)
return
}
c.group.mu.Lock()
u, ok := c.up[id] u, ok := c.up[id]
if !ok { if !ok {
log.Printf("Unknown connection") log.Printf("Unknown connection")
c.group.mu.Unlock() c.mu.Unlock()
return return
} }
u.pairs = append(u.pairs, trackPair{ track := &upTrack{
remote: remote, track: remote,
local: local,
maxBitrate: ^uint64(0), maxBitrate: ^uint64(0),
}) }
done := len(u.pairs) >= u.trackCount u.tracks = append(u.tracks, track)
c.group.mu.Unlock() done := len(u.tracks) >= u.trackCount
c.mu.Unlock()
clients := c.group.getClients(c) clients := c.group.getClients(c)
for _, cc := range clients { for _, cc := range clients {
cc.action(addTrackAction{id, local, u, done}) cc.action(addTrackAction{track, u, done})
if done && u.label != "" { if done && u.label != "" {
cc.action(addLabelAction{id, u.label}) cc.action(addLabelAction{id, u.label})
} }
@ -313,9 +303,12 @@ func addUpConn(c *client, id string) (*upConnection, error) {
continue continue
} }
err = local.WriteRTP(&packet) local := track.getLocal()
if err != nil && err != io.ErrClosedPipe { for _, l := range local {
log.Printf("%v", err) err := l.track.WriteRTP(&packet)
if err != nil {
log.Printf("%v", err)
}
} }
} }
}() }()
@ -323,8 +316,8 @@ func addUpConn(c *client, id string) (*upConnection, error) {
conn := &upConnection{id: id, pc: pc} conn := &upConnection{id: id, pc: pc}
c.group.mu.Lock() c.mu.Lock()
defer c.group.mu.Unlock() defer c.mu.Unlock()
if c.up == nil { if c.up == nil {
c.up = make(map[string]*upConnection) c.up = make(map[string]*upConnection)
@ -338,8 +331,8 @@ func addUpConn(c *client, id string) (*upConnection, error) {
} }
func delUpConn(c *client, id string) { func delUpConn(c *client, id string) {
c.group.mu.Lock() c.mu.Lock()
defer c.group.mu.Unlock() defer c.mu.Unlock()
if c.up == nil { if c.up == nil {
log.Printf("Deleting unknown connection") log.Printf("Deleting unknown connection")
@ -357,16 +350,20 @@ func delUpConn(c *client, id string) {
id string id string
} }
cids := make([]clientId, 0) cids := make([]clientId, 0)
for _, cc := range c.group.clients {
clients := c.group.getClients(c)
for _, cc := range clients {
cc.mu.Lock()
for _, otherconn := range cc.down { for _, otherconn := range cc.down {
if otherconn.remote == conn { if otherconn.remote == conn {
cids = append(cids, clientId{cc, otherconn.id}) cids = append(cids, clientId{cc, otherconn.id})
} }
} }
cc.mu.Unlock()
} }
for _, cid := range cids { for _, cid := range cids {
cid.client.action(delPCAction{cid.id}) cid.client.action(delConnAction{cid.id})
} }
conn.pc.Close() conn.pc.Close()
@ -378,8 +375,8 @@ func getDownConn(c *client, id string) *downConnection {
return nil return nil
} }
c.group.mu.Lock() c.mu.Lock()
defer c.group.mu.Unlock() defer c.mu.Unlock()
conn := c.down[id] conn := c.down[id]
if conn == nil { if conn == nil {
return nil return nil
@ -406,8 +403,8 @@ func addDownConn(c *client, id string, remote *upConnection) (*downConnection, e
} }
conn := &downConnection{id: id, pc: pc, remote: remote} conn := &downConnection{id: id, pc: pc, remote: remote}
c.group.mu.Lock() c.mu.Lock()
defer c.group.mu.Unlock() defer c.mu.Unlock()
if c.down[id] != nil { if c.down[id] != nil {
conn.pc.Close() conn.pc.Close()
return nil, errors.New("Adding duplicate connection") return nil, errors.New("Adding duplicate connection")
@ -417,8 +414,8 @@ func addDownConn(c *client, id string, remote *upConnection) (*downConnection, e
} }
func delDownConn(c *client, id string) { func delDownConn(c *client, id string) {
c.group.mu.Lock() c.mu.Lock()
defer c.group.mu.Unlock() defer c.mu.Unlock()
if c.down == nil { if c.down == nil {
log.Printf("Deleting unknown connection") log.Printf("Deleting unknown connection")
@ -429,31 +426,49 @@ func delDownConn(c *client, id string) {
log.Printf("Deleting unknown connection") log.Printf("Deleting unknown connection")
return return
} }
for _, track := range conn.tracks {
found := track.remote.delLocal(track)
if !found {
log.Printf("Couldn't find remote track")
}
track.remote = nil
}
conn.pc.Close() conn.pc.Close()
delete(c.down, id) delete(c.down, id)
} }
func addDownTrack(c *client, id string, track *webrtc.Track, remote *upConnection) (*downConnection, *webrtc.RTPSender, error) { func addDownTrack(c *client, id string, remoteTrack *upTrack, remoteConn *upConnection) (*downConnection, *webrtc.RTPSender, error) {
conn := getDownConn(c, id) conn := getDownConn(c, id)
if conn == nil { if conn == nil {
var err error var err error
conn, err = addDownConn(c, id, remote) conn, err = addDownConn(c, id, remoteConn)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
} }
s, err := conn.pc.AddTrack(track) local, err := conn.pc.NewTrack(
remoteTrack.track.PayloadType(),
remoteTrack.track.SSRC(),
remoteTrack.track.ID(),
remoteTrack.track.Label(),
)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
conn.tracks = append(conn.tracks, s, err := conn.pc.AddTrack(local)
downTrack{track.SSRC(), new(timeStampedBitrate)}, if err != nil {
) return nil, nil, err
}
go rtcpListener(c.group, conn, s, track := &downTrack{local, remoteTrack, new(timeStampedBitrate)}
conn.tracks[len(conn.tracks)-1].maxBitrate)
conn.tracks = append(conn.tracks, track)
remoteTrack.addLocal(track)
go rtcpListener(c.group, conn, s, track.maxBitrate)
return conn, s, nil return conn, s, nil
} }
@ -545,44 +560,26 @@ func splitBitrate(bitrate uint32, audio, video bool) (uint32, uint32) {
return audioRate, bitrate - audioRate return audioRate, bitrate - audioRate
} }
func updateUpBitrate(g *group, up *upConnection) { func updateUpBitrate(up *upConnection) {
for i := range up.pairs { for _, t := range up.tracks {
up.pairs[i].maxBitrate = ^uint64(0) t.maxBitrate = ^uint64(0)
} }
now := msSinceEpoch() now := msSinceEpoch()
g.Range(func(c *client) bool { for _, track := range up.tracks {
for _, down := range c.down { local := track.getLocal()
if down.remote == up { for _, l := range local {
for _, dt := range down.tracks { ms := atomic.LoadUint64(&l.maxBitrate.timestamp)
ms := atomic.LoadUint64( bitrate := atomic.LoadUint64(&l.maxBitrate.bitrate)
&dt.maxBitrate.timestamp, if now-ms > 5000 || bitrate == 0 {
) continue
bitrate := atomic.LoadUint64( }
&dt.maxBitrate.bitrate, if track.maxBitrate > bitrate {
) track.maxBitrate = bitrate
if bitrate == 0 {
continue
}
if now-ms > 5000 {
continue
}
for i, p := range up.pairs {
if p.local.SSRC() == dt.ssrc {
if p.maxBitrate > bitrate {
up.pairs[i].maxBitrate = bitrate
break
}
}
}
}
} }
} }
return true }
})
} }
func sendPLI(pc *webrtc.PeerConnection, ssrc uint32) error { func sendPLI(pc *webrtc.PeerConnection, ssrc uint32) error {
@ -779,18 +776,18 @@ func clientLoop(c *client, conn *websocket.Conn) error {
case addTrackAction: case addTrackAction:
down, _, err := down, _, err :=
addDownTrack( addDownTrack(
c, a.id, a.track, c, a.remote.id, a.track,
a.remote) a.remote)
if err != nil { if err != nil {
return err return err
} }
if a.done { if a.done {
err = negotiate(c, a.id, down.pc) err = negotiate(c, a.remote.id, down.pc)
if err != nil { if err != nil {
return err return err
} }
} }
case delPCAction: case delConnAction:
c.write(clientMessage{ c.write(clientMessage{
Type: "close", Type: "close",
Id: a.id, Id: a.id,
@ -805,11 +802,10 @@ func clientLoop(c *client, conn *websocket.Conn) error {
case pushTracksAction: case pushTracksAction:
for _, u := range c.up { for _, u := range c.up {
var done bool var done bool
for i, p := range u.pairs { for i, t := range u.tracks {
done = i >= u.trackCount-1 done = i >= u.trackCount-1
a.c.action(addTrackAction{ a.c.action(addTrackAction{
u.id, p.local, u, t, u, done,
done,
}) })
} }
if done && u.label != "" { if done && u.label != "" {
@ -931,22 +927,35 @@ func handleClientMessage(c *client, m clientMessage) error {
} }
func sendRateUpdate(c *client) { func sendRateUpdate(c *client) {
type remb struct {
pc *webrtc.PeerConnection
ssrc uint32
bitrate uint64
}
rembs := make([]remb, 0)
c.mu.Lock()
for _, u := range c.up { for _, u := range c.up {
updateUpBitrate(c.group, u) updateUpBitrate(u)
for _, p := range u.pairs { for _, t := range u.tracks {
bitrate := p.maxBitrate bitrate := t.maxBitrate
if bitrate != ^uint64(0) { if bitrate != ^uint64(0) {
if bitrate < 6000 { if bitrate < 6000 {
bitrate = 6000 bitrate = 6000
} }
err := sendREMB(u.pc, p.remote.SSRC(), rembs = append(rembs,
uint64(bitrate)) remb{u.pc, t.track.SSRC(), bitrate})
if err != nil {
log.Printf("sendREMB: %v", err)
}
} }
} }
} }
c.mu.Unlock()
for _, r := range rembs {
err := sendREMB(r.pc, r.ssrc, r.bitrate)
if err != nil {
log.Printf("sendREMB: %v", err)
}
}
} }
func clientReader(conn *websocket.Conn, read chan<- interface{}, done <-chan struct{}) { func clientReader(conn *websocket.Conn, read chan<- interface{}, done <-chan struct{}) {

View file

@ -17,9 +17,38 @@ import (
"github.com/pion/webrtc/v2" "github.com/pion/webrtc/v2"
) )
type trackPair struct { type upTrack struct {
remote, local *webrtc.Track track *webrtc.Track
maxBitrate uint64 maxBitrate uint64
mu sync.Mutex
local []*downTrack
}
func (up *upTrack) addLocal(local *downTrack) {
up.mu.Lock()
defer up.mu.Unlock()
up.local = append(up.local, local)
}
func (up *upTrack) delLocal(local *downTrack) bool {
up.mu.Lock()
defer up.mu.Unlock()
for i, l := range up.local {
if l == local {
up.local = append(up.local[:i], up.local[i+1:]...)
return true
}
}
return false
}
func (up *upTrack) getLocal() []*downTrack {
up.mu.Lock()
defer up.mu.Unlock()
local := make([]*downTrack, len(up.local))
copy(local, up.local)
return local
} }
type upConnection struct { type upConnection struct {
@ -27,7 +56,7 @@ type upConnection struct {
label string label string
pc *webrtc.PeerConnection pc *webrtc.PeerConnection
trackCount int trackCount int
pairs []trackPair tracks []*upTrack
} }
type timeStampedBitrate struct { type timeStampedBitrate struct {
@ -35,7 +64,8 @@ type timeStampedBitrate struct {
timestamp uint64 timestamp uint64
} }
type downTrack struct { type downTrack struct {
ssrc uint32 track *webrtc.Track
remote *upTrack
maxBitrate *timeStampedBitrate maxBitrate *timeStampedBitrate
} }
@ -43,7 +73,7 @@ type downConnection struct {
id string id string
pc *webrtc.PeerConnection pc *webrtc.PeerConnection
remote *upConnection remote *upConnection
tracks []downTrack tracks []*downTrack
} }
type client struct { type client struct {
@ -55,8 +85,10 @@ type client struct {
writeCh chan interface{} writeCh chan interface{}
writerDone chan struct{} writerDone chan struct{}
actionCh chan interface{} actionCh chan interface{}
down map[string]*downConnection
up map[string]*upConnection mu sync.Mutex
down map[string]*downConnection
up map[string]*upConnection
} }
type chatHistoryEntry struct { type chatHistoryEntry struct {
@ -76,13 +108,12 @@ type group struct {
history []chatHistoryEntry history []chatHistoryEntry
} }
type delPCAction struct { type delConnAction struct {
id string id string
} }
type addTrackAction struct { type addTrackAction struct {
id string track *upTrack
track *webrtc.Track
remote *upConnection remote *upConnection
done bool done bool
} }