Skip to content

Commit

Permalink
Enabled linear optimization for logical indexes only if they are the …
Browse files Browse the repository at this point in the history
…only index variable.

`zeros(2,3,4)[2, 1]` is always invalid, thus this optimization is illegal for trailing dimension.
  • Loading branch information
N5N3 committed Jun 30, 2022
1 parent b11ccae commit 2b1fa10
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
18 changes: 11 additions & 7 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -829,18 +829,22 @@ ensure_indexable(I::Tuple{}) = ()
# until Julia gets smart enough to elide the call on its own:
@inline to_indices(A, I::Tuple{Vararg{Union{Integer, CartesianIndex}}}) = to_indices(A, (), I)
# But some index types require more context spanning multiple indices
# CartesianIndex is unfolded outside the inner to_indices for better inference
# `CartesianIndex` is unfolded directly to `Int`s for better inference
_to_indices1(A, inds, I1::CartesianIndex) = map(Fix1(to_index, A), I1.I)
_cutdim(inds, I1::CartesianIndex) = IteratorsMD.split(inds, Val(length(I1)))[2]
# For arrays of CartesianIndex, we just skip the appropriate number of inds
# For arrays of `CartesianIndex`, we just skip the appropriate number of inds
_cutdim(inds, I1::AbstractArray{CartesianIndex{N}}) where {N} = IteratorsMD.split(inds, Val(N))[2]
# And boolean arrays behave similarly; they also skip their number of dimensions
_cutdim(inds::Tuple, I1::AbstractArray{Bool}) = IteratorsMD.split(inds, Val(ndims(I1)))[2]
# As an optimization, we allow trailing Array{Bool} and BitArray to be linear over trailing dimensions
@inline to_indices(A, inds, I::Tuple{Union{Array{Bool,N}, BitArray{N}}}) where {N} =
(_maybe_linear_logical_index(IndexStyle(A), A, I[1]),)
_maybe_linear_logical_index(::IndexStyle, A, i) = to_index(A, i)
_maybe_linear_logical_index(::IndexLinear, A, i) = LogicalIndex{Int}(i)
# As an optimization, we allow the only `AbstractArray{Bool}` to be linear-iterated
@inline to_indices(A, I::Tuple{CartesianIndex{0},Vararg{Union{CartesianIndex{0},AbstractArray{Bool}}}}) = to_indices(A, tail(I))
@inline function to_indices(A, I::Tuple{AbstractArray{Bool},Vararg{CartesianIndex{0}}})
if ndims(A) == ndims(I[1]) && IndexStyle(A) === IndexLinear()
(LogicalIndex{Int}(I[1]),)
else
(to_index(A, I[1]),)
end
end

# Colons get converted to slices by `uncolon`
_to_indices1(A, inds, I1::Colon) = (uncolon(inds),)
Expand Down
20 changes: 18 additions & 2 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,17 @@ function test_vector_indexing(::Type{T}, shape, ::Type{TestAbstractArray}) where

mask = bitrand(shape)
@testset "test logical indexing" begin
let
masks1 = (mask,)
masks2 = ntuple(Returns(CartesianIndex()), 7)
masks2 = Base.setindex(masks2, mask, rand(1:7))
@test only(@inferred(to_indices(A, masks1))) isa Base.LogicalIndex{Int}
@test only(@inferred(to_indices(A, masks2))) isa Base.LogicalIndex{Int}
if IndexStyle(B) isa IndexCartesian
@test only(@inferred(to_indices(B, masks1))) === Base.LogicalIndex(mask)
@test only(@inferred(to_indices(B, masks2))) === Base.LogicalIndex(mask)
end
end
@test B[mask] == A[mask] == B[findall(mask)] == A[findall(mask)] == LinearIndices(mask)[findall(mask)]
@test B[vec(mask)] == A[vec(mask)] == LinearIndices(mask)[findall(mask)]
mask1 = bitrand(size(A, 1))
Expand All @@ -466,10 +477,15 @@ function test_vector_indexing(::Type{T}, shape, ::Type{TestAbstractArray}) where
@test B[mask1, 1, trailing2] == A[mask1, 1, trailing2] == LinearIndices(mask)[findall(mask1)]

if ndims(B) > 1
slice = ntuple(Returns(:), ndims(B)-1)
maskfront = bitrand(shape[1:end-1])
Bslice = B[ntuple(i->(:), ndims(B)-1)..., 1]
@test B[maskfront,1] == Bslice[maskfront]
Bslicefront = B[slice..., 1]
@test B[maskfront, 1] == Bslicefront[maskfront]
@test size(B[maskfront, 1:1]) == (sum(maskfront), 1)
maskend = bitrand(shape[2:end])
Bsliceend = B[1, slice...]
@test B[1 ,maskend] == Bsliceend[maskend]
@test size(B[1:1, maskend]) == (1, sum(maskend))
end
end
end
Expand Down

0 comments on commit 2b1fa10

Please sign in to comment.