1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-11-22 00:25:58 +01:00

Maintain a reference to the connection associated to each track.

This avoids carrying an extra parameter in many places.
This commit is contained in:
Juliusz Chroboczek 2021-07-14 14:05:23 +02:00
parent bcd62f190b
commit 36d31f0db8
6 changed files with 30 additions and 32 deletions

View file

@ -27,7 +27,7 @@ type UpTrack interface {
Codec() webrtc.RTPCodecCapability Codec() webrtc.RTPCodecCapability
// get a recent packet. Returns 0 if the packet is not in cache. // get a recent packet. Returns 0 if the packet is not in cache.
GetRTP(seqno uint16, result []byte) uint16 GetRTP(seqno uint16, result []byte) uint16
Nack(conn Up, seqnos []uint16) error Nack(seqnos []uint16) error
RequestKeyframe() error RequestKeyframe() error
} }

View file

@ -510,7 +510,7 @@ func (t *diskTrack) Write(buf []byte) (int, error) {
} }
} }
if len(nacks) > 0 { if len(nacks) > 0 {
t.remote.Nack(t.conn.remote, nacks) t.remote.Nack(nacks)
} }
} }
} }

View file

@ -80,6 +80,7 @@ type downTrackAtomics struct {
type rtpDownTrack struct { type rtpDownTrack struct {
track *webrtc.TrackLocalStaticRTP track *webrtc.TrackLocalStaticRTP
sender *webrtc.RTPSender sender *webrtc.RTPSender
conn *rtpDownConnection
remote conn.UpTrack remote conn.UpTrack
ssrc webrtc.SSRC ssrc webrtc.SSRC
packetmap packetmap.Map packetmap packetmap.Map
@ -351,6 +352,7 @@ func (down *rtpDownConnection) flushICECandidates() error {
type rtpUpTrack struct { type rtpUpTrack struct {
track *webrtc.TrackRemote track *webrtc.TrackRemote
conn *rtpUpConnection
rate *estimator.Estimator rate *estimator.Estimator
cache *packetcache.Cache cache *packetcache.Cache
jitter *jitter.Estimator jitter *jitter.Estimator
@ -616,6 +618,7 @@ func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpCon
track := &rtpUpTrack{ track := &rtpUpTrack{
track: remote, track: remote,
conn: up,
cache: packetcache.New(minPacketCache(remote)), cache: packetcache.New(minPacketCache(remote)),
rate: estimator.New(time.Second), rate: estimator.New(time.Second),
jitter: jitter.New(remote.Codec().ClockRate), jitter: jitter.New(remote.Codec().ClockRate),
@ -625,7 +628,7 @@ func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpCon
up.tracks = append(up.tracks, track) up.tracks = append(up.tracks, track)
go readLoop(up, track) go readLoop(track)
go rtcpUpListener(up, track, receiver) go rtcpUpListener(up, track, receiver)
@ -643,11 +646,11 @@ func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpCon
var ErrUnsupportedFeedback = errors.New("unsupported feedback type") var ErrUnsupportedFeedback = errors.New("unsupported feedback type")
var ErrRateLimited = errors.New("rate limited") var ErrRateLimited = errors.New("rate limited")
func (up *rtpUpConnection) sendPLI(track *rtpUpTrack) error { func (track *rtpUpTrack) sendPLI() error {
if !track.hasRtcpFb("nack", "pli") { if !track.hasRtcpFb("nack", "pli") {
return ErrUnsupportedFeedback return ErrUnsupportedFeedback
} }
return sendPLI(up.pc, track.track.SSRC()) return sendPLI(track.conn.pc, track.track.SSRC())
} }
func sendPLI(pc *webrtc.PeerConnection, ssrc webrtc.SSRC) error { func sendPLI(pc *webrtc.PeerConnection, ssrc webrtc.SSRC) error {
@ -656,12 +659,12 @@ func sendPLI(pc *webrtc.PeerConnection, ssrc webrtc.SSRC) error {
}) })
} }
func (up *rtpUpConnection) sendNACK(track *rtpUpTrack, first uint16, bitmap uint16) error { func (track *rtpUpTrack) sendNACK(first uint16, bitmap uint16) error {
if !track.hasRtcpFb("nack", "") { if !track.hasRtcpFb("nack", "") {
return ErrUnsupportedFeedback return ErrUnsupportedFeedback
} }
err := sendNACKs(up.pc, track.track.SSRC(), err := sendNACKs(track.conn.pc, track.track.SSRC(),
[]rtcp.NackPair{{first, rtcp.PacketBitmap(bitmap)}}, []rtcp.NackPair{{first, rtcp.PacketBitmap(bitmap)}},
) )
if err == nil { if err == nil {
@ -670,7 +673,7 @@ func (up *rtpUpConnection) sendNACK(track *rtpUpTrack, first uint16, bitmap uint
return err return err
} }
func (up *rtpUpConnection) sendNACKs(track *rtpUpTrack, seqnos []uint16) error { func (track *rtpUpTrack) sendNACKs(seqnos []uint16) error {
count := len(seqnos) count := len(seqnos)
if count == 0 { if count == 0 {
return nil return nil
@ -691,7 +694,7 @@ func (up *rtpUpConnection) sendNACKs(track *rtpUpTrack, seqnos []uint16) error {
f, b, seqnos = packetcache.ToBitmap(seqnos) f, b, seqnos = packetcache.ToBitmap(seqnos)
nacks = append(nacks, rtcp.NackPair{f, rtcp.PacketBitmap(b)}) nacks = append(nacks, rtcp.NackPair{f, rtcp.PacketBitmap(b)})
} }
err := sendNACKs(up.pc, track.track.SSRC(), nacks) err := sendNACKs(track.conn.pc, track.track.SSRC(), nacks)
if err == nil { if err == nil {
track.cache.Expect(count) track.cache.Expect(count)
} }
@ -708,7 +711,7 @@ func sendNACKs(pc *webrtc.PeerConnection, ssrc webrtc.SSRC, nacks []rtcp.NackPai
return pc.WriteRTCP([]rtcp.Packet{packet}) return pc.WriteRTCP([]rtcp.Packet{packet})
} }
func gotNACK(conn *rtpDownConnection, track *rtpDownTrack, p *rtcp.TransportLayerNack) { func gotNACK(track *rtpDownTrack, p *rtcp.TransportLayerNack) {
var unhandled []uint16 var unhandled []uint16
buf := make([]byte, packetcache.BufSize) buf := make([]byte, packetcache.BufSize)
for _, nack := range p.Nacks { for _, nack := range p.Nacks {
@ -734,10 +737,10 @@ func gotNACK(conn *rtpDownConnection, track *rtpDownTrack, p *rtcp.TransportLaye
return return
} }
track.remote.Nack(conn.remote, unhandled) track.remote.Nack(unhandled)
} }
func (track *rtpUpTrack) Nack(conn conn.Up, nacks []uint16) error { func (track *rtpUpTrack) Nack(nacks []uint16) error {
track.mu.Lock() track.mu.Lock()
defer track.mu.Unlock() defer track.mu.Unlock()
@ -754,12 +757,7 @@ outer:
} }
if doit { if doit {
up, ok := conn.(*rtpUpConnection) go nackWriter(track)
if !ok {
log.Printf("Nack: unexpected type %T", conn)
return errors.New("unexpected connection type")
}
go nackWriter(up, track)
} }
return nil return nil
} }
@ -1088,7 +1086,7 @@ func (track *rtpDownTrack) updateRate(loss uint8, now uint64) {
track.maxBitrate.Set(rate, now) track.maxBitrate.Set(rate, now)
} }
func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RTPSender) { func rtcpDownListener(track *rtpDownTrack, s *webrtc.RTPSender) {
lastFirSeqno := uint8(0) lastFirSeqno := uint8(0)
buf := make([]byte, 1500) buf := make([]byte, 1500)
@ -1149,7 +1147,7 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT
} }
} }
case *rtcp.TransportLayerNack: case *rtcp.TransportLayerNack:
gotNACK(conn, track, p) gotNACK(track, p)
} }
} }
if adjust { if adjust {

View file

@ -12,8 +12,8 @@ import (
"github.com/jech/galene/rtptime" "github.com/jech/galene/rtptime"
) )
func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { func readLoop(track *rtpUpTrack) {
writers := rtpWriterPool{conn: conn, track: track} writers := rtpWriterPool{track: track}
defer func() { defer func() {
writers.close() writers.close()
close(track.readerDone) close(track.readerDone)
@ -118,7 +118,7 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) {
packet.SequenceNumber - unnacked, packet.SequenceNumber - unnacked,
) )
if found && sendNACK { if found && sendNACK {
err := conn.sendNACK(track, first, bitmap) err := track.sendNACK(first, bitmap)
if err != nil { if err != nil {
log.Printf("%v", err) log.Printf("%v", err)
} }
@ -136,7 +136,7 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) {
now := time.Now() now := time.Now()
if kfNeeded && now.Sub(kfRequested) > time.Second/2 { if kfNeeded && now.Sub(kfRequested) > time.Second/2 {
if sendPLI { if sendPLI {
err := conn.sendPLI(track) err := track.sendPLI()
if err != nil { if err != nil {
log.Printf("sendPLI: %v", err) log.Printf("sendPLI: %v", err)
kfNeeded = false kfNeeded = false

View file

@ -21,7 +21,6 @@ type packetIndex struct {
// An rtpWriterPool is a set of rtpWriters // An rtpWriterPool is a set of rtpWriters
type rtpWriterPool struct { type rtpWriterPool struct {
conn *rtpUpConnection
track *rtpUpTrack track *rtpUpTrack
writers []*rtpWriter writers []*rtpWriter
count int count int
@ -72,7 +71,7 @@ func (wp *rtpWriterPool) add(track conn.DownTrack, add bool) error {
} }
if add { if add {
writer := newRtpWriter(wp.conn, wp.track) writer := newRtpWriter(wp.track)
wp.writers = append(wp.writers, writer) wp.writers = append(wp.writers, writer)
err := writer.add(track, true, n) err := writer.add(track, true, n)
if err == nil { if err == nil {
@ -181,13 +180,13 @@ type rtpWriter struct {
drop int drop int
} }
func newRtpWriter(conn *rtpUpConnection, track *rtpUpTrack) *rtpWriter { func newRtpWriter(track *rtpUpTrack) *rtpWriter {
writer := &rtpWriter{ writer := &rtpWriter{
ch: make(chan packetIndex, 32), ch: make(chan packetIndex, 32),
done: make(chan struct{}), done: make(chan struct{}),
action: make(chan writerAction, 1), action: make(chan writerAction, 1),
} }
go rtpWriterLoop(writer, conn, track) go rtpWriterLoop(writer, track)
return writer return writer
} }
@ -223,7 +222,7 @@ func sendKeyframe(kf []uint16, track conn.DownTrack, cache *packetcache.Cache) {
} }
// rtpWriterLoop is the main loop of an rtpWriter. // rtpWriterLoop is the main loop of an rtpWriter.
func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { func rtpWriterLoop(writer *rtpWriter, track *rtpUpTrack) {
defer close(writer.done) defer close(writer.done)
buf := make([]byte, packetcache.BufSize) buf := make([]byte, packetcache.BufSize)
@ -314,7 +313,7 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
// nackWriter is called when bufferedNACKs becomes non-empty. It decides // nackWriter is called when bufferedNACKs becomes non-empty. It decides
// which nacks to ship out. // which nacks to ship out.
func nackWriter(conn *rtpUpConnection, track *rtpUpTrack) { func nackWriter(track *rtpUpTrack) {
// a client might send us a NACK for a packet that has already // a client might send us a NACK for a packet that has already
// been nacked by the reader loop. Give recovery a chance. // been nacked by the reader loop. Give recovery a chance.
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
@ -366,6 +365,6 @@ func nackWriter(conn *rtpUpConnection, track *rtpUpTrack) {
}) })
if len(nacks) > 0 { if len(nacks) > 0 {
conn.sendNACKs(track, nacks) track.sendNACKs(nacks)
} }
} }

View file

@ -399,6 +399,7 @@ func addDownTrackUnlocked(conn *rtpDownConnection, remoteTrack *rtpUpTrack, remo
track: local, track: local,
sender: sender, sender: sender,
ssrc: parms.Encodings[0].SSRC, ssrc: parms.Encodings[0].SSRC,
conn: conn,
remote: remoteTrack, remote: remoteTrack,
maxBitrate: new(bitrate), maxBitrate: new(bitrate),
maxREMBBitrate: new(bitrate), maxREMBBitrate: new(bitrate),
@ -409,7 +410,7 @@ func addDownTrackUnlocked(conn *rtpDownConnection, remoteTrack *rtpUpTrack, remo
conn.tracks = append(conn.tracks, track) conn.tracks = append(conn.tracks, track)
go rtcpDownListener(conn, track, sender) go rtcpDownListener(track, sender)
return nil return nil
} }