From bfeeeb4bcdeb9669ab1ce1c7fc4e9f51bb8c5a94 Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Mon, 27 Apr 2020 21:43:29 +0200 Subject: [PATCH] Handle NACKs arriving on down connections. --- client.go | 29 ++++++++++++++++++- group.go | 4 +++ packetlist/packetlist.go | 51 +++++++++++++++++++++++++++++++++ packetlist/packetlist_test.go | 53 +++++++++++++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 packetlist/packetlist.go create mode 100644 packetlist/packetlist_test.go diff --git a/client.go b/client.go index 8241e65..c5a5331 100644 --- a/client.go +++ b/client.go @@ -16,6 +16,8 @@ import ( "sync/atomic" "time" + "sfu/packetlist" + "github.com/gorilla/websocket" "github.com/pion/rtcp" "github.com/pion/rtp" @@ -269,8 +271,10 @@ func addUpConn(c *client, id string) (*upConnection, error) { c.mu.Unlock() return } + list := packetlist.New(32) track := &upTrack{ track: remote, + list: list, maxBitrate: ^uint64(0), } u.tracks = append(u.tracks, track) @@ -286,7 +290,7 @@ func addUpConn(c *client, id string) (*upConnection, error) { } go func() { - buf := make([]byte, 1500) + buf := make([]byte, packetlist.BufSize) var packet rtp.Packet var local []*downTrack var localTime time.Time @@ -311,6 +315,8 @@ func addUpConn(c *client, id string) (*upConnection, error) { continue } + list.Store(packet.SequenceNumber, buf[:i]) + for _, l := range local { if l.muted() { continue @@ -523,6 +529,8 @@ func rtcpListener(g *group, conn *downConnection, track *downTrack, s *webrtc.RT uint64(ms), ) case *rtcp.ReceiverReport: + case *rtcp.TransportLayerNack: + sendRecovery(p, track) default: log.Printf("RTCP: %T", p) } @@ -592,6 +600,25 @@ func sendREMB(pc *webrtc.PeerConnection, ssrc uint32, bitrate uint64) error { }) } +func sendRecovery(p *rtcp.TransportLayerNack, track *downTrack) { + var packet rtp.Packet + for _, nack := range p.Nacks { + for _, seqno := range nack.PacketList() { + raw := track.remote.list.Get(seqno) + if raw != nil { + err := packet.Unmarshal(raw) + if err != nil { + continue + } + err = track.track.WriteRTP(&packet) + if err != nil { + log.Printf("%v", err) + } + } + } + } +} + func countMediaStreams(data string) (int, error) { desc := sdp.NewJSEPSessionDescription(false) err := desc.Unmarshal(data) diff --git a/group.go b/group.go index d40618a..89eb56e 100644 --- a/group.go +++ b/group.go @@ -15,11 +15,14 @@ import ( "sync/atomic" "time" + "sfu/packetlist" + "github.com/pion/webrtc/v2" ) type upTrack struct { track *webrtc.Track + list *packetlist.List maxBitrate uint64 mu sync.Mutex @@ -172,6 +175,7 @@ func addGroup(name string, desc *groupDescription) (*group, error) { webrtc.DefaultPayloadTypeVP8, 90000, []webrtc.RTCPFeedback{ {"goog-remb", ""}, + {"nack", ""}, {"nack", "pli"}, }, "", diff --git a/packetlist/packetlist.go b/packetlist/packetlist.go new file mode 100644 index 0000000..fe21e0a --- /dev/null +++ b/packetlist/packetlist.go @@ -0,0 +1,51 @@ +package packetlist + +import ( + "sync" +) + +const BufSize = 1500 + +type entry struct { + seqno uint16 + length int + buf [BufSize]byte +} + +type List struct { + mu sync.Mutex + tail int + entries []entry +} + +func New(capacity int) *List { + return &List{ + entries: make([]entry, capacity), + } +} + +func (list *List) Store(seqno uint16, buf []byte) { + list.mu.Lock() + defer list.mu.Unlock() + list.entries[list.tail].seqno = seqno + copy(list.entries[list.tail].buf[:], buf) + list.entries[list.tail].length = len(buf) + list.tail = (list.tail + 1) % len(list.entries) + +} + +func (list *List) Get(seqno uint16) []byte { + list.mu.Lock() + defer list.mu.Unlock() + + for i := range list.entries { + if list.entries[i].length == 0 || + list.entries[i].seqno != seqno { + continue + } + buf := make([]byte, list.entries[i].length) + copy(buf, list.entries[i].buf[:]) + return buf + } + return nil +} diff --git a/packetlist/packetlist_test.go b/packetlist/packetlist_test.go new file mode 100644 index 0000000..aaf4cae --- /dev/null +++ b/packetlist/packetlist_test.go @@ -0,0 +1,53 @@ +package packetlist + +import ( + "bytes" + "math/rand" + "testing" +) + +func randomBuf() []byte { + length := rand.Int31n(BufSize-1) + 1 + buf := make([]byte, length) + rand.Read(buf) + return buf +} + +func TestList(t *testing.T) { + buf1 := randomBuf() + buf2 := randomBuf() + list := New(16) + list.Store(13, buf1) + list.Store(17, buf2) + + if bytes.Compare(list.Get(13), buf1) != 0 { + t.Errorf("Couldn't get 13") + } + if bytes.Compare(list.Get(17), buf2) != 0 { + t.Errorf("Couldn't get 17") + } + if list.Get(42) != nil { + t.Errorf("Creation ex nihilo") + } +} + +func TestOverflow(t *testing.T) { + list := New(16) + + for i := 0; i < 32; i++ { + list.Store(uint16(i), []byte{uint8(i)}) + } + + for i := 0; i < 32; i++ { + buf := list.Get(uint16(i)) + if i < 16 { + if buf != nil { + t.Errorf("Creation ex nihilo: %v", i) + } + } else { + if len(buf) != 1 || buf[0] != uint8(i) { + t.Errorf("Expected [%v], got %v", i, buf) + } + } + } +}