From 44e223b463e358cca518d5a35d21601a0191f61e Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Fri, 1 Oct 2021 17:02:39 -0500 Subject: [PATCH] Solve all problems with extra indirection The `getindex` and `setindex!` methods for `Trues` are limited while also risking ambiguities. This replaces those definitions with a specialization for `to_indices` that avoids such problems. Closes #162 --- src/trues.jl | 31 +++---------------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/src/trues.jl b/src/trues.jl index 3381db2d..d7e6508b 100644 --- a/src/trues.jl +++ b/src/trues.jl @@ -22,33 +22,8 @@ const Falses = Zeros{Bool, N, Axes} where {N, Axes} # y[mask] = x when mask isa Trues (cf y[:] = x) -# Supported here only for arrays with standard OneTo axes. -function Base.setindex!(y::AbstractArray{T,N}, x, mask::Trues{N, NTuple{N,Base.OneTo{Int}}}) where {T,N} - if axes(x) isa NTuple{N,Base.OneTo{Int}} && axes(y) isa NTuple{N,Base.OneTo{Int}} - @boundscheck size(y) == size(mask) || throw(BoundsError(y, mask)) - @boundscheck size(x) == size(mask) || throw(DimensionMismatch( - "tried to assign $(length(x)) elements to $(length(y)) destinations")) - @boundscheck checkbounds(y, mask) - return copyto!(y, x) - end - return setindex!(y, x, trues(size(mask))) # fall back on usual setindex! -end - -# x[mask] when mask isa Trues (cf x[trues(size(x))] or x[:]) -# Supported here only for arrays with standard OneTo axes. -function Base.getindex(x::AbstractArray{T,N}, mask::Trues{N, NTuple{N,Base.OneTo{Int}}}) where {T,N} - if axes(x) isa NTuple{N,Base.OneTo{Int}} - @boundscheck size(x) == size(mask) || throw(BoundsError(x, mask)) - return vec(x) - end - return x[trues(size(x))] # else revert to usual getindex method -end -# https://github.com/JuliaArrays/FillArrays.jl/issues/148 and 150 -function Base.getindex( - a::AbstractFill{T, N, Tuple{Vararg{Base.OneTo{Int}, N}}}, - b::Trues{N, Tuple{Vararg{Base.OneTo{Int}, N}}}, -) where {T, N} - @boundscheck size(a) == size(b) || throw(BoundsError(a, b)) - return Fill(getindex_value(a), length(a)) +function Base.to_indices(A::AbstractArray{T,N}, inds, I::Tuple{Trues{N}}) where {T,N} + @boundscheck axes(A) == axes(I[1]) || Base.throw_boundserror(A, I[1]) + (vec(LinearIndices(A)),) end