diff --git a/conn/conn.go b/conn/conn.go index 33c893d..2129dae 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -9,7 +9,6 @@ import ( ) var ErrConnectionClosed = errors.New("connection is closed") -var ErrKeyframeNeeded = errors.New("keyframe needed") // Type Up represents a connection in the client to server direction. type Up interface { @@ -30,6 +29,7 @@ type UpTrack interface { // get a recent packet. Returns 0 if the packet is not in cache. GetRTP(seqno uint16, result []byte) uint16 Nack(conn Up, seqnos []uint16) error + RequestKeyframe() error } // Type Down represents a connection in the server to client direction. diff --git a/diskwriter/diskwriter.go b/diskwriter/diskwriter.go index 0491e7e..80dc5b4 100644 --- a/diskwriter/diskwriter.go +++ b/diskwriter/diskwriter.go @@ -458,7 +458,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { sample, ts := t.builder.PopWithTimestamp() if sample == nil { if kfNeeded { - return conn.ErrKeyframeNeeded + t.remote.RequestKeyframe() } return nil } @@ -506,7 +506,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { if t.writer == nil { if !keyframe { - return conn.ErrKeyframeNeeded + t.remote.RequestKeyframe() } return nil } diff --git a/rtpconn/rtpconn.go b/rtpconn/rtpconn.go index d512793..9828e87 100644 --- a/rtpconn/rtpconn.go +++ b/rtpconn/rtpconn.go @@ -171,9 +171,9 @@ func newDownConn(c group.Client, id string, remote conn.Up) (*rtpDownConnection, }) conn := &rtpDownConnection{ - id: id, - pc: pc, - remote: remote, + id: id, + pc: pc, + remote: remote, } return conn, nil @@ -235,7 +235,7 @@ type rtpUpTrack struct { atomics *upTrackAtomics cname atomic.Value - localCh chan localTrackAction + localCh chan trackAction readerDone chan struct{} mu sync.Mutex @@ -246,14 +246,20 @@ type rtpUpTrack struct { bufferedNACKs []uint16 } -type localTrackAction struct { - add bool - track conn.DownTrack +const ( + trackActionAdd = iota + trackActionDel + trackActionKeyframe +) + +type trackAction struct { + action int + track conn.DownTrack } -func (up *rtpUpTrack) notifyLocal(add bool, track conn.DownTrack) { +func (up *rtpUpTrack) action(action int, track conn.DownTrack) { select { - case up.localCh <- localTrackAction{add, track}: + case up.localCh <- trackAction{action, track}: case <-up.readerDone: } } @@ -271,7 +277,12 @@ func (up *rtpUpTrack) AddLocal(local conn.DownTrack) error { // do this asynchronously, to avoid deadlocks when multiple // clients call this simultaneously. - go up.notifyLocal(true, local) + go up.action(trackActionAdd, local) + return nil +} + +func (up *rtpUpTrack) RequestKeyframe() error { + go up.action(trackActionKeyframe, nil) return nil } @@ -283,7 +294,7 @@ func (up *rtpUpTrack) DelLocal(local conn.DownTrack) bool { up.local = append(up.local[:i], up.local[i+1:]...) // do this asynchronously, to avoid deadlocking when // multiple clients call this simultaneously. - go up.notifyLocal(false, l) + go up.action(trackActionDel, l) return true } } @@ -489,7 +500,7 @@ func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpCon rate: estimator.New(time.Second), jitter: jitter.New(remote.Codec().ClockRate), atomics: &upTrackAtomics{}, - localCh: make(chan localTrackAction, 2), + localCh: make(chan trackAction, 2), readerDone: make(chan struct{}), } @@ -977,7 +988,6 @@ func (track *rtpDownTrack) updateRate(loss uint8, now uint64) { } func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RTPSender) { - var gotFir bool lastFirSeqno := uint8(0) buf := make([]byte, 1500) @@ -1001,18 +1011,7 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT for _, p := range ps { switch p := p.(type) { case *rtcp.PictureLossIndication: - remote, ok := conn.remote.(*rtpUpConnection) - if !ok { - continue - } - rt, ok := track.remote.(*rtpUpTrack) - if !ok { - continue - } - err := remote.sendPLI(rt) - if err != nil && err != ErrRateLimited { - log.Printf("sendPLI: %v", err) - } + track.remote.RequestKeyframe() case *rtcp.FullIntraRequest: found := false var seqno uint8 @@ -1028,29 +1027,8 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT continue } - increment := true - if gotFir { - increment = seqno != lastFirSeqno - } - gotFir = true - lastFirSeqno = seqno - - remote, ok := conn.remote.(*rtpUpConnection) - if !ok { - continue - } - rt, ok := track.remote.(*rtpUpTrack) - if !ok { - continue - } - err := remote.sendFIR(rt, increment) - if err == ErrUnsupportedFeedback { - err := remote.sendPLI(rt) - if err != nil && err != ErrRateLimited { - log.Printf("sendPLI: %v", err) - } - } else if err != nil && err != ErrRateLimited { - log.Printf("sendFIR: %v", err) + if seqno != lastFirSeqno { + track.remote.RequestKeyframe() } case *rtcp.ReceiverEstimatedMaximumBitrate: track.maxREMBBitrate.Set(p.Bitrate, jiffies) diff --git a/rtpconn/rtpreader.go b/rtpconn/rtpreader.go index 107fdd4..43ef4e4 100644 --- a/rtpconn/rtpreader.go +++ b/rtpconn/rtpreader.go @@ -21,6 +21,7 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo codec := track.track.Codec() sendNACK := track.hasRtcpFb("nack", "") + var kfNeeded bool buf := make([]byte, packetcache.BufSize) var packet rtp.Packet for { @@ -41,8 +42,10 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { track.jitter.Accumulate(packet.Timestamp) - kf, _ := isKeyframe(codec.MimeType, &packet) - + kf, kfKnown := isKeyframe(codec.MimeType, &packet) + if kf || !kfKnown { + kfNeeded = false + } if packet.Extension { packet.Extension = false packet.Extensions = nil @@ -102,11 +105,29 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { select { case action := <-track.localCh: - err := writers.add(action.track, action.add) - if err != nil { - log.Printf("add/remove track: %v", err) + switch action.action { + case trackActionAdd, trackActionDel: + err := writers.add( + action.track, + action.action == trackActionAdd, + ) + if err != nil { + log.Printf("add/remove track: %v", err) + } + case trackActionKeyframe: + kfNeeded = true + default: + log.Printf("Unknown action %v", action.action) } default: } + + if kfNeeded { + err := conn.sendPLI(track) + if err != nil && err != ErrRateLimited { + log.Printf("sendPLI: %v", err) + kfNeeded = false + } + } } } diff --git a/rtpconn/rtpwriter.go b/rtpconn/rtpwriter.go index 7008cc2..98de4ff 100644 --- a/rtpconn/rtpwriter.go +++ b/rtpconn/rtpwriter.go @@ -4,7 +4,6 @@ import ( "errors" "log" "sort" - "strings" "time" "github.com/pion/rtp" @@ -223,33 +222,22 @@ func sendKeyframe(kf []uint16, track conn.DownTrack, cache *packetcache.Cache) { return } err = track.WriteRTP(&packet) - if err != nil && err != conn.ErrKeyframeNeeded { + if err != nil { return } track.Accumulate(uint32(bytes)) } } -const ( - kfUnneeded = iota - kfNeededPLI - kfNeededFIR - kfNeededNewFIR -) - // rtpWriterLoop is the main loop of an rtpWriter. func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { defer close(writer.done) - codec := track.track.Codec() - buf := make([]byte, packetcache.BufSize) var packet rtp.Packet local := make([]conn.DownTrack, 0) - kfNeeded := kfUnneeded - for { select { case action := <-writer.action: @@ -277,8 +265,7 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { found, _, lts := track.cache.Last() kts, _, kf := track.cache.Keyframe() - if strings.ToLower(codec.MimeType) == "video/vp8" && - found && len(kf) > 0 { + if found && len(kf) > 0 { if ((lts-kts)&0x80000000) != 0 || lts-kts < 2*90000 { // we got a recent keyframe @@ -288,8 +275,7 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { track.cache, ) } else { - // Request a new keyframe - kfNeeded = kfNeededNewFIR + track.RequestKeyframe() } } else { // no keyframe yet, one should @@ -333,44 +319,10 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { for _, l := range local { err := l.WriteRTP(&packet) if err != nil { - if err == conn.ErrKeyframeNeeded { - kfNeeded = kfNeededPLI - } else { - continue - } + continue } l.Accumulate(uint32(bytes)) } - - if kfNeeded > kfUnneeded { - kf, kfKnown := - isKeyframe(codec.MimeType, &packet) - if kf { - kfNeeded = kfUnneeded - } - - if kfNeeded >= kfNeededFIR { - err := up.sendFIR( - track, - kfNeeded >= kfNeededNewFIR, - ) - if err == ErrUnsupportedFeedback { - kfNeeded = kfNeededPLI - } else { - kfNeeded = kfNeededFIR - } - } - - if kfNeeded == kfNeededPLI { - up.sendPLI(track) - } - - if !kfKnown { - // we cannot detect keyframes for - // this codec, reset our state - kfNeeded = kfUnneeded - } - } } } }