From e8fbfcb9ba532f733405b1c5846f4443e5464c4a Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Tue, 25 Jan 2022 19:45:49 +0100 Subject: [PATCH] Avoid overflow in bitrate computation. --- group/group.go | 1 + rtpconn/rtpconn.go | 26 ++++++++++++++++++-------- rtpconn/rtpconn_test.go | 17 +++++++++++++++++ 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/group/group.go b/group/group.go index 4144832..e54e1c5 100644 --- a/group/group.go +++ b/group/group.go @@ -64,6 +64,7 @@ type ChatHistoryEntry struct { const ( LowBitrate = 100 * 1024 MinBitrate = LowBitrate * 2 + MaxBitrate = 1024 * 1024 * 1024 ) type Group struct { diff --git a/rtpconn/rtpconn.go b/rtpconn/rtpconn.go index b705ea0..ec29355 100644 --- a/rtpconn/rtpconn.go +++ b/rtpconn/rtpconn.go @@ -881,6 +881,15 @@ func rtcpUpListener(track *rtpUpTrack) { } } +// saturating addition +func sadd(x, y uint64) uint64 { + s, c := bits.Add64(x, y, 0) + if c != 0 { + return ^uint64(0) + } + return s +} + func maxUpBitrate(t *rtpUpTrack) uint64 { minrate := ^uint64(0) maxrate := uint64(group.MinBitrate) @@ -908,15 +917,13 @@ func maxUpBitrate(t *rtpUpTrack) uint64 { // assume that lower spatial layers take up 1/5 of // the throughput if maxsid > 0 { - maxrate = maxrate * 5 / 4 + maxrate = sadd(maxrate, maxrate / 4) } // assume that each layer takes two times less // throughput than the higher one. Then we've // got enough slack for a factor of 2^(layers-1). for i := 0; i < maxtid; i++ { - if minrate < ^uint64(0)/2 { - minrate *= 2 - } + minrate = sadd(minrate, minrate) } if minrate < maxrate { return minrate @@ -990,15 +997,18 @@ func sendUpRTCP(up *rtpUpConnection) error { } ssrcs = append(ssrcs, uint32(t.track.SSRC())) if t.Kind() == webrtc.RTPCodecTypeAudio { - rate += 100 * 1024 + rate = sadd(rate, 100 * 1024) } else if t.Label() == "l" { - rate += group.LowBitrate + rate = sadd(rate, group.LowBitrate) } else { - rate += maxUpBitrate(t) + rate = sadd(rate, maxUpBitrate(t)) } } - if rate < ^uint64(0) && len(ssrcs) > 0 { + if rate > group.MaxBitrate { + rate = group.MaxBitrate + } + if len(ssrcs) > 0 { packets = append(packets, &rtcp.ReceiverEstimatedMaximumBitrate{ Bitrate: float32(rate), diff --git a/rtpconn/rtpconn_test.go b/rtpconn/rtpconn_test.go index 36bd644..a1bed66 100644 --- a/rtpconn/rtpconn_test.go +++ b/rtpconn/rtpconn_test.go @@ -37,3 +37,20 @@ func TestDownTrackAtomics(t *testing.T) { t.Errorf("Expected %v, got %v", info, info2) } } + +func TestSadd(t *testing.T) { + ts := []struct{ x, y, z uint64 }{ + {0, 0, 0}, + {1, 2, 3}, + {^uint64(0) - 10, 5, ^uint64(0) - 5}, + {^uint64(0) - 10, 15, ^uint64(0)}, + } + for _, tt := range ts { + z := sadd(tt.x, tt.y) + if z != tt.z { + t.Errorf("%v + %v: expected %v, got %v", + tt.x, tt.y, tt.z, z, + ) + } + } +}