mirror of
https://github.com/jech/galene.git
synced 2024-11-25 10:05:58 +01:00
438 lines
8.4 KiB
Go
438 lines
8.4 KiB
Go
package rtpconn
|
|
|
|
import (
|
|
"errors"
|
|
"log"
|
|
"sort"
|
|
"time"
|
|
|
|
"github.com/pion/rtp"
|
|
"github.com/pion/webrtc/v3"
|
|
|
|
"sfu/conn"
|
|
"sfu/packetcache"
|
|
"sfu/rtptime"
|
|
)
|
|
|
|
// packetIndex is a request to send a packet from the cache.
|
|
type packetIndex struct {
|
|
// the packet's seqno
|
|
seqno uint16
|
|
// the index in the cache
|
|
index uint16
|
|
}
|
|
|
|
// An rtpWriterPool is a set of rtpWriters
|
|
type rtpWriterPool struct {
|
|
conn *rtpUpConnection
|
|
track *rtpUpTrack
|
|
writers []*rtpWriter
|
|
count int
|
|
}
|
|
|
|
// sqrt computes the integer square root
|
|
func sqrt(n int) int {
|
|
if n < 2 {
|
|
return n
|
|
}
|
|
|
|
s := sqrt(n/2) * 2
|
|
l := s + 1
|
|
if l*l > n {
|
|
return s
|
|
} else {
|
|
return l
|
|
}
|
|
}
|
|
|
|
// add adds or removes a track from a writer pool
|
|
func (wp *rtpWriterPool) add(track conn.DownTrack, add bool) error {
|
|
n := 4
|
|
if wp.count > 16 {
|
|
n = sqrt(wp.count)
|
|
}
|
|
|
|
i := 0
|
|
for i < len(wp.writers) {
|
|
w := wp.writers[i]
|
|
err := w.add(track, add, n)
|
|
if err == nil {
|
|
if add {
|
|
wp.count++
|
|
} else {
|
|
if wp.count > 0 {
|
|
wp.count--
|
|
} else {
|
|
log.Printf("Negative writer count!")
|
|
}
|
|
}
|
|
return nil
|
|
} else if err == ErrWriterDead {
|
|
wp.del(wp.writers[i])
|
|
} else {
|
|
i++
|
|
}
|
|
}
|
|
|
|
if add {
|
|
writer := newRtpWriter(wp.conn, wp.track)
|
|
wp.writers = append(wp.writers, writer)
|
|
err := writer.add(track, true, n)
|
|
if err == nil {
|
|
wp.count++
|
|
}
|
|
return err
|
|
} else {
|
|
return errors.New("deleting unknown track")
|
|
}
|
|
}
|
|
|
|
// del deletes a writer.
|
|
func (wp *rtpWriterPool) del(w *rtpWriter) bool {
|
|
for i, ww := range wp.writers {
|
|
if ww == w {
|
|
close(w.ch)
|
|
wp.writers = append(wp.writers[:i], wp.writers[i+1:]...)
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// close deletes all writers.
|
|
func (wp *rtpWriterPool) close() {
|
|
for _, w := range wp.writers {
|
|
close(w.ch)
|
|
}
|
|
wp.writers = nil
|
|
wp.count = 0
|
|
}
|
|
|
|
// write writes a packet stored in the packet cache to all local tracks
|
|
func (wp *rtpWriterPool) write(seqno uint16, index uint16, delay uint32, isvideo bool, marker bool) {
|
|
pi := packetIndex{seqno, index}
|
|
|
|
var dead []*rtpWriter
|
|
for _, w := range wp.writers {
|
|
if w.drop > 0 {
|
|
// currently dropping
|
|
if marker {
|
|
// last packet in frame
|
|
w.drop = 0
|
|
} else {
|
|
w.drop--
|
|
}
|
|
continue
|
|
}
|
|
select {
|
|
case w.ch <- pi:
|
|
// all is well
|
|
case <-w.done:
|
|
// the writer is dead.
|
|
dead = append(dead, w)
|
|
default:
|
|
// the writer is congested
|
|
if isvideo {
|
|
// drop until the end of the frame
|
|
if !marker {
|
|
w.drop = 7
|
|
}
|
|
continue
|
|
}
|
|
// audio, try again with a delay
|
|
d := delay / uint32(2*len(wp.writers))
|
|
timer := time.NewTimer(rtptime.ToDuration(
|
|
uint64(d), rtptime.JiffiesPerSec,
|
|
))
|
|
|
|
select {
|
|
case w.ch <- pi:
|
|
timer.Stop()
|
|
case <-w.done:
|
|
dead = append(dead, w)
|
|
case <-timer.C:
|
|
}
|
|
}
|
|
}
|
|
|
|
if dead != nil {
|
|
for _, d := range dead {
|
|
wp.del(d)
|
|
}
|
|
dead = nil
|
|
}
|
|
}
|
|
|
|
var ErrWriterDead = errors.New("writer is dead")
|
|
var ErrWriterBusy = errors.New("writer is busy")
|
|
var ErrUnknownTrack = errors.New("unknown track")
|
|
|
|
type writerAction struct {
|
|
add bool
|
|
track conn.DownTrack
|
|
maxTracks int
|
|
ch chan error
|
|
}
|
|
|
|
// an rtpWriter is a thread writing to a set of tracks.
|
|
type rtpWriter struct {
|
|
ch chan packetIndex
|
|
done chan struct{}
|
|
action chan writerAction
|
|
|
|
// this is not touched by the writer loop, used by the caller
|
|
drop int
|
|
}
|
|
|
|
func newRtpWriter(conn *rtpUpConnection, track *rtpUpTrack) *rtpWriter {
|
|
writer := &rtpWriter{
|
|
ch: make(chan packetIndex, 32),
|
|
done: make(chan struct{}),
|
|
action: make(chan writerAction, 1),
|
|
}
|
|
go rtpWriterLoop(writer, conn, track)
|
|
return writer
|
|
}
|
|
|
|
// add adds or removes a track from a writer.
|
|
func (writer *rtpWriter) add(track conn.DownTrack, add bool, max int) error {
|
|
ch := make(chan error, 1)
|
|
select {
|
|
case writer.action <- writerAction{add, track, max, ch}:
|
|
select {
|
|
case err := <-ch:
|
|
return err
|
|
case <-writer.done:
|
|
return ErrWriterDead
|
|
}
|
|
case <-writer.done:
|
|
return ErrWriterDead
|
|
}
|
|
}
|
|
|
|
func sendKeyframe(kf []uint16, track conn.DownTrack, cache *packetcache.Cache) {
|
|
buf := make([]byte, packetcache.BufSize)
|
|
var packet rtp.Packet
|
|
for _, seqno := range kf {
|
|
bytes := cache.Get(seqno, buf)
|
|
if bytes == 0 {
|
|
return
|
|
}
|
|
err := packet.Unmarshal(buf[:bytes])
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = track.WriteRTP(&packet)
|
|
if err != nil && err != conn.ErrKeyframeNeeded {
|
|
return
|
|
}
|
|
track.Accumulate(uint32(bytes))
|
|
}
|
|
}
|
|
|
|
const (
|
|
kfUnneeded = iota
|
|
kfNeededPLI
|
|
kfNeededFIR
|
|
kfNeededNewFIR
|
|
)
|
|
|
|
// rtpWriterLoop is the main loop of an rtpWriter.
|
|
func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
|
|
defer close(writer.done)
|
|
|
|
codec := track.track.Codec().Name
|
|
|
|
buf := make([]byte, packetcache.BufSize)
|
|
var packet rtp.Packet
|
|
|
|
local := make([]conn.DownTrack, 0)
|
|
|
|
kfNeeded := kfUnneeded
|
|
|
|
for {
|
|
select {
|
|
case action := <-writer.action:
|
|
if action.add {
|
|
if len(local) >= action.maxTracks {
|
|
action.ch <- ErrWriterBusy
|
|
close(action.ch)
|
|
continue
|
|
}
|
|
local = append(local, action.track)
|
|
action.ch <- nil
|
|
close(action.ch)
|
|
|
|
track.mu.Lock()
|
|
ntp := track.srNTPTime
|
|
rtp := track.srRTPTime
|
|
cname := track.cname
|
|
track.mu.Unlock()
|
|
if ntp != 0 {
|
|
action.track.SetTimeOffset(ntp, rtp)
|
|
}
|
|
if cname != "" {
|
|
action.track.SetCname(cname)
|
|
}
|
|
|
|
found, _, lts := track.cache.Last()
|
|
kts, _, kf := track.cache.Keyframe()
|
|
if codec == webrtc.VP8 && found && len(kf) > 0 {
|
|
if ((lts-kts)&0x80000000) != 0 ||
|
|
lts-kts < 2*90000 {
|
|
// we got a recent keyframe
|
|
go sendKeyframe(
|
|
kf,
|
|
action.track,
|
|
track.cache,
|
|
)
|
|
} else {
|
|
// Request a new keyframe
|
|
kfNeeded = kfNeededNewFIR
|
|
}
|
|
} else {
|
|
// no keyframe yet, one should
|
|
// arrive soon. Do nothing.
|
|
}
|
|
} else {
|
|
found := false
|
|
for i, t := range local {
|
|
if t == action.track {
|
|
local = append(local[:i],
|
|
local[i+1:]...)
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
action.ch <- ErrUnknownTrack
|
|
} else {
|
|
action.ch <- nil
|
|
}
|
|
close(action.ch)
|
|
if len(local) == 0 {
|
|
return
|
|
}
|
|
}
|
|
case pi, ok := <-writer.ch:
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
bytes := track.cache.GetAt(pi.seqno, pi.index, buf)
|
|
if bytes == 0 {
|
|
continue
|
|
}
|
|
|
|
err := packet.Unmarshal(buf[:bytes])
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
for _, l := range local {
|
|
err := l.WriteRTP(&packet)
|
|
if err != nil {
|
|
if err == conn.ErrKeyframeNeeded {
|
|
kfNeeded = kfNeededPLI
|
|
} else {
|
|
continue
|
|
}
|
|
}
|
|
l.Accumulate(uint32(bytes))
|
|
}
|
|
|
|
if kfNeeded > kfUnneeded {
|
|
kf := false
|
|
kfValid := false
|
|
switch codec {
|
|
case webrtc.VP8:
|
|
kf = isVP8Keyframe(&packet)
|
|
kfValid = true
|
|
}
|
|
if kf {
|
|
kfNeeded = kfUnneeded
|
|
}
|
|
|
|
if kfNeeded >= kfNeededFIR {
|
|
err := up.sendFIR(
|
|
track,
|
|
kfNeeded >= kfNeededNewFIR,
|
|
)
|
|
if err == ErrUnsupportedFeedback {
|
|
kfNeeded = kfNeededPLI
|
|
} else {
|
|
kfNeeded = kfNeededFIR
|
|
}
|
|
}
|
|
|
|
if kfNeeded == kfNeededPLI {
|
|
up.sendPLI(track)
|
|
}
|
|
|
|
if !kfValid {
|
|
// we cannot detect keyframes for
|
|
// this codec, reset our state
|
|
kfNeeded = kfUnneeded
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// nackWriter is called when bufferedNACKs becomes non-empty. It decides
|
|
// which nacks to ship out.
|
|
func nackWriter(conn *rtpUpConnection, track *rtpUpTrack) {
|
|
// a client might send us a NACK for a packet that has already
|
|
// been nacked by the reader loop. Give recovery a chance.
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
track.mu.Lock()
|
|
nacks := track.bufferedNACKs
|
|
track.bufferedNACKs = nil
|
|
track.mu.Unlock()
|
|
|
|
if len(nacks) == 0 || !track.hasRtcpFb("nack", "") {
|
|
return
|
|
}
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// drop any nacks before the last keyframe
|
|
var cutoff uint16
|
|
found, seqno, _ := track.cache.KeyframeSeqno()
|
|
if found {
|
|
cutoff = seqno
|
|
} else {
|
|
last, lastSeqno, _ := track.cache.Last()
|
|
if !last {
|
|
// NACK on a fresh track? Give up.
|
|
return
|
|
}
|
|
// no keyframe, use an arbitrary cutoff
|
|
cutoff = lastSeqno - 256
|
|
}
|
|
|
|
i := 0
|
|
for i < len(nacks) {
|
|
if ((nacks[i] - cutoff) & 0x8000) != 0 {
|
|
// earlier than the cutoff, drop
|
|
nacks = append(nacks[:i], nacks[i+1:]...)
|
|
continue
|
|
}
|
|
l := track.cache.Get(nacks[i], nil)
|
|
if l > 0 {
|
|
// the packet arrived in the meantime
|
|
nacks = append(nacks[:i], nacks[i+1:]...)
|
|
continue
|
|
}
|
|
i++
|
|
}
|
|
|
|
sort.Slice(nacks, func(i, j int) bool {
|
|
return nacks[i]-cutoff < nacks[j]-cutoff
|
|
})
|
|
|
|
if len(nacks) > 0 {
|
|
conn.sendNACKs(track, nacks)
|
|
}
|
|
}
|