diff --git a/codecs/codecs.go b/codecs/codecs.go index 7bb8e98..e7ac943 100644 --- a/codecs/codecs.go +++ b/codecs/codecs.go @@ -230,12 +230,15 @@ func KeyframeDimensions(codec string, packet *rtp.Packet) (uint32, uint32) { type Flags struct { Seqno uint16 + Marker bool Start bool - Pid uint16 // only if it needs rewriting + End bool + Keyframe bool + Pid uint16 Tid uint8 Sid uint8 TidUpSync bool - SidSync bool + SidUpSync bool SidNonReference bool Discardable bool } @@ -248,6 +251,7 @@ func PacketFlags(codec string, buf []byte) (Flags, error) { var flags Flags flags.Seqno = (uint16(buf[2]) << 8) | uint16(buf[3]) + flags.Marker = (buf[1] & 0x80) != 0 if strings.EqualFold(codec, "video/vp8") { var packet rtp.Packet @@ -261,10 +265,13 @@ func PacketFlags(codec string, buf []byte) (Flags, error) { return flags, err } - flags.Start = vp8.S == 1 && vp8.PID == 0 + flags.Start = vp8.S != 0 && vp8.PID == 0 + flags.End = packet.Marker + flags.Keyframe = vp8.S != 0 && (vp8.Payload[0]&0x1) == 0 flags.Pid = vp8.PictureID flags.Tid = vp8.TID - flags.TidUpSync = vp8.Y == 1 + flags.TidUpSync = flags.Keyframe || vp8.Y == 1 + flags.SidUpSync = flags.Keyframe flags.Discardable = vp8.N == 1 return flags, nil } else if strings.EqualFold(codec, "video/vp9") { @@ -279,22 +286,35 @@ func PacketFlags(codec string, buf []byte) (Flags, error) { return flags, err } flags.Start = vp9.B + flags.End = vp9.E + if (vp9.Payload[0] & 0xc0) == 0x80 { + profile := (vp9.Payload[0] >> 4) & 0x3 + if profile != 3 { + flags.Keyframe = (vp9.Payload[0] & 0xC) == 0 + } else { + flags.Keyframe = (vp9.Payload[0] & 0x6) == 0 + } + } + flags.Pid = vp9.PictureID flags.Tid = vp9.TID flags.Sid = vp9.SID - flags.TidUpSync = vp9.U - flags.SidSync = vp9.P - // not yet in pion/rtp + flags.TidUpSync = flags.Keyframe || vp9.U + flags.SidUpSync = flags.Keyframe || !vp9.P flags.SidNonReference = (packet.Payload[0] & 0x01) != 0 return flags, nil } return flags, nil } -func RewritePacket(codec string, data []byte, seqno uint16, delta uint16) error { +func RewritePacket(codec string, data []byte, setMarker bool, seqno uint16, delta uint16) error { if len(data) < 12 { return errTruncated } + if(setMarker) { + data[1] |= 0x80 + } + data[2] = uint8(seqno >> 8) data[3] = uint8(seqno) if delta == 0 { @@ -335,6 +355,23 @@ func RewritePacket(codec string, data []byte, seqno uint16, delta uint16) error data[offset+2] = (data[offset+2] + uint8(delta)) & 0x7F } return nil + } else if strings.EqualFold(codec, "video/vp9") { + i := (data[offset] & 0x80) != 0 + if !i { + return nil + } + m := (data[offset+1] & 0x80) != 0 + if m { + pid := (uint16(data[offset+1]&0x7F) << 8) | + uint16(data[offset+2]) + pid = (pid + delta) & 0x7FFF + data[offset+1] = 0x80 | byte((pid>>8)&0x7F) + data[offset+2] = byte(pid & 0xFF) + } else { + data[offset+1] = (data[offset+1] + uint8(delta)) & 0x7F + } + return nil } + return errUnsupportedCodec } diff --git a/codecs/codecs_test.go b/codecs/codecs_test.go index 8aef8bd..eb46598 100644 --- a/codecs/codecs_test.go +++ b/codecs/codecs_test.go @@ -138,7 +138,7 @@ var vp8 = []byte{ 0, 0, 0, 0, } -func TestPacketFlags(t *testing.T) { +func TestPacketFlagsVP8(t *testing.T) { buf := append([]byte{}, vp8...) flags, err := PacketFlags("video/vp8", buf) if flags.Seqno != 42 || !flags.Start || flags.Pid != 57 || @@ -151,17 +151,56 @@ func TestPacketFlags(t *testing.T) { } } -func TestRewrite(t *testing.T) { +func TestRewriteVP8(t *testing.T) { for i := uint16(0); i < 0x7fff; i++ { buf := append([]byte{}, vp8...) - err := RewritePacket("video/vp8", buf, i, i) + err := RewritePacket("video/vp8", buf, true, i, i) if err != nil { t.Errorf("rewrite: %v", err) continue } flags, err := PacketFlags("video/vp8", buf) if err != nil || flags.Seqno != i || - flags.Pid != (57+i)&0x7FFF { + flags.Pid != (57+i)&0x7FFF || !flags.Marker { + t.Errorf("Expected %v %v, got %v %v (%v)", + i, (57+i)&0x7FFF, + flags.Seqno, flags.Pid, err) + } + } +} + +var vp9 = []byte{ + 0x80, 0, 0, 42, + 0, 0, 0, 0, + 0, 0, 0, 0, + + 0x88, 0x80, 57, 0, +} + +func TestPacketFlagsVP9(t *testing.T) { + buf := append([]byte{}, vp9...) + flags, err := PacketFlags("video/vp9", buf) + if flags.Seqno != 42 || !flags.Start || flags.Pid != 57 || + flags.Sid != 0 || flags.Tid != 0 || + flags.TidUpSync || flags.Discardable || err != nil { + t.Errorf("Got %v, %v, %v, %v, %v, %v (%v)", + flags.Seqno, flags.Start, flags.Pid, flags.Sid, + flags.TidUpSync, flags.Discardable, err, + ) + } +} + +func TestRewriteVP9(t *testing.T) { + for i := uint16(0); i < 0x7fff; i++ { + buf := append([]byte{}, vp9...) + err := RewritePacket("video/vp9", buf, true, i, i) + if err != nil { + t.Errorf("rewrite: %v", err) + continue + } + flags, err := PacketFlags("video/vp9", buf) + if err != nil || flags.Seqno != i || + flags.Pid != (57+i)&0x7FFF || !flags.Marker { t.Errorf("Expected %v %v, got %v %v (%v)", i, (57+i)&0x7FFF, flags.Seqno, flags.Pid, err) diff --git a/rtpconn/rtpconn.go b/rtpconn/rtpconn.go index 138e619..83f0605 100644 --- a/rtpconn/rtpconn.go +++ b/rtpconn/rtpconn.go @@ -248,9 +248,23 @@ func (down *rtpDownTrack) Write(buf []byte) (int, error) { } if flags.Start && (layer.sid != layer.wantedSid) { - if flags.SidSync { - layer.sid = layer.wantedSid - down.setLayerInfo(layer) + if layer.wantedSid < layer.sid { + if flags.Keyframe { + layer.sid = layer.wantedSid + down.setLayerInfo(layer) + } else { + down.remote.RequestKeyframe() + } + } else if layer.wantedSid > layer.sid { + if flags.Keyframe { + layer.sid = layer.wantedSid + down.setLayerInfo(layer) + } else if flags.SidUpSync { + layer.sid = layer.sid + 1 + down.setLayerInfo(layer) + } else { + down.remote.RequestKeyframe() + } } } @@ -267,7 +281,9 @@ func (down *rtpDownTrack) Write(buf []byte) (int, error) { return 0, nil } - if newseqno == flags.Seqno && piddelta == 0 { + setMarker := flags.Sid == layer.sid && flags.End && !flags.Marker + + if !setMarker && newseqno == flags.Seqno && piddelta == 0 { return down.write(buf) } @@ -276,7 +292,7 @@ func (down *rtpDownTrack) Write(buf []byte) (int, error) { buf2 := ibuf2.([]byte) n := copy(buf2, buf) - err = codecs.RewritePacket(codec, buf2[:n], newseqno, piddelta) + err = codecs.RewritePacket(codec, buf2[:n], setMarker, newseqno, piddelta) if err != nil { return 0, err }