Skip to content

Commit

Permalink
Solve all problems with extra indirection
Browse files Browse the repository at this point in the history
The `getindex` and `setindex!` methods for `Trues` are limited
while also risking  ambiguities. This replaces those definitions
with a specialization for `to_indices` that avoids such problems.

Closes #162
  • Loading branch information
timholy committed Oct 1, 2021
1 parent a8bffd3 commit 44e223b
Showing 1 changed file with 3 additions and 28 deletions.
31 changes: 3 additions & 28 deletions src/trues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,8 @@ const Falses = Zeros{Bool, N, Axes} where {N, Axes}


# y[mask] = x when mask isa Trues (cf y[:] = x)
# Supported here only for arrays with standard OneTo axes.
function Base.setindex!(y::AbstractArray{T,N}, x, mask::Trues{N, NTuple{N,Base.OneTo{Int}}}) where {T,N}
if axes(x) isa NTuple{N,Base.OneTo{Int}} && axes(y) isa NTuple{N,Base.OneTo{Int}}
@boundscheck size(y) == size(mask) || throw(BoundsError(y, mask))
@boundscheck size(x) == size(mask) || throw(DimensionMismatch(
"tried to assign $(length(x)) elements to $(length(y)) destinations"))
@boundscheck checkbounds(y, mask)
return copyto!(y, x)
end
return setindex!(y, x, trues(size(mask))) # fall back on usual setindex!
end

# x[mask] when mask isa Trues (cf x[trues(size(x))] or x[:])
# Supported here only for arrays with standard OneTo axes.
function Base.getindex(x::AbstractArray{T,N}, mask::Trues{N, NTuple{N,Base.OneTo{Int}}}) where {T,N}
if axes(x) isa NTuple{N,Base.OneTo{Int}}
@boundscheck size(x) == size(mask) || throw(BoundsError(x, mask))
return vec(x)
end
return x[trues(size(x))] # else revert to usual getindex method
end

# https://github.com/JuliaArrays/FillArrays.jl/issues/148 and 150
function Base.getindex(
a::AbstractFill{T, N, Tuple{Vararg{Base.OneTo{Int}, N}}},
b::Trues{N, Tuple{Vararg{Base.OneTo{Int}, N}}},
) where {T, N}
@boundscheck size(a) == size(b) || throw(BoundsError(a, b))
return Fill(getindex_value(a), length(a))
function Base.to_indices(A::AbstractArray{T,N}, inds, I::Tuple{Trues{N}}) where {T,N}
@boundscheck axes(A) == axes(I[1]) || Base.throw_boundserror(A, I[1])
(vec(LinearIndices(A)),)
end

0 comments on commit 44e223b

Please sign in to comment.