Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GetSetBits and Count to BitArray #221

Merged
merged 9 commits into from
May 18, 2023
80 changes: 79 additions & 1 deletion bitarray/bitarray.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ efficient way. This is *NOT* a threadsafe package.
*/
package bitarray

import "math/bits"

// bitArray is a struct that maintains state of a bit array.
type bitArray struct {
blocks []block
Expand Down Expand Up @@ -116,7 +118,74 @@ func (ba *bitArray) GetBit(k uint64) (bool, error) {
return result, nil
}

//ClearBit will unset a bit at the given index if it is set.
// GetSetBits gets the position of bits set in the array.
func (ba *bitArray) GetSetBits(from uint64, buffer []uint64) []uint64 {
fromBlockIndex, fromOffset := getIndexAndRemainder(from)
return getSetBitsInBlocks(
fromBlockIndex,
fromOffset,
ba.blocks[fromBlockIndex:],
nil,
buffer,
)
}

// getSetBitsInBlocks fills a buffer with positions of set bits in the provided blocks. Optionally, indices may be
// provided for sparse/non-consecutive blocks.
func getSetBitsInBlocks(
fromBlockIndex, fromOffset uint64,
blocks []block,
indices []uint64,
buffer []uint64,
) []uint64 {
bufferCapacity := cap(buffer)
if bufferCapacity == 0 {
return buffer[:0]
}

results := buffer[:bufferCapacity]
resultSize := 0

for i, block := range blocks {
blockIndex := fromBlockIndex + uint64(i)
if indices != nil {
blockIndex = indices[i]
}

isFirstBlock := blockIndex == fromBlockIndex
if isFirstBlock {
block >>= fromOffset
}

for block != 0 {
trailing := bits.TrailingZeros64(uint64(block))

if isFirstBlock {
results[resultSize] = uint64(trailing) + (blockIndex << 6) + fromOffset
danielway-wk marked this conversation as resolved.
Show resolved Hide resolved
} else {
results[resultSize] = uint64(trailing) + (blockIndex << 6)
}
resultSize++

if resultSize == cap(results) {
return results[:resultSize]
}

// Clear the bit we just added to the result, which is the last bit set in the block. Ex.:
// block 01001100
// ^block 10110011
// (^block) + 1 10110100
// block & (^block) + 1 00000100
// block ^ mask 01001000
danielway-wk marked this conversation as resolved.
Show resolved Hide resolved
mask := block & ((^block) + 1)
block = block ^ mask
}
}

return results[:resultSize]
}

// ClearBit will unset a bit at the given index if it is set.
func (ba *bitArray) ClearBit(k uint64) error {
if k >= ba.Capacity() {
return OutOfRangeError(k)
Expand All @@ -137,6 +206,15 @@ func (ba *bitArray) ClearBit(k uint64) error {
return nil
}

// Count returns the number of set bits in this array.
func (ba *bitArray) Count() int {
count := 0
for _, block := range ba.blocks {
count += bits.OnesCount64(uint64(block))
}
return count
}

// Or will bitwise or two bit arrays and return a new bit array
// representing the result.
func (ba *bitArray) Or(other BitArray) BitArray {
Expand Down
70 changes: 70 additions & 0 deletions bitarray/bitarray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestBitOperations(t *testing.T) {
Expand Down Expand Up @@ -142,6 +143,28 @@ func TestIsEmpty(t *testing.T) {
assert.False(t, ba.IsEmpty())
}

func TestCount(t *testing.T) {
ba := newBitArray(500)
assert.Equal(t, 0, ba.Count())

require.NoError(t, ba.SetBit(0))
assert.Equal(t, 1, ba.Count())

require.NoError(t, ba.SetBit(40))
require.NoError(t, ba.SetBit(64))
require.NoError(t, ba.SetBit(100))
require.NoError(t, ba.SetBit(200))
require.NoError(t, ba.SetBit(469))
require.NoError(t, ba.SetBit(500))
assert.Equal(t, 7, ba.Count())

require.NoError(t, ba.ClearBit(200))
assert.Equal(t, 6, ba.Count())

ba.Reset()
assert.Equal(t, 0, ba.Count())
}

func TestClear(t *testing.T) {
ba := newBitArray(10)

Expand Down Expand Up @@ -195,6 +218,53 @@ func BenchmarkGetBit(b *testing.B) {
}
}

func TestGetSetBits(t *testing.T) {
ba := newBitArray(1000)
buf := make([]uint64, 0, 5)

require.NoError(t, ba.SetBit(1))
require.NoError(t, ba.SetBit(4))
require.NoError(t, ba.SetBit(8))
require.NoError(t, ba.SetBit(63))
require.NoError(t, ba.SetBit(64))
require.NoError(t, ba.SetBit(200))
require.NoError(t, ba.SetBit(1000))

assert.Equal(t, []uint64(nil), ba.GetSetBits(0, nil))
assert.Equal(t, []uint64{}, ba.GetSetBits(0, []uint64{}))

assert.Equal(t, []uint64{1, 4, 8, 63, 64}, ba.GetSetBits(0, buf))
assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(10, buf))
assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(63, buf))
assert.Equal(t, []uint64{200, 1000}, ba.GetSetBits(128, buf))

require.NoError(t, ba.ClearBit(4))
require.NoError(t, ba.ClearBit(64))
assert.Equal(t, []uint64{1, 8, 63, 200, 1000}, ba.GetSetBits(0, buf))
assert.Empty(t, ba.GetSetBits(1001, buf))

ba.Reset()
assert.Empty(t, ba.GetSetBits(0, buf))
}

func BenchmarkGetSetBits(b *testing.B) {
numItems := uint64(168000)

ba := newBitArray(numItems)
for i := uint64(0); i < numItems; i++ {
if i%13 == 0 || i%5 == 0 {
require.NoError(b, ba.SetBit(i))
}
}

buf := make([]uint64, 0, ba.Capacity())

b.ResetTimer()
for i := 0; i < b.N; i++ {
ba.GetSetBits(0, buf)
}
}

func TestEquality(t *testing.T) {
ba := newBitArray(s + 1)
other := newBitArray(s + 1)
Expand Down
6 changes: 6 additions & 0 deletions bitarray/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ type BitArray interface {
// function returns an error if the position is out
// of range. A sparse bit array never returns an error.
GetBit(k uint64) (bool, error)
// GetSetBits gets the position of bits set in the array. Will
// return as many set bits as can fit in the provided buffer
// starting from the specified position in the array.
GetSetBits(from uint64, buffer []uint64) []uint64
// ClearBit clears the bit at the given position. This
// function returns an error if the position is out
// of range. A sparse bit array never returns an error.
Expand All @@ -55,6 +59,8 @@ type BitArray interface {
// in the case of a dense bit array or the highest possible
// seen capacity of the sparse array.
Capacity() uint64
// Count returns the number of set bits in this array.
Count() int
// Or will bitwise or the two bitarrays and return a new bitarray
// representing the result.
Or(other BitArray) BitArray
Expand Down
32 changes: 31 additions & 1 deletion bitarray/sparse_bitarray.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ limitations under the License.

package bitarray

import "sort"
import (
"math/bits"
"sort"
)

// uintSlice is an alias for a slice of ints. Len, Swap, and Less
// are exported to fulfill an interface needed for the search
Expand Down Expand Up @@ -127,6 +130,24 @@ func (sba *sparseBitArray) GetBit(k uint64) (bool, error) {
return sba.blocks[i].get(position), nil
}

// GetSetBits gets the position of bits set in the array.
func (sba *sparseBitArray) GetSetBits(from uint64, buffer []uint64) []uint64 {
fromBlockIndex, fromOffset := getIndexAndRemainder(from)

fromBlockLocation := sba.indices.search(fromBlockIndex)
if int(fromBlockLocation) == len(sba.indices) {
return buffer[:0]
}

return getSetBitsInBlocks(
fromBlockIndex,
fromOffset,
sba.blocks[fromBlockLocation:],
sba.indices[fromBlockLocation:],
buffer,
)
}

// ToNums converts this sparse bitarray to a list of numbers contained
// within it.
func (sba *sparseBitArray) ToNums() []uint64 {
Expand Down Expand Up @@ -225,6 +246,15 @@ func (sba *sparseBitArray) Equals(other BitArray) bool {
return true
}

// Count returns the number of set bits in this array.
func (sba *sparseBitArray) Count() int {
count := 0
for _, block := range sba.blocks {
count += bits.OnesCount64(uint64(block))
}
return count
}

// Or will perform a bitwise or operation with the provided bitarray and
// return a new result bitarray.
func (sba *sparseBitArray) Or(other BitArray) BitArray {
Expand Down
68 changes: 68 additions & 0 deletions bitarray/sparse_bitarray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGetCompressedBit(t *testing.T) {
Expand Down Expand Up @@ -76,6 +77,73 @@ func BenchmarkSetCompressedBit(b *testing.B) {
}
}

func TestGetSetCompressedBits(t *testing.T) {
ba := newSparseBitArray()
buf := make([]uint64, 0, 5)

require.NoError(t, ba.SetBit(1))
require.NoError(t, ba.SetBit(4))
require.NoError(t, ba.SetBit(8))
require.NoError(t, ba.SetBit(63))
require.NoError(t, ba.SetBit(64))
require.NoError(t, ba.SetBit(200))
require.NoError(t, ba.SetBit(1000))

assert.Equal(t, []uint64(nil), ba.GetSetBits(0, nil))
assert.Equal(t, []uint64{}, ba.GetSetBits(0, []uint64{}))

assert.Equal(t, []uint64{1, 4, 8, 63, 64}, ba.GetSetBits(0, buf))
assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(10, buf))
assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(63, buf))
assert.Equal(t, []uint64{200, 1000}, ba.GetSetBits(128, buf))

require.NoError(t, ba.ClearBit(4))
require.NoError(t, ba.ClearBit(64))
assert.Equal(t, []uint64{1, 8, 63, 200, 1000}, ba.GetSetBits(0, buf))
assert.Empty(t, ba.GetSetBits(1001, buf))

ba.Reset()
assert.Empty(t, ba.GetSetBits(0, buf))
}

func BenchmarkGetSetCompressedBits(b *testing.B) {
ba := newSparseBitArray()
for i := uint64(0); i < 168000; i++ {
if i%13 == 0 || i%5 == 0 {
require.NoError(b, ba.SetBit(i))
}
}

buf := make([]uint64, 0, ba.Capacity())

b.ResetTimer()
for i := 0; i < b.N; i++ {
ba.GetSetBits(0, buf)
}
}

func TestCompressedCount(t *testing.T) {
ba := newSparseBitArray()
assert.Equal(t, 0, ba.Count())

require.NoError(t, ba.SetBit(0))
assert.Equal(t, 1, ba.Count())

require.NoError(t, ba.SetBit(40))
require.NoError(t, ba.SetBit(64))
require.NoError(t, ba.SetBit(100))
require.NoError(t, ba.SetBit(200))
require.NoError(t, ba.SetBit(469))
require.NoError(t, ba.SetBit(500))
assert.Equal(t, 7, ba.Count())

require.NoError(t, ba.ClearBit(200))
assert.Equal(t, 6, ba.Count())

ba.Reset()
assert.Equal(t, 0, ba.Count())
}

func TestClearCompressedBit(t *testing.T) {
ba := newSparseBitArray()
ba.SetBit(5)
Expand Down