1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-12-22 07:15:47 +01:00

Avoid overflow in bitrate computation.

This commit is contained in:
Juliusz Chroboczek 2022-01-25 19:45:49 +01:00
parent 0d8fdc5c20
commit e8fbfcb9ba
3 changed files with 36 additions and 8 deletions

View file

@ -64,6 +64,7 @@ type ChatHistoryEntry struct {
const (
LowBitrate = 100 * 1024
MinBitrate = LowBitrate * 2
MaxBitrate = 1024 * 1024 * 1024
)
type Group struct {

View file

@ -881,6 +881,15 @@ func rtcpUpListener(track *rtpUpTrack) {
}
}
// saturating addition
func sadd(x, y uint64) uint64 {
s, c := bits.Add64(x, y, 0)
if c != 0 {
return ^uint64(0)
}
return s
}
func maxUpBitrate(t *rtpUpTrack) uint64 {
minrate := ^uint64(0)
maxrate := uint64(group.MinBitrate)
@ -908,15 +917,13 @@ func maxUpBitrate(t *rtpUpTrack) uint64 {
// assume that lower spatial layers take up 1/5 of
// the throughput
if maxsid > 0 {
maxrate = maxrate * 5 / 4
maxrate = sadd(maxrate, maxrate / 4)
}
// assume that each layer takes two times less
// throughput than the higher one. Then we've
// got enough slack for a factor of 2^(layers-1).
for i := 0; i < maxtid; i++ {
if minrate < ^uint64(0)/2 {
minrate *= 2
}
minrate = sadd(minrate, minrate)
}
if minrate < maxrate {
return minrate
@ -990,15 +997,18 @@ func sendUpRTCP(up *rtpUpConnection) error {
}
ssrcs = append(ssrcs, uint32(t.track.SSRC()))
if t.Kind() == webrtc.RTPCodecTypeAudio {
rate += 100 * 1024
rate = sadd(rate, 100 * 1024)
} else if t.Label() == "l" {
rate += group.LowBitrate
rate = sadd(rate, group.LowBitrate)
} else {
rate += maxUpBitrate(t)
rate = sadd(rate, maxUpBitrate(t))
}
}
if rate < ^uint64(0) && len(ssrcs) > 0 {
if rate > group.MaxBitrate {
rate = group.MaxBitrate
}
if len(ssrcs) > 0 {
packets = append(packets,
&rtcp.ReceiverEstimatedMaximumBitrate{
Bitrate: float32(rate),

View file

@ -37,3 +37,20 @@ func TestDownTrackAtomics(t *testing.T) {
t.Errorf("Expected %v, got %v", info, info2)
}
}
func TestSadd(t *testing.T) {
ts := []struct{ x, y, z uint64 }{
{0, 0, 0},
{1, 2, 3},
{^uint64(0) - 10, 5, ^uint64(0) - 5},
{^uint64(0) - 10, 15, ^uint64(0)},
}
for _, tt := range ts {
z := sadd(tt.x, tt.y)
if z != tt.z {
t.Errorf("%v + %v: expected %v, got %v",
tt.x, tt.y, tt.z, z,
)
}
}
}