From 2b1fa102ba420d67ab86a21ca132be8e972f7e2c Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Thu, 30 Jun 2022 09:30:07 +0800 Subject: [PATCH] Enabled linear optimization for logical indexes only if they are the only index variable. `zeros(2,3,4)[2, 1]` is always invalid, thus this optimization is illegal for trailing dimension. --- base/multidimensional.jl | 18 +++++++++++------- test/abstractarray.jl | 20 ++++++++++++++++++-- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/base/multidimensional.jl b/base/multidimensional.jl index 3eecdf17e53181..d7f00bf2e3056e 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -829,18 +829,22 @@ ensure_indexable(I::Tuple{}) = () # until Julia gets smart enough to elide the call on its own: @inline to_indices(A, I::Tuple{Vararg{Union{Integer, CartesianIndex}}}) = to_indices(A, (), I) # But some index types require more context spanning multiple indices -# CartesianIndex is unfolded outside the inner to_indices for better inference +# `CartesianIndex` is unfolded directly to `Int`s for better inference _to_indices1(A, inds, I1::CartesianIndex) = map(Fix1(to_index, A), I1.I) _cutdim(inds, I1::CartesianIndex) = IteratorsMD.split(inds, Val(length(I1)))[2] -# For arrays of CartesianIndex, we just skip the appropriate number of inds +# For arrays of `CartesianIndex`, we just skip the appropriate number of inds _cutdim(inds, I1::AbstractArray{CartesianIndex{N}}) where {N} = IteratorsMD.split(inds, Val(N))[2] # And boolean arrays behave similarly; they also skip their number of dimensions _cutdim(inds::Tuple, I1::AbstractArray{Bool}) = IteratorsMD.split(inds, Val(ndims(I1)))[2] -# As an optimization, we allow trailing Array{Bool} and BitArray to be linear over trailing dimensions -@inline to_indices(A, inds, I::Tuple{Union{Array{Bool,N}, BitArray{N}}}) where {N} = - (_maybe_linear_logical_index(IndexStyle(A), A, I[1]),) -_maybe_linear_logical_index(::IndexStyle, A, i) = to_index(A, i) -_maybe_linear_logical_index(::IndexLinear, A, i) = LogicalIndex{Int}(i) +# As an optimization, we allow the only `AbstractArray{Bool}` to be linear-iterated +@inline to_indices(A, I::Tuple{CartesianIndex{0},Vararg{Union{CartesianIndex{0},AbstractArray{Bool}}}}) = to_indices(A, tail(I)) +@inline function to_indices(A, I::Tuple{AbstractArray{Bool},Vararg{CartesianIndex{0}}}) + if ndims(A) == ndims(I[1]) && IndexStyle(A) === IndexLinear() + (LogicalIndex{Int}(I[1]),) + else + (to_index(A, I[1]),) + end +end # Colons get converted to slices by `uncolon` _to_indices1(A, inds, I1::Colon) = (uncolon(inds),) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 111e2cabbe7c2c..5f0af49948252b 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -457,6 +457,17 @@ function test_vector_indexing(::Type{T}, shape, ::Type{TestAbstractArray}) where mask = bitrand(shape) @testset "test logical indexing" begin + let + masks1 = (mask,) + masks2 = ntuple(Returns(CartesianIndex()), 7) + masks2 = Base.setindex(masks2, mask, rand(1:7)) + @test only(@inferred(to_indices(A, masks1))) isa Base.LogicalIndex{Int} + @test only(@inferred(to_indices(A, masks2))) isa Base.LogicalIndex{Int} + if IndexStyle(B) isa IndexCartesian + @test only(@inferred(to_indices(B, masks1))) === Base.LogicalIndex(mask) + @test only(@inferred(to_indices(B, masks2))) === Base.LogicalIndex(mask) + end + end @test B[mask] == A[mask] == B[findall(mask)] == A[findall(mask)] == LinearIndices(mask)[findall(mask)] @test B[vec(mask)] == A[vec(mask)] == LinearIndices(mask)[findall(mask)] mask1 = bitrand(size(A, 1)) @@ -466,10 +477,15 @@ function test_vector_indexing(::Type{T}, shape, ::Type{TestAbstractArray}) where @test B[mask1, 1, trailing2] == A[mask1, 1, trailing2] == LinearIndices(mask)[findall(mask1)] if ndims(B) > 1 + slice = ntuple(Returns(:), ndims(B)-1) maskfront = bitrand(shape[1:end-1]) - Bslice = B[ntuple(i->(:), ndims(B)-1)..., 1] - @test B[maskfront,1] == Bslice[maskfront] + Bslicefront = B[slice..., 1] + @test B[maskfront, 1] == Bslicefront[maskfront] @test size(B[maskfront, 1:1]) == (sum(maskfront), 1) + maskend = bitrand(shape[2:end]) + Bsliceend = B[1, slice...] + @test B[1 ,maskend] == Bsliceend[maskend] + @test size(B[1:1, maskend]) == (1, sum(maskend)) end end end