1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-11-09 18:25:58 +01:00

Avoid overflow in bitrate computation.

This commit is contained in:
Juliusz Chroboczek 2022-01-25 19:45:49 +01:00
parent 0d8fdc5c20
commit e8fbfcb9ba
3 changed files with 36 additions and 8 deletions

View file

@ -64,6 +64,7 @@ type ChatHistoryEntry struct {
const ( const (
LowBitrate = 100 * 1024 LowBitrate = 100 * 1024
MinBitrate = LowBitrate * 2 MinBitrate = LowBitrate * 2
MaxBitrate = 1024 * 1024 * 1024
) )
type Group struct { type Group struct {

View file

@ -881,6 +881,15 @@ func rtcpUpListener(track *rtpUpTrack) {
} }
} }
// saturating addition
func sadd(x, y uint64) uint64 {
s, c := bits.Add64(x, y, 0)
if c != 0 {
return ^uint64(0)
}
return s
}
func maxUpBitrate(t *rtpUpTrack) uint64 { func maxUpBitrate(t *rtpUpTrack) uint64 {
minrate := ^uint64(0) minrate := ^uint64(0)
maxrate := uint64(group.MinBitrate) maxrate := uint64(group.MinBitrate)
@ -908,15 +917,13 @@ func maxUpBitrate(t *rtpUpTrack) uint64 {
// assume that lower spatial layers take up 1/5 of // assume that lower spatial layers take up 1/5 of
// the throughput // the throughput
if maxsid > 0 { if maxsid > 0 {
maxrate = maxrate * 5 / 4 maxrate = sadd(maxrate, maxrate / 4)
} }
// assume that each layer takes two times less // assume that each layer takes two times less
// throughput than the higher one. Then we've // throughput than the higher one. Then we've
// got enough slack for a factor of 2^(layers-1). // got enough slack for a factor of 2^(layers-1).
for i := 0; i < maxtid; i++ { for i := 0; i < maxtid; i++ {
if minrate < ^uint64(0)/2 { minrate = sadd(minrate, minrate)
minrate *= 2
}
} }
if minrate < maxrate { if minrate < maxrate {
return minrate return minrate
@ -990,15 +997,18 @@ func sendUpRTCP(up *rtpUpConnection) error {
} }
ssrcs = append(ssrcs, uint32(t.track.SSRC())) ssrcs = append(ssrcs, uint32(t.track.SSRC()))
if t.Kind() == webrtc.RTPCodecTypeAudio { if t.Kind() == webrtc.RTPCodecTypeAudio {
rate += 100 * 1024 rate = sadd(rate, 100 * 1024)
} else if t.Label() == "l" { } else if t.Label() == "l" {
rate += group.LowBitrate rate = sadd(rate, group.LowBitrate)
} else { } else {
rate += maxUpBitrate(t) rate = sadd(rate, maxUpBitrate(t))
} }
} }
if rate < ^uint64(0) && len(ssrcs) > 0 { if rate > group.MaxBitrate {
rate = group.MaxBitrate
}
if len(ssrcs) > 0 {
packets = append(packets, packets = append(packets,
&rtcp.ReceiverEstimatedMaximumBitrate{ &rtcp.ReceiverEstimatedMaximumBitrate{
Bitrate: float32(rate), Bitrate: float32(rate),

View file

@ -37,3 +37,20 @@ func TestDownTrackAtomics(t *testing.T) {
t.Errorf("Expected %v, got %v", info, info2) t.Errorf("Expected %v, got %v", info, info2)
} }
} }
func TestSadd(t *testing.T) {
ts := []struct{ x, y, z uint64 }{
{0, 0, 0},
{1, 2, 3},
{^uint64(0) - 10, 5, ^uint64(0) - 5},
{^uint64(0) - 10, 15, ^uint64(0)},
}
for _, tt := range ts {
z := sadd(tt.x, tt.y)
if z != tt.z {
t.Errorf("%v + %v: expected %v, got %v",
tt.x, tt.y, tt.z, z,
)
}
}
}