diff --git a/client.go b/client.go index a472455..9d86b74 100644 --- a/client.go +++ b/client.go @@ -409,7 +409,7 @@ func upLoop(conn *upConnection, track *upTrack) { track.jitter.Accumulate(packet.Timestamp) - first := track.cache.Store(packet.SequenceNumber, buf[:bytes]) + first, _ := track.cache.Store(packet.SequenceNumber, buf[:bytes]) if packet.SequenceNumber-first > 24 { found, first, bitmap := track.cache.BitmapGet() if found { diff --git a/packetcache/packetcache.go b/packetcache/packetcache.go index e52b9e9..1103db0 100644 --- a/packetcache/packetcache.go +++ b/packetcache/packetcache.go @@ -26,11 +26,14 @@ type Cache struct { first uint16 bitmap uint32 // packet cache - tail int + tail uint16 entries []entry } func New(capacity int) *Cache { + if capacity > int(^uint16(0)) { + return nil + } return &Cache{ entries: make([]entry, capacity), } @@ -73,7 +76,7 @@ func (cache *Cache) set(seqno uint16) { } // Store a packet, setting bitmap at the same time -func (cache *Cache) Store(seqno uint16, buf []byte) uint16 { +func (cache *Cache) Store(seqno uint16, buf []byte) (uint16, uint16) { cache.mu.Lock() defer cache.mu.Unlock() @@ -98,12 +101,13 @@ func (cache *Cache) Store(seqno uint16, buf []byte) uint16 { cache.set(seqno) - cache.entries[cache.tail].seqno = seqno - copy(cache.entries[cache.tail].buf[:], buf) - cache.entries[cache.tail].length = uint16(len(buf)) - cache.tail = (cache.tail + 1) % len(cache.entries) + i := cache.tail + cache.entries[i].seqno = seqno + copy(cache.entries[i].buf[:], buf) + cache.entries[i].length = uint16(len(buf)) + cache.tail = (i + 1) % uint16(len(cache.entries)) - return cache.first + return cache.first, i } func (cache *Cache) Expect(n int) { @@ -132,6 +136,19 @@ func (cache *Cache) Get(seqno uint16, result []byte) uint16 { return 0 } +func (cache *Cache) GetAt(seqno uint16, index uint16, result []byte) uint16 { + cache.mu.Lock() + defer cache.mu.Unlock() + + if cache.entries[index].seqno != seqno { + return 0 + } + return uint16(copy( + result[:cache.entries[index].length], + cache.entries[index].buf[:]), + ) +} + // Shift 17 bits out of the bitmap. Return a boolean indicating if any // were 0, the index of the first 0 bit, and a bitmap indicating any // 0 bits after the first one. diff --git a/packetcache/packetcache_test.go b/packetcache/packetcache_test.go index 1163906..4dbac3e 100644 --- a/packetcache/packetcache_test.go +++ b/packetcache/packetcache_test.go @@ -20,8 +20,8 @@ func TestCache(t *testing.T) { buf1 := randomBuf() buf2 := randomBuf() cache := New(16) - cache.Store(13, buf1) - cache.Store(17, buf2) + _, i1 := cache.Store(13, buf1) + _, i2 := cache.Store(17, buf2) buf := make([]byte, BufSize) @@ -29,16 +29,34 @@ func TestCache(t *testing.T) { if bytes.Compare(buf[:l], buf1) != 0 { t.Errorf("Couldn't get 13") } + l = cache.GetAt(13, i1, buf) + if bytes.Compare(buf[:l], buf1) != 0 { + t.Errorf("Couldn't get 13 at %v", i1) + } l = cache.Get(17, buf) if bytes.Compare(buf[:l], buf2) != 0 { t.Errorf("Couldn't get 17") } + l = cache.GetAt(17, i2, buf) + if bytes.Compare(buf[:l], buf2) != 0 { + t.Errorf("Couldn't get 17 at %v", i2) + } l = cache.Get(42, buf) if l != 0 { t.Errorf("Creation ex nihilo") } + + l = cache.GetAt(17, i1, buf) + if l != 0 { + t.Errorf("Got 17 at %v", i1) + } + + l = cache.GetAt(42, i2, buf) + if l != 0 { + t.Errorf("Got 42 at %v", i2) + } } func TestCacheOverflow(t *testing.T) { @@ -82,7 +100,7 @@ func TestBitmap(t *testing.T) { var first uint16 for i := 0; i < 64; i++ { if (value & (1 << i)) != 0 { - first = cache.Store(uint16(42 + i), packet) + first, _ = cache.Store(uint16(42 + i), packet) } } @@ -104,7 +122,7 @@ func TestBitmapWrap(t *testing.T) { var first uint16 for i := 0; i < 64; i++ { if (value & (1 << i)) != 0 { - first = cache.Store(uint16(42 + i), packet) + first, _ = cache.Store(uint16(42 + i), packet) } }