From b4edb7c0c7a780d0483ea7634fa309f86aecd157 Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Fri, 22 Jul 2022 21:24:36 +0200 Subject: [PATCH] Avoid overflow in FromDuration and ToDuration. Thanks to lamhai1401. --- rtptime/rtptime.go | 25 +++++++++++++++---------- rtptime/rtptime_test.go | 19 +++++++++++++++++++ 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/rtptime/rtptime.go b/rtptime/rtptime.go index 45aacec..a824708 100644 --- a/rtptime/rtptime.go +++ b/rtptime/rtptime.go @@ -2,6 +2,7 @@ package rtptime import ( + "math/bits" "time" ) @@ -9,25 +10,29 @@ import ( var epoch = time.Now() // FromDuration converts a time.Duration into units of 1/hz. +// Negative values are clamped to zero. func FromDuration(d time.Duration, hz uint32) int64 { - return int64(d) * int64(hz) / int64(time.Second) + if d < 0 { + return -FromDuration(-d, hz) + } + hi, lo := bits.Mul64(uint64(d), uint64(hz)) + q, _ := bits.Div64(hi, lo, uint64(time.Second)) + return int64(q) } // ToDuration converts units of 1/hz into a time.Duration. func ToDuration(tm int64, hz uint32) time.Duration { - return time.Duration(tm * int64(time.Second) / int64(hz)) -} - -func sat(a int64) uint64 { - if a < 0 { - return 0 + if tm < 0 { + return -ToDuration(-tm, hz) } - return uint64(a) + hi, lo := bits.Mul64(uint64(tm), uint64(time.Second)) + q, _ := bits.Div64(hi, lo, uint64(hz)) + return time.Duration(q) } // Now returns the current time in units of 1/hz from an arbitrary origin. func Now(hz uint32) uint64 { - return sat(FromDuration(time.Since(epoch), hz)) + return uint64(FromDuration(time.Since(epoch), hz)) } // Microseconds is like Now, but uses microseconds. @@ -46,7 +51,7 @@ func Jiffies() uint64 { // TimeToJiffies converts a time.Time into jiffies. func TimeToJiffies(tm time.Time) uint64 { - return sat(FromDuration(tm.Sub(epoch), JiffiesPerSec)) + return uint64(FromDuration(tm.Sub(epoch), JiffiesPerSec)) } // The origin of NTP time. diff --git a/rtptime/rtptime_test.go b/rtptime/rtptime_test.go index 0b7e447..23dc6fb 100644 --- a/rtptime/rtptime_test.go +++ b/rtptime/rtptime_test.go @@ -27,6 +27,25 @@ func TestDuration(t *testing.T) { } } +func TestDurationOverflow(t *testing.T) { + delta := 10 * time.Minute + dj := FromDuration(delta, JiffiesPerSec) + var prev int64 + for d := time.Duration(0); d < time.Duration(1000*time.Hour); d += delta { + jiffies := FromDuration(d, JiffiesPerSec) + if d != 0 { + if jiffies != prev+dj { + t.Errorf("%v: %v, %v", d, jiffies, prev) + } + } + d2 := ToDuration(jiffies, JiffiesPerSec) + if d2 != d { + t.Errorf("%v != %v (%v)", d2, d, jiffies) + } + prev = jiffies + } +} + func differs(a, b, delta uint64) bool { if a < b { a, b = b, a