Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix overflow exception in BitArray #8494

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions spec/std/bit_array_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,27 @@ describe "BitArray" do
(b == c).should be_false
(a == d).should be_false
end

it "compares last bit" do
a = BitArray.new(1000)
b = BitArray.new(1000)
a.should eq(b)
a[-1] = true
a.should_not eq(b)
end

it "compares true-initialized" do
BitArray.new(0, true).should eq(BitArray.new(0))
{31, 32, 33, 63, 64, 65}.each do |size|
ary = BitArray.new(size, true)
bry = BitArray.new(size)
bry.should_not eq(ary)
0.to(size - 2) { |i| bry[i] = true }
bry.should_not eq(ary)
bry[-1] = true
bry.should eq(ary)
end
end
end

describe "[]" do
Expand Down Expand Up @@ -225,6 +246,19 @@ describe "BitArray" do

ba[28..40].should eq(from_int(13, 0b1111111111111))
end

it "gets 32-bit boundaries" do
{32, 64, 128}.each do |ba_size|
ba = BitArray.new(ba_size, true)
{0, 31, 32}.each do |start|
{0, 31, 32}.each do |len|
size = {ba_size - start, len}.min
ba[start, len].size.should eq(size)
ba[start, len].should eq(BitArray.new(size, true))
end
end
end
end
end

it "toggles a bit" do
Expand Down Expand Up @@ -337,4 +371,12 @@ describe "BitArray" do
iter.next.should be_true
iter.next.should be_a(Iterator::Stop)
end

it "hashes" do
1.to(17) do |len|
ba = BitArray.new(len, true)
0.upto(len - 1) { |i| ba[i] = false }
BitArray.new(len).hash.should eq(ba.hash)
end
end
end
31 changes: 22 additions & 9 deletions src/bit_array.cr
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,17 @@ struct BitArray

def ==(other : BitArray)
return false if size != other.size
# NOTE: If BitArray implements resizing, there may be more than 1 binary
# representation and their hashes for equivalent BitArrays after a downsize as the
# discarded bits may not have been zeroed.
return LibC.memcmp(@bits, other.@bits, malloc_size) == 0
return true if size == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why this entire change is needed? I don't understand why the last bit needs special handling. And if it does, there should be a comment in the code explaining why.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs special handling since unused bits at the end of the array are not guaranteed to be in a defined state, so they have to be masked out (i.e., set to zero) before comparing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that's not correct. Pointer.malloc makes sure everything is zero. The comparison is broken because LibC.memcmp expects the number of bytes but it uses malloc_size which is in terms of UInt32. Changing it to use malloc_size * sizeof(UInt32) makes it work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. The method [] will copy bits from an existing bit array and just copy them over without zeroing the non-interesting part, that's why memcmp doesn't work.

I think we should fix [] to guarantee the invariant that unused bits are always zero.

Copy link
Contributor

@RX14 RX14 Dec 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really care which method is used, this PR is certainly a valid method.

Nice catch on the memcmp size being wrong though.

return false if LibC.memcmp(@bits, other.@bits, malloc_size - 1) != 0
last = @bits[malloc_size - 1]
other_last = other.@bits[malloc_size - 1]
return true if last == other_last
trailing = 32 - size % 32
if trailing != 32
last << trailing == other_last << trailing
else
false
end
end

def ==(other)
Expand Down Expand Up @@ -133,18 +140,18 @@ struct BitArray
bits = @bits[0]

bits >>= start
bits &= (1 << count) - 1
bits &= (1 << count) &- 1

BitArray.new(count).tap { |ba| ba.@bits[0] = bits }
elsif size <= 64
# Original fits in int64, we can use bitshifts
bits = @bits.as(UInt64*)[0]

bits >>= start
bits &= (1 << count) - 1
bits &= (1 << count) &- 1

if count <= 32
BitArray.new(count).tap { |ba| ba.@bits[0] = bits.to_u32 }
BitArray.new(count).tap { |ba| ba.@bits[0] = bits.to_u32! }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see when this can overflow if bits &= (1 << count) &- 1 is working correctly. However, it is not.

(1 << count) is simply zero when count >= 32, because 1 is an int32. This doesn't matter though, since we don't care about the MSB bits being zero. So I think the bits &= (1 << count) &- 1 line can be removed and the to_u32! can stay. And I think the masking in the size <=32 case can be removed too. I'd like a comment left about leaving the trailing upper though.

Usages of (1 << count) in the rest of the file must be audited though. I think they're all fine, but i'd like a second pair of eyes.

else
BitArray.new(count).tap { |ba| [email protected](UInt64*)[0] = bits }
end
Expand All @@ -162,7 +169,7 @@ struct BitArray
bits = @bits[start_bit_index + i + 1]

high_bits = bits
high_bits &= (1 << start_sub_index) - 1
high_bits &= (1 << start_sub_index) &- 1
high_bits <<= 32 - start_sub_index

ba.@bits[i] = low_bits | high_bits
Expand Down Expand Up @@ -239,7 +246,13 @@ struct BitArray
# See `Object#hash(hasher)`
def hash(hasher)
hasher = size.hash(hasher)
hasher = to_slice.hash(hasher)
bytes, bits = @size.divmod(8)
if bytes > 0
hasher = Slice.new(@bits.as(Pointer(UInt8)), bytes).hash(hasher)
end
if bits != 0
hasher = (@bits.as(Pointer(UInt8))[bytes] << 8 - bits).hash(hasher)
end
hasher
end

Expand Down