1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-11-14 12:45:58 +01:00
galene/rtpconn/rtpwriter.go

435 lines
8.4 KiB
Go
Raw Normal View History

package rtpconn
import (
"errors"
"log"
"sort"
2020-12-04 01:15:52 +01:00
"strings"
"time"
"github.com/pion/rtp"
2020-12-19 17:37:48 +01:00
"github.com/jech/galene/conn"
"github.com/jech/galene/packetcache"
"github.com/jech/galene/rtptime"
)
// packetIndex is a request to send a packet from the cache.
type packetIndex struct {
// the packet's seqno
seqno uint16
// the index in the cache
index uint16
}
// An rtpWriterPool is a set of rtpWriters
type rtpWriterPool struct {
conn *rtpUpConnection
track *rtpUpTrack
writers []*rtpWriter
count int
}
// sqrt computes the integer square root
func sqrt(n int) int {
if n < 2 {
return n
}
s := sqrt(n/2) * 2
l := s + 1
if l*l > n {
return s
} else {
return l
}
}
// add adds or removes a track from a writer pool
2020-09-13 11:04:16 +02:00
func (wp *rtpWriterPool) add(track conn.DownTrack, add bool) error {
n := 4
if wp.count > 16 {
n = sqrt(wp.count)
}
i := 0
for i < len(wp.writers) {
w := wp.writers[i]
err := w.add(track, add, n)
if err == nil {
if add {
wp.count++
} else {
if wp.count > 0 {
wp.count--
} else {
log.Printf("Negative writer count!")
}
}
return nil
} else if err == ErrWriterDead {
wp.del(wp.writers[i])
} else {
i++
}
}
if add {
writer := newRtpWriter(wp.conn, wp.track)
wp.writers = append(wp.writers, writer)
err := writer.add(track, true, n)
if err == nil {
wp.count++
}
return err
} else {
return errors.New("deleting unknown track")
}
}
// del deletes a writer.
func (wp *rtpWriterPool) del(w *rtpWriter) bool {
for i, ww := range wp.writers {
if ww == w {
close(w.ch)
wp.writers = append(wp.writers[:i], wp.writers[i+1:]...)
return true
}
}
return false
}
// close deletes all writers.
func (wp *rtpWriterPool) close() {
for _, w := range wp.writers {
close(w.ch)
}
wp.writers = nil
wp.count = 0
}
// write writes a packet stored in the packet cache to all local tracks
func (wp *rtpWriterPool) write(seqno uint16, index uint16, delay uint32, isvideo bool, marker bool) {
pi := packetIndex{seqno, index}
var dead []*rtpWriter
for _, w := range wp.writers {
if w.drop > 0 {
// currently dropping
if marker {
// last packet in frame
w.drop = 0
} else {
w.drop--
}
continue
}
select {
case w.ch <- pi:
// all is well
case <-w.done:
// the writer is dead.
dead = append(dead, w)
default:
// the writer is congested
if isvideo {
// drop until the end of the frame
if !marker {
w.drop = 7
}
continue
}
// audio, try again with a delay
2020-10-03 12:54:17 +02:00
d := delay / uint32(2*len(wp.writers))
timer := time.NewTimer(rtptime.ToDuration(
uint64(d), rtptime.JiffiesPerSec,
))
select {
case w.ch <- pi:
timer.Stop()
case <-w.done:
dead = append(dead, w)
case <-timer.C:
}
}
}
if dead != nil {
for _, d := range dead {
wp.del(d)
}
dead = nil
}
}
var ErrWriterDead = errors.New("writer is dead")
var ErrWriterBusy = errors.New("writer is busy")
var ErrUnknownTrack = errors.New("unknown track")
type writerAction struct {
add bool
2020-09-13 11:04:16 +02:00
track conn.DownTrack
maxTracks int
ch chan error
}
// an rtpWriter is a thread writing to a set of tracks.
type rtpWriter struct {
ch chan packetIndex
done chan struct{}
action chan writerAction
// this is not touched by the writer loop, used by the caller
drop int
}
func newRtpWriter(conn *rtpUpConnection, track *rtpUpTrack) *rtpWriter {
writer := &rtpWriter{
ch: make(chan packetIndex, 32),
done: make(chan struct{}),
action: make(chan writerAction, 1),
}
go rtpWriterLoop(writer, conn, track)
return writer
}
// add adds or removes a track from a writer.
2020-09-13 11:04:16 +02:00
func (writer *rtpWriter) add(track conn.DownTrack, add bool, max int) error {
ch := make(chan error, 1)
select {
case writer.action <- writerAction{add, track, max, ch}:
select {
case err := <-ch:
return err
case <-writer.done:
return ErrWriterDead
}
case <-writer.done:
return ErrWriterDead
}
}
func sendKeyframe(kf []uint16, track conn.DownTrack, cache *packetcache.Cache) {
2020-10-03 12:54:17 +02:00
buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet
for _, seqno := range kf {
bytes := cache.Get(seqno, buf)
if bytes == 0 {
2020-10-03 12:54:17 +02:00
return
}
err := packet.Unmarshal(buf[:bytes])
if err != nil {
return
}
err = track.WriteRTP(&packet)
if err != nil && err != conn.ErrKeyframeNeeded {
return
}
track.Accumulate(uint32(bytes))
}
}
const (
kfUnneeded = iota
kfNeededPLI
kfNeededFIR
kfNeededNewFIR
)
// rtpWriterLoop is the main loop of an rtpWriter.
2020-09-13 11:04:16 +02:00
func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
defer close(writer.done)
2020-12-04 01:15:52 +01:00
codec := track.track.Codec()
buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet
2020-09-13 11:04:16 +02:00
local := make([]conn.DownTrack, 0)
kfNeeded := kfUnneeded
for {
select {
case action := <-writer.action:
if action.add {
if len(local) >= action.maxTracks {
action.ch <- ErrWriterBusy
close(action.ch)
continue
}
local = append(local, action.track)
action.ch <- nil
close(action.ch)
track.mu.Lock()
ntp := track.srNTPTime
rtp := track.srRTPTime
cname := track.cname
track.mu.Unlock()
if ntp != 0 {
2020-09-13 11:04:16 +02:00
action.track.SetTimeOffset(ntp, rtp)
}
if cname != "" {
2020-09-13 11:04:16 +02:00
action.track.SetCname(cname)
}
found, _, lts := track.cache.Last()
kts, _, kf := track.cache.Keyframe()
2020-12-04 01:15:52 +01:00
if strings.ToLower(codec.MimeType) == "video/vp8" &&
found && len(kf) > 0 {
if ((lts-kts)&0x80000000) != 0 ||
lts-kts < 2*90000 {
// we got a recent keyframe
go sendKeyframe(
kf,
action.track,
track.cache,
)
} else {
// Request a new keyframe
kfNeeded = kfNeededNewFIR
}
} else {
// no keyframe yet, one should
// arrive soon. Do nothing.
}
} else {
found := false
for i, t := range local {
if t == action.track {
local = append(local[:i],
local[i+1:]...)
found = true
break
}
}
if !found {
action.ch <- ErrUnknownTrack
} else {
action.ch <- nil
}
close(action.ch)
if len(local) == 0 {
return
}
}
case pi, ok := <-writer.ch:
if !ok {
return
}
bytes := track.cache.GetAt(pi.seqno, pi.index, buf)
if bytes == 0 {
continue
}
err := packet.Unmarshal(buf[:bytes])
if err != nil {
continue
}
for _, l := range local {
err := l.WriteRTP(&packet)
if err != nil {
2020-09-13 11:04:16 +02:00
if err == conn.ErrKeyframeNeeded {
kfNeeded = kfNeededPLI
2020-10-03 12:54:17 +02:00
} else {
continue
}
}
l.Accumulate(uint32(bytes))
}
if kfNeeded > kfUnneeded {
2020-12-25 16:39:12 +01:00
kf, kfKnown :=
isKeyframe(codec.MimeType, &packet)
if kf {
kfNeeded = kfUnneeded
}
if kfNeeded >= kfNeededFIR {
err := up.sendFIR(
track,
kfNeeded >= kfNeededNewFIR,
)
if err == ErrUnsupportedFeedback {
kfNeeded = kfNeededPLI
} else {
kfNeeded = kfNeededFIR
}
}
if kfNeeded == kfNeededPLI {
2020-09-13 11:04:16 +02:00
up.sendPLI(track)
}
2020-12-25 16:39:12 +01:00
if !kfKnown {
// we cannot detect keyframes for
// this codec, reset our state
kfNeeded = kfUnneeded
}
}
}
}
}
// nackWriter is called when bufferedNACKs becomes non-empty. It decides
// which nacks to ship out.
func nackWriter(conn *rtpUpConnection, track *rtpUpTrack) {
// a client might send us a NACK for a packet that has already
// been nacked by the reader loop. Give recovery a chance.
2020-10-31 21:31:05 +01:00
time.Sleep(50 * time.Millisecond)
track.mu.Lock()
nacks := track.bufferedNACKs
track.bufferedNACKs = nil
track.mu.Unlock()
2020-10-31 21:31:05 +01:00
if len(nacks) == 0 || !track.hasRtcpFb("nack", "") {
return
}
2020-10-31 21:31:05 +01:00
time.Sleep(50 * time.Millisecond)
// drop any nacks before the last keyframe
var cutoff uint16
found, seqno, _ := track.cache.KeyframeSeqno()
if found {
cutoff = seqno
} else {
last, lastSeqno, _ := track.cache.Last()
if !last {
// NACK on a fresh track? Give up.
return
}
// no keyframe, use an arbitrary cutoff
cutoff = lastSeqno - 256
}
i := 0
for i < len(nacks) {
if ((nacks[i] - cutoff) & 0x8000) != 0 {
// earlier than the cutoff, drop
nacks = append(nacks[:i], nacks[i+1:]...)
continue
}
l := track.cache.Get(nacks[i], nil)
if l > 0 {
// the packet arrived in the meantime
nacks = append(nacks[:i], nacks[i+1:]...)
continue
}
i++
}
sort.Slice(nacks, func(i, j int) bool {
return nacks[i]-cutoff < nacks[j]-cutoff
})
if len(nacks) > 0 {
conn.sendNACKs(track, nacks)
}
}