Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster findall for bitarrays #29888

Merged
merged 3 commits into from
Nov 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 58 additions & 25 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1510,37 +1510,70 @@ function findprev(testf::Function, B::BitArray, start::Integer)
end
#findlast(testf::Function, B::BitArray) = findprev(testf, B, 1) ## defined in array.jl

# findall helper functions
# Generic case (>2 dimensions)
function allindices!(I, B::BitArray)
ind = first(keys(B))
for k = 1:length(B)
I[k] = ind
ind = nextind(B, ind)
end
end

# Optimized case for vector
function allindices!(I, B::BitVector)
I[:] .= 1:length(B)
end

# Optimized case for matrix
function allindices!(I, B::BitMatrix)
k = 1
for c = 1:size(B,2), r = 1:size(B,1)
I[k] = CartesianIndex(r, c)
k += 1
end
end

@inline _overflowind(i1, irest::Tuple{}, size) = (i1, irest)
@inline function _overflowind(i1, irest, size)
i2 = irest[1]
while i1 > size[1]
i1 -= size[1]
i2 += 1
end
i2, irest = _overflowind(i2, tail(irest), tail(size))
return (i1, (i2, irest...))
end

@inline _toind(i1, irest::Tuple{}) = i1
@inline _toind(i1, irest) = CartesianIndex(i1, irest...)

function findall(B::BitArray)
l = length(B)
nnzB = count(B)
ind = first(keys(B))
I = Vector{typeof(ind)}(undef, nnzB)
I = Vector{eltype(keys(B))}(undef, nnzB)
nnzB == 0 && return I
nnzB == length(B) && (allindices!(I, B); return I)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allindices! seems like it should be able to be faster/more generic/less code. It's a little annoying though since we don't yet have the generic Vector(itr) constructor. Maybe it should just be vec(collect(keys(B))) and move the short circuit return to be before you construct I.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It annoyed me too that I needed almost as many lines of code for the allindices! functions as for findall itself. vec(collect(keys(B))) is a great suggestion for vectors and arrays of dim >= 3, but I am seeing much worse performance for matrices (2 dims). This is the simple test script I'm using:

for B in [trues(100000), trues(200, 200), trues(50, 50, 50), trues(16, 16, 16, 16)]
    print(size(B)); @btime findall_optimized($B)
    print(size(B)); @btime vec(collect(keys($B)))
end

With results:

(100000,)  56.197 μs (3 allocations: 781.38 KiB)
(100000,)  55.882 μs (3 allocations: 781.34 KiB)
(200, 200)  49.331 μs (2 allocations: 625.08 KiB)
(200, 200)  72.926 μs (5 allocations: 625.19 KiB)
(50, 50, 50)  222.002 μs (2 allocations: 2.86 MiB)
(50, 50, 50)  225.390 μs (5 allocations: 2.86 MiB)
(16, 16, 16, 16)  151.709 μs (2 allocations: 2.00 MiB)
(16, 16, 16, 16)  155.849 μs (6 allocations: 2.00 MiB)

In fact, for matrices, it would be better then to turn off this special case optimization. Timings for findall_optimized without using allindices!:

(100000,)  74.627 μs (2 allocations: 781.33 KiB)
(200, 200)  52.787 μs (2 allocations: 625.08 KiB)
(50, 50, 50)  234.702 μs (2 allocations: 2.86 MiB)
(16, 16, 16, 16)  165.563 μs (2 allocations: 2.00 MiB)

While I think some performance can be sacrificed for simpler code, IMO the degradation for matrices is a bit much. Can you think of a performant solution that works for arrays of all dimensions? If not, two alternatives are: 1) keep allindices! (or _allindices!) but with only two cases: the BitMatrix one as is, and vec(collect(keys(B))) for all other BitArrays; or 2) make vec(collect(keys(B))) fast for matrices.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the thorough testing here. I think what you have makes sense and is just fine.

Bc = B.chunks
Icount = 1
for i = 1:length(Bc)-1
u = UInt64(1)
c = Bc[i]
for j = 1:64
if c & u != 0
I[Icount] = ind
Icount += 1
end
ind = nextind(B, ind)
u <<= 1
end
end
u = UInt64(1)
c = Bc[end]
for j = 0:_mod64(l-1)
if c & u != 0
I[Icount] = ind
Icount += 1
Bs = size(B)
Bi = i1 = i = 1
irest = ntuple(one, ndims(B) - 1)
c = Bc[1]
@inbounds while true
while c == 0
Bi == length(Bc) && return I
i1 += 64
Bi += 1
c = Bc[Bi]
end
ind = nextind(B, ind)
u <<= 1

tz = trailing_zeros(c)
c = _blsr(c)

i1, irest = _overflowind(i1 + tz, irest, Bs)
I[i] = _toind(i1, irest)
i += 1
i1 -= tz
end
return I
end

# For performance
Expand Down
21 changes: 21 additions & 0 deletions test/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1159,9 +1159,30 @@ timesofar("datamove")
@test findnextnot((.~(b1 >> i)) .⊻ submask, j) == i+1
end

# Do a few more thorough tests for findall
b1 = bitrand(n1, n2)
@check_bit_operation findall(b1) Vector{CartesianIndex{2}}
@check_bit_operation findall(!iszero, b1) Vector{CartesianIndex{2}}

# tall-and-skinny (test index overflow logic in findall)
@check_bit_operation findall(bitrand(1, 1, 1, 250)) Vector{CartesianIndex{4}}

# empty dimensions
@check_bit_operation findall(bitrand(0, 0, 10)) Vector{CartesianIndex{3}}

# sparse (test empty 64-bit chunks in findall)
b1 = falses(8, 8, 8)
b1[3,3,3] = b1[6,6,6] = true
@check_bit_operation findall(b1) Vector{CartesianIndex{3}}

# BitArrays of various dimensions
for dims = 0:8
t = Tuple(fill(2, dims))
ret_type = Vector{dims == 1 ? Int : CartesianIndex{dims}}
@check_bit_operation findall(trues(t)) ret_type
@check_bit_operation findall(falses(t)) ret_type
@check_bit_operation findall(bitrand(t)) ret_type
end
end

timesofar("find")
Expand Down