diff --git a/bitarray/bitarray.go b/bitarray/bitarray.go index 73526c9..37e49ff 100644 --- a/bitarray/bitarray.go +++ b/bitarray/bitarray.go @@ -20,6 +20,8 @@ efficient way. This is *NOT* a threadsafe package. */ package bitarray +import "math/bits" + // bitArray is a struct that maintains state of a bit array. type bitArray struct { blocks []block @@ -116,7 +118,74 @@ func (ba *bitArray) GetBit(k uint64) (bool, error) { return result, nil } -//ClearBit will unset a bit at the given index if it is set. +// GetSetBits gets the position of bits set in the array. +func (ba *bitArray) GetSetBits(from uint64, buffer []uint64) []uint64 { + fromBlockIndex, fromOffset := getIndexAndRemainder(from) + return getSetBitsInBlocks( + fromBlockIndex, + fromOffset, + ba.blocks[fromBlockIndex:], + nil, + buffer, + ) +} + +// getSetBitsInBlocks fills a buffer with positions of set bits in the provided blocks. Optionally, indices may be +// provided for sparse/non-consecutive blocks. +func getSetBitsInBlocks( + fromBlockIndex, fromOffset uint64, + blocks []block, + indices []uint64, + buffer []uint64, +) []uint64 { + bufferCapacity := cap(buffer) + if bufferCapacity == 0 { + return buffer[:0] + } + + results := buffer[:bufferCapacity] + resultSize := 0 + + for i, block := range blocks { + blockIndex := fromBlockIndex + uint64(i) + if indices != nil { + blockIndex = indices[i] + } + + isFirstBlock := blockIndex == fromBlockIndex + if isFirstBlock { + block >>= fromOffset + } + + for block != 0 { + trailing := bits.TrailingZeros64(uint64(block)) + + if isFirstBlock { + results[resultSize] = uint64(trailing) + (blockIndex << 6) + fromOffset + } else { + results[resultSize] = uint64(trailing) + (blockIndex << 6) + } + resultSize++ + + if resultSize == cap(results) { + return results[:resultSize] + } + + // Clear the bit we just added to the result, which is the last bit set in the block. Ex.: + // block 01001100 + // ^block 10110011 + // (^block) + 1 10110100 + // block & (^block) + 1 00000100 + // block ^ mask 01001000 + mask := block & ((^block) + 1) + block = block ^ mask + } + } + + return results[:resultSize] +} + +// ClearBit will unset a bit at the given index if it is set. func (ba *bitArray) ClearBit(k uint64) error { if k >= ba.Capacity() { return OutOfRangeError(k) @@ -137,6 +206,15 @@ func (ba *bitArray) ClearBit(k uint64) error { return nil } +// Count returns the number of set bits in this array. +func (ba *bitArray) Count() int { + count := 0 + for _, block := range ba.blocks { + count += bits.OnesCount64(uint64(block)) + } + return count +} + // Or will bitwise or two bit arrays and return a new bit array // representing the result. func (ba *bitArray) Or(other BitArray) BitArray { diff --git a/bitarray/bitarray_test.go b/bitarray/bitarray_test.go index 9172369..b005757 100644 --- a/bitarray/bitarray_test.go +++ b/bitarray/bitarray_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBitOperations(t *testing.T) { @@ -142,6 +143,28 @@ func TestIsEmpty(t *testing.T) { assert.False(t, ba.IsEmpty()) } +func TestCount(t *testing.T) { + ba := newBitArray(500) + assert.Equal(t, 0, ba.Count()) + + require.NoError(t, ba.SetBit(0)) + assert.Equal(t, 1, ba.Count()) + + require.NoError(t, ba.SetBit(40)) + require.NoError(t, ba.SetBit(64)) + require.NoError(t, ba.SetBit(100)) + require.NoError(t, ba.SetBit(200)) + require.NoError(t, ba.SetBit(469)) + require.NoError(t, ba.SetBit(500)) + assert.Equal(t, 7, ba.Count()) + + require.NoError(t, ba.ClearBit(200)) + assert.Equal(t, 6, ba.Count()) + + ba.Reset() + assert.Equal(t, 0, ba.Count()) +} + func TestClear(t *testing.T) { ba := newBitArray(10) @@ -195,6 +218,53 @@ func BenchmarkGetBit(b *testing.B) { } } +func TestGetSetBits(t *testing.T) { + ba := newBitArray(1000) + buf := make([]uint64, 0, 5) + + require.NoError(t, ba.SetBit(1)) + require.NoError(t, ba.SetBit(4)) + require.NoError(t, ba.SetBit(8)) + require.NoError(t, ba.SetBit(63)) + require.NoError(t, ba.SetBit(64)) + require.NoError(t, ba.SetBit(200)) + require.NoError(t, ba.SetBit(1000)) + + assert.Equal(t, []uint64(nil), ba.GetSetBits(0, nil)) + assert.Equal(t, []uint64{}, ba.GetSetBits(0, []uint64{})) + + assert.Equal(t, []uint64{1, 4, 8, 63, 64}, ba.GetSetBits(0, buf)) + assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(10, buf)) + assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(63, buf)) + assert.Equal(t, []uint64{200, 1000}, ba.GetSetBits(128, buf)) + + require.NoError(t, ba.ClearBit(4)) + require.NoError(t, ba.ClearBit(64)) + assert.Equal(t, []uint64{1, 8, 63, 200, 1000}, ba.GetSetBits(0, buf)) + assert.Empty(t, ba.GetSetBits(1001, buf)) + + ba.Reset() + assert.Empty(t, ba.GetSetBits(0, buf)) +} + +func BenchmarkGetSetBits(b *testing.B) { + numItems := uint64(168000) + + ba := newBitArray(numItems) + for i := uint64(0); i < numItems; i++ { + if i%13 == 0 || i%5 == 0 { + require.NoError(b, ba.SetBit(i)) + } + } + + buf := make([]uint64, 0, ba.Capacity()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ba.GetSetBits(0, buf) + } +} + func TestEquality(t *testing.T) { ba := newBitArray(s + 1) other := newBitArray(s + 1) diff --git a/bitarray/interface.go b/bitarray/interface.go index bb4057a..fb22493 100644 --- a/bitarray/interface.go +++ b/bitarray/interface.go @@ -36,6 +36,10 @@ type BitArray interface { // function returns an error if the position is out // of range. A sparse bit array never returns an error. GetBit(k uint64) (bool, error) + // GetSetBits gets the position of bits set in the array. Will + // return as many set bits as can fit in the provided buffer + // starting from the specified position in the array. + GetSetBits(from uint64, buffer []uint64) []uint64 // ClearBit clears the bit at the given position. This // function returns an error if the position is out // of range. A sparse bit array never returns an error. @@ -55,6 +59,8 @@ type BitArray interface { // in the case of a dense bit array or the highest possible // seen capacity of the sparse array. Capacity() uint64 + // Count returns the number of set bits in this array. + Count() int // Or will bitwise or the two bitarrays and return a new bitarray // representing the result. Or(other BitArray) BitArray diff --git a/bitarray/sparse_bitarray.go b/bitarray/sparse_bitarray.go index ca68e9f..b9e3dbc 100644 --- a/bitarray/sparse_bitarray.go +++ b/bitarray/sparse_bitarray.go @@ -16,7 +16,10 @@ limitations under the License. package bitarray -import "sort" +import ( + "math/bits" + "sort" +) // uintSlice is an alias for a slice of ints. Len, Swap, and Less // are exported to fulfill an interface needed for the search @@ -127,6 +130,24 @@ func (sba *sparseBitArray) GetBit(k uint64) (bool, error) { return sba.blocks[i].get(position), nil } +// GetSetBits gets the position of bits set in the array. +func (sba *sparseBitArray) GetSetBits(from uint64, buffer []uint64) []uint64 { + fromBlockIndex, fromOffset := getIndexAndRemainder(from) + + fromBlockLocation := sba.indices.search(fromBlockIndex) + if int(fromBlockLocation) == len(sba.indices) { + return buffer[:0] + } + + return getSetBitsInBlocks( + fromBlockIndex, + fromOffset, + sba.blocks[fromBlockLocation:], + sba.indices[fromBlockLocation:], + buffer, + ) +} + // ToNums converts this sparse bitarray to a list of numbers contained // within it. func (sba *sparseBitArray) ToNums() []uint64 { @@ -225,6 +246,15 @@ func (sba *sparseBitArray) Equals(other BitArray) bool { return true } +// Count returns the number of set bits in this array. +func (sba *sparseBitArray) Count() int { + count := 0 + for _, block := range sba.blocks { + count += bits.OnesCount64(uint64(block)) + } + return count +} + // Or will perform a bitwise or operation with the provided bitarray and // return a new result bitarray. func (sba *sparseBitArray) Or(other BitArray) BitArray { diff --git a/bitarray/sparse_bitarray_test.go b/bitarray/sparse_bitarray_test.go index 4f58d3e..0944347 100644 --- a/bitarray/sparse_bitarray_test.go +++ b/bitarray/sparse_bitarray_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGetCompressedBit(t *testing.T) { @@ -76,6 +77,73 @@ func BenchmarkSetCompressedBit(b *testing.B) { } } +func TestGetSetCompressedBits(t *testing.T) { + ba := newSparseBitArray() + buf := make([]uint64, 0, 5) + + require.NoError(t, ba.SetBit(1)) + require.NoError(t, ba.SetBit(4)) + require.NoError(t, ba.SetBit(8)) + require.NoError(t, ba.SetBit(63)) + require.NoError(t, ba.SetBit(64)) + require.NoError(t, ba.SetBit(200)) + require.NoError(t, ba.SetBit(1000)) + + assert.Equal(t, []uint64(nil), ba.GetSetBits(0, nil)) + assert.Equal(t, []uint64{}, ba.GetSetBits(0, []uint64{})) + + assert.Equal(t, []uint64{1, 4, 8, 63, 64}, ba.GetSetBits(0, buf)) + assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(10, buf)) + assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(63, buf)) + assert.Equal(t, []uint64{200, 1000}, ba.GetSetBits(128, buf)) + + require.NoError(t, ba.ClearBit(4)) + require.NoError(t, ba.ClearBit(64)) + assert.Equal(t, []uint64{1, 8, 63, 200, 1000}, ba.GetSetBits(0, buf)) + assert.Empty(t, ba.GetSetBits(1001, buf)) + + ba.Reset() + assert.Empty(t, ba.GetSetBits(0, buf)) +} + +func BenchmarkGetSetCompressedBits(b *testing.B) { + ba := newSparseBitArray() + for i := uint64(0); i < 168000; i++ { + if i%13 == 0 || i%5 == 0 { + require.NoError(b, ba.SetBit(i)) + } + } + + buf := make([]uint64, 0, ba.Capacity()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ba.GetSetBits(0, buf) + } +} + +func TestCompressedCount(t *testing.T) { + ba := newSparseBitArray() + assert.Equal(t, 0, ba.Count()) + + require.NoError(t, ba.SetBit(0)) + assert.Equal(t, 1, ba.Count()) + + require.NoError(t, ba.SetBit(40)) + require.NoError(t, ba.SetBit(64)) + require.NoError(t, ba.SetBit(100)) + require.NoError(t, ba.SetBit(200)) + require.NoError(t, ba.SetBit(469)) + require.NoError(t, ba.SetBit(500)) + assert.Equal(t, 7, ba.Count()) + + require.NoError(t, ba.ClearBit(200)) + assert.Equal(t, 6, ba.Count()) + + ba.Reset() + assert.Equal(t, 0, ba.Count()) +} + func TestClearCompressedBit(t *testing.T) { ba := newSparseBitArray() ba.SetBit(5)