Skip to content

Commit

Permalink
Optimize findall(f, ::AbstractArray{Bool})
Browse files Browse the repository at this point in the history
* Take shortcuts if f(::Bool) always returns true or false
* Avoid branching in main loop to please branch predictor
* Switch to indexing-agnostic code
* Fix regression mentioned in #42187
  • Loading branch information
jakobnissen committed Sep 12, 2021
1 parent 211ed19 commit 8340cc9
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 7 deletions.
40 changes: 33 additions & 7 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2336,20 +2336,46 @@ Int64[]
function findall(A)
collect(first(p) for p in pairs(A) if last(p))
end

# Allocating result upfront is faster (possible only when collection can be iterated twice)
function findall(A::AbstractArray{Bool})
n = count(A)
function findall(f::Function, A::AbstractArray{Bool})
# Compute f for true and false only once
ft, ff = f(true), f(false)
(ft | ff) || return Vector{eltype(keys(A))}()
(ft & ff) && return vec(Array(keys(A)))
n = let
c = count(A)
ft ? c : length(A) - c
end
I = Vector{eltype(keys(A))}(undef, n)
_findall(ff, I, A)
end

function _findall(invert::Bool, I::Vector, A::AbstractArray{Bool})
cnt = 1
for (i,a) in pairs(A)
if a
I[cnt] = i
cnt += 1
end
len = length(I)
for (k, v) in pairs(A)
cnt > len && break
I[cnt] = k
cnt += v invert
end
I
end

function _findall(invert::Bool, I::Vector, A::AbstractVector{Bool})
i = firstindex(A)
cnt = 1
len = length(I)
@inbounds while cnt len
I[cnt] = i
cnt += A[i] invert
i = nextind(A, i)
end
I
end

findall(A::AbstractArray{Bool}) = findall(identity, A)

findall(x::Bool) = x ? [1] : Vector{Int}()
findall(testf::Function, x::Number) = testf(x) ? [1] : Vector{Int}()
findall(p::Fix2{typeof(in)}, x::Number) = x in p.x ? [1] : Vector{Int}()
Expand Down
8 changes: 8 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -547,9 +547,17 @@ end

@testset "findall, findfirst, findnext, findlast, findprev" begin
a = [0,1,2,3,0,1,2,3]
m = [false false; true false]
@test findall(!iszero, a) == [2,3,4,6,7,8]
@test findall(a.==2) == [3,7]
@test findall(isodd,a) == [2,4,6,8]
@test findall(Bool[]) == Int[]
@test findall([false, false]) == Int[]
@test findall(m) == [k for (k,v) in pairs(m) if v]
@test findall(!, [false, true, true]) == [1]
@test findall(i -> true, [false, true, false]) == [1, 2, 3]
@test findall(i -> false, rand(2, 2)) == Int[]
@test findall(!, m) == [k for (k,v) in pairs(m) if !v]
@test findfirst(!iszero, a) == 2
@test findfirst(a.==0) == 1
@test findfirst(a.==5) == nothing
Expand Down

0 comments on commit 8340cc9

Please sign in to comment.