diff --git a/conn/conn.go b/conn/conn.go index c4de7e6..45ccf9f 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -32,6 +32,7 @@ type UpTrack interface { Codec() *webrtc.RTPCodec // get a recent packet. Returns 0 if the packet is not in cache. GetRTP(seqno uint16, result []byte) uint16 + Nack(conn Up, seqnos []uint16) error } // Type Down represents a connection in the server to client direction. diff --git a/rtpconn/rtpconn.go b/rtpconn/rtpconn.go index c100fce..6fc30ab 100644 --- a/rtpconn/rtpconn.go +++ b/rtpconn/rtpconn.go @@ -188,12 +188,13 @@ type rtpUpTrack struct { localCh chan localTrackAction readerDone chan struct{} - mu sync.Mutex - cname string - local []conn.DownTrack - srTime uint64 - srNTPTime uint64 - srRTPTime uint32 + mu sync.Mutex + cname string + srTime uint64 + srNTPTime uint64 + srRTPTime uint32 + local []conn.DownTrack + bufferedNACKs []uint16 } type localTrackAction struct { @@ -538,13 +539,15 @@ func sendNACK(pc *webrtc.PeerConnection, ssrc uint32, first uint16, bitmap uint1 return pc.WriteRTCP([]rtcp.Packet{packet}) } -func sendRecovery(p *rtcp.TransportLayerNack, track *rtpDownTrack) { +func gotNACK(conn *rtpDownConnection, track *rtpDownTrack, p *rtcp.TransportLayerNack) { + var unhandled []uint16 var packet rtp.Packet buf := make([]byte, packetcache.BufSize) for _, nack := range p.Nacks { for _, seqno := range nack.PacketList() { l := track.remote.GetRTP(seqno, buf) if l == 0 { + unhandled = append(unhandled, seqno) continue } err := packet.Unmarshal(buf[:l]) @@ -559,6 +562,38 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *rtpDownTrack) { track.rate.Accumulate(uint32(l)) } } + if len(unhandled) == 0 { + return + } + + track.remote.Nack(conn.remote, unhandled) +} + +func (track *rtpUpTrack) Nack(conn conn.Up, nacks []uint16) error { + track.mu.Lock() + defer track.mu.Unlock() + + doit := len(track.bufferedNACKs) == 0 + +outer: + for _, nack := range nacks { + for _, seqno := range track.bufferedNACKs { + if seqno == nack { + continue outer + } + } + track.bufferedNACKs = append(track.bufferedNACKs, nack) + } + + if doit { + up, ok := conn.(*rtpUpConnection) + if !ok { + log.Printf("Nack: unexpected type %T", conn) + return errors.New("unexpected connection type") + } + go nackWriter(up, track) + } + return nil } func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPReceiver) { @@ -938,7 +973,7 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT } } case *rtcp.TransportLayerNack: - sendRecovery(p, track) + gotNACK(conn, track, p) } } } diff --git a/rtpconn/rtpwriter.go b/rtpconn/rtpwriter.go index fc30e5c..8e8e956 100644 --- a/rtpconn/rtpwriter.go +++ b/rtpconn/rtpwriter.go @@ -3,6 +3,7 @@ package rtpconn import ( "errors" "log" + "sort" "time" "github.com/pion/rtp" @@ -360,3 +361,55 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { } } } + +// nackWriter is called when bufferedNACKs becomes non-empty. It decides +// which nacks to ship out. +func nackWriter(conn *rtpUpConnection, track *rtpUpTrack) { + // a client might send us a NACK for a packet that has already + // been nacked by the reader loop. Give recovery a chance. + time.Sleep(100 * time.Millisecond) + + track.mu.Lock() + nacks := track.bufferedNACKs + track.bufferedNACKs = nil + track.mu.Unlock() + + // drop any nacks before the last keyframe + var cutoff uint16 + found, seqno, _ := track.cache.KeyframeSeqno() + if found { + cutoff = seqno + } else { + last, lastSeqno, _ := track.cache.Last() + if !last { + // NACK on a fresh track? Give up. + return + } + // no keyframe, use an arbitrary cutoff + cutoff = lastSeqno - 256 + } + + i := 0 + for i < len(nacks) { + if ((nacks[i] - cutoff) & 0x8000) != 0 { + // earlier than the cutoff, drop + nacks = append(nacks[:i], nacks[i+1:]...) + continue + } + l := track.cache.Get(nacks[i], nil) + if l > 0 { + // the packet arrived in the meantime + nacks = append(nacks[:i], nacks[i+1:]...) + continue + } + i++ + } + + sort.Slice(nacks, func(i, j int) bool { + return nacks[i]-cutoff < nacks[j]-cutoff + }) + + for _, nack := range nacks { + conn.sendNACK(track, nack, 0) + } +}