From 89780b866b93e6e6a57d783120c0c2181a39d918 Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Thu, 29 Jul 2021 21:30:12 +0200 Subject: [PATCH] Move packet parsing code into its own package. --- rtpconn/codec.go => codecs/codecs.go | 64 ++++++++++++++-------------- codecs/codecs_test.go | 46 ++++++++++++++++++++ rtpconn/codec_test.go | 46 -------------------- rtpconn/rtpconn.go | 41 +++++++++--------- rtpconn/rtpreader.go | 3 +- 5 files changed, 101 insertions(+), 99 deletions(-) rename rtpconn/codec.go => codecs/codecs.go (85%) create mode 100644 codecs/codecs_test.go delete mode 100644 rtpconn/codec_test.go diff --git a/rtpconn/codec.go b/codecs/codecs.go similarity index 85% rename from rtpconn/codec.go rename to codecs/codecs.go index 93dd465..6ffbec3 100644 --- a/rtpconn/codec.go +++ b/codecs/codecs.go @@ -1,4 +1,4 @@ -package rtpconn +package codecs import ( "errors" @@ -8,11 +8,14 @@ import ( "github.com/pion/rtp/codecs" ) -// isKeyframe determines if packet is the start of a keyframe. +var errTruncated = errors.New("truncated packet") +var errUnsupportedCodec = errors.New("unsupported codec") + +// Keyframe determines if packet is the start of a keyframe. // It returns (true, true) if that is the case, (false, true) if that is // definitely not the case, and (false, false) if the information cannot // be determined. -func isKeyframe(codec string, packet *rtp.Packet) (bool, bool) { +func Keyframe(codec string, packet *rtp.Packet) (bool, bool) { if strings.EqualFold(codec, "video/vp8") { var vp8 codecs.VP8Packet _, err := vp8.Unmarshal(packet.Payload) @@ -179,29 +182,26 @@ func isKeyframe(codec string, packet *rtp.Packet) (bool, bool) { return false, false } -var errTruncated = errors.New("truncated packet") -var errUnsupportedCodec = errors.New("unsupported codec") - -type packetFlags struct { - seqno uint16 - start bool - pid uint16 // only if it needs rewriting - tid uint8 - sid uint8 - tidupsync bool - sidsync bool - sidnonreference bool - discardable bool +type Flags struct { + Seqno uint16 + Start bool + Pid uint16 // only if it needs rewriting + Tid uint8 + Sid uint8 + TidUpSync bool + SidSync bool + SidNonReference bool + Discardable bool } -func getPacketFlags(codec string, buf []byte) (packetFlags, error) { +func PacketFlags(codec string, buf []byte) (Flags, error) { if len(buf) < 12 { - return packetFlags{}, errTruncated + return Flags{}, errTruncated } - var flags packetFlags + var flags Flags - flags.seqno = (uint16(buf[2]) << 8) | uint16(buf[3]) + flags.Seqno = (uint16(buf[2]) << 8) | uint16(buf[3]) if strings.EqualFold(codec, "video/vp8") { var packet rtp.Packet @@ -215,11 +215,11 @@ func getPacketFlags(codec string, buf []byte) (packetFlags, error) { return flags, err } - flags.start = vp8.S == 1 && vp8.PID == 0 - flags.pid = vp8.PictureID - flags.tid = vp8.TID - flags.tidupsync = vp8.Y == 1 - flags.discardable = vp8.N == 1 + flags.Start = vp8.S == 1 && vp8.PID == 0 + flags.Pid = vp8.PictureID + flags.Tid = vp8.TID + flags.TidUpSync = vp8.Y == 1 + flags.Discardable = vp8.N == 1 return flags, nil } else if strings.EqualFold(codec, "video/vp9") { var packet rtp.Packet @@ -232,19 +232,19 @@ func getPacketFlags(codec string, buf []byte) (packetFlags, error) { if err != nil { return flags, err } - flags.start = vp9.B - flags.tid = vp9.TID - flags.sid = vp9.SID - flags.tidupsync = vp9.U - flags.sidsync = vp9.P + flags.Start = vp9.B + flags.Tid = vp9.TID + flags.Sid = vp9.SID + flags.TidUpSync = vp9.U + flags.SidSync = vp9.P // not yet in pion/rtp - flags.sidnonreference = (packet.Payload[0] & 0x01) != 0 + flags.SidNonReference = (packet.Payload[0] & 0x01) != 0 return flags, nil } return flags, nil } -func rewritePacket(codec string, data []byte, seqno uint16, delta uint16) error { +func RewritePacket(codec string, data []byte, seqno uint16, delta uint16) error { if len(data) < 12 { return errTruncated } diff --git a/codecs/codecs_test.go b/codecs/codecs_test.go new file mode 100644 index 0000000..96faec9 --- /dev/null +++ b/codecs/codecs_test.go @@ -0,0 +1,46 @@ +package codecs + +import ( + "testing" +) + +var vp8 = []byte{ + 0x80, 0, 0, 42, + 0, 0, 0, 0, + 0, 0, 0, 0, + + 0x90, 0x80, 0x80, 57, + + 0, 0, 0, 0, +} + +func TestPacketFlags(t *testing.T) { + buf := append([]byte{}, vp8...) + flags, err := PacketFlags("video/vp8", buf) + if flags.Seqno != 42 || !flags.Start || flags.Pid != 57 || + flags.Sid != 0 || flags.Tid != 0 || + flags.TidUpSync || flags.Discardable || err != nil { + t.Errorf("Got %v, %v, %v, %v, %v, %v (%v)", + flags.Seqno, flags.Start, flags.Pid, flags.Sid, + flags.TidUpSync, flags.Discardable, err, + ) + } +} + +func TestRewrite(t *testing.T) { + for i := uint16(0); i < 0x7fff; i++ { + buf := append([]byte{}, vp8...) + err := RewritePacket("video/vp8", buf, i, i) + if err != nil { + t.Errorf("rewrite: %v", err) + continue + } + flags, err := PacketFlags("video/vp8", buf) + if err != nil || flags.Seqno != i || + flags.Pid != (57+i)&0x7FFF { + t.Errorf("Expected %v %v, got %v %v (%v)", + i, (57+i)&0x7FFF, + flags.Seqno, flags.Pid, err) + } + } +} diff --git a/rtpconn/codec_test.go b/rtpconn/codec_test.go deleted file mode 100644 index 358f132..0000000 --- a/rtpconn/codec_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package rtpconn - -import ( - "testing" -) - -var vp8 = []byte{ - 0x80, 0, 0, 42, - 0, 0, 0, 0, - 0, 0, 0, 0, - - 0x90, 0x80, 0x80, 57, - - 0, 0, 0, 0, -} - -func TestPacketFlags(t *testing.T) { - buf := append([]byte{}, vp8...) - flags, err := getPacketFlags("video/vp8", buf) - if flags.seqno != 42 || !flags.start || flags.pid != 57 || - flags.sid != 0 || flags.tid != 0 || - flags.tidupsync || flags.discardable || err != nil { - t.Errorf("Got %v, %v, %v, %v, %v, %v (%v)", - flags.seqno, flags.start, flags.pid, flags.sid, - flags.tidupsync, flags.discardable, err, - ) - } -} - -func TestRewrite(t *testing.T) { - for i := uint16(0); i < 0x7fff; i++ { - buf := append([]byte{}, vp8...) - err := rewritePacket("video/vp8", buf, i, i) - if err != nil { - t.Errorf("rewrite: %v", err) - continue - } - flags, err := getPacketFlags("video/vp8", buf) - if err != nil || flags.seqno != i || - flags.pid != (57 + i) & 0x7FFF { - t.Errorf("Expected %v %v, got %v %v (%v)", - i, (57 + i) & 0x7FFF, - flags.seqno, flags.pid, err) - } - } -} diff --git a/rtpconn/rtpconn.go b/rtpconn/rtpconn.go index 87c3394..d497493 100644 --- a/rtpconn/rtpconn.go +++ b/rtpconn/rtpconn.go @@ -14,6 +14,7 @@ import ( "github.com/pion/sdp/v3" "github.com/pion/webrtc/v3" + "github.com/jech/galene/codecs" "github.com/jech/galene/conn" "github.com/jech/galene/estimator" "github.com/jech/galene/group" @@ -213,59 +214,59 @@ var packetBufPool = sync.Pool{ func (down *rtpDownTrack) Write(buf []byte) (int, error) { codec := down.remote.Codec().MimeType - flags, err := getPacketFlags(codec, buf) + flags, err := codecs.PacketFlags(codec, buf) if err != nil { return 0, err } layer := down.getLayerInfo() - if flags.tid > layer.maxTid || flags.sid > layer.maxSid { - if flags.tid > layer.maxTid { + if flags.Tid > layer.maxTid || flags.Sid > layer.maxSid { + if flags.Tid > layer.maxTid { if layer.tid == layer.maxTid { - layer.wantedTid = flags.tid - layer.tid = flags.tid + layer.wantedTid = flags.Tid + layer.tid = flags.Tid } - layer.maxTid = flags.tid + layer.maxTid = flags.Tid } - if flags.sid > layer.maxSid { + if flags.Sid > layer.maxSid { if layer.sid == layer.maxSid { - layer.wantedSid = flags.sid - layer.sid = flags.sid + layer.wantedSid = flags.Sid + layer.sid = flags.Sid } - layer.maxSid = flags.sid + layer.maxSid = flags.Sid } down.setLayerInfo(layer) down.adjustLayer() } - if flags.start && (layer.tid != layer.wantedTid) { - if layer.wantedTid < layer.tid || flags.tidupsync { + if flags.Start && (layer.tid != layer.wantedTid) { + if layer.wantedTid < layer.tid || flags.TidUpSync { layer.tid = layer.wantedTid down.setLayerInfo(layer) } } - if flags.start && (layer.sid != layer.wantedSid) { - if flags.sidsync { + if flags.Start && (layer.sid != layer.wantedSid) { + if flags.SidSync { layer.sid = layer.wantedTid down.setLayerInfo(layer) } } - if flags.tid > layer.tid || flags.sid > layer.sid || - (flags.sid < layer.sid && flags.sidnonreference) { - ok := down.packetmap.Drop(flags.seqno, flags.pid) + if flags.Tid > layer.tid || flags.Sid > layer.sid || + (flags.Sid < layer.sid && flags.SidNonReference) { + ok := down.packetmap.Drop(flags.Seqno, flags.Pid) if ok { return 0, nil } } - ok, newseqno, piddelta := down.packetmap.Map(flags.seqno, flags.pid) + ok, newseqno, piddelta := down.packetmap.Map(flags.Seqno, flags.Pid) if !ok { return 0, nil } - if newseqno == flags.seqno && piddelta == 0 { + if newseqno == flags.Seqno && piddelta == 0 { return down.write(buf) } @@ -274,7 +275,7 @@ func (down *rtpDownTrack) Write(buf []byte) (int, error) { buf2 := ibuf2.([]byte) n := copy(buf2, buf) - err = rewritePacket(codec, buf2[:n], newseqno, piddelta) + err = codecs.RewritePacket(codec, buf2[:n], newseqno, piddelta) if err != nil { return 0, err } diff --git a/rtpconn/rtpreader.go b/rtpconn/rtpreader.go index d86b548..6a501df 100644 --- a/rtpconn/rtpreader.go +++ b/rtpconn/rtpreader.go @@ -8,6 +8,7 @@ import ( "github.com/pion/rtp" "github.com/pion/webrtc/v3" + "github.com/jech/galene/codecs" "github.com/jech/galene/packetcache" "github.com/jech/galene/rtptime" ) @@ -74,7 +75,7 @@ func readLoop(track *rtpUpTrack) { track.jitter.Accumulate(packet.Timestamp) - kf, kfKnown := isKeyframe(codec.MimeType, &packet) + kf, kfKnown := codecs.Keyframe(codec.MimeType, &packet) if kf || !kfKnown { kfNeeded = false }