Skip to content

Commit

Permalink
Get rid of stail indexing code
Browse files Browse the repository at this point in the history
  • Loading branch information
Tokazama committed Oct 15, 2021
1 parent 9f631ad commit fc95c21
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 134 deletions.
3 changes: 0 additions & 3 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,10 @@ ArrayInterface.safevec
ArrayInterface.setindex!
ArrayInterface.size
ArrayInterface.strides
ArrayInterface.to_axes
ArrayInterface.to_axis
ArrayInterface.to_dims
ArrayInterface.to_index
ArrayInterface.to_indices
ArrayInterface.to_parent_dims
ArrayInterface.unsafe_reconstruct
ArrayInterface.zeromatrix
```

Expand Down
108 changes: 9 additions & 99 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,99 +199,6 @@ end
return LogicalIndex{Int}(arg)
end

"""
unsafe_reconstruct(A, data; kwargs...)
Reconstruct `A` given the values in `data`. New methods using `unsafe_reconstruct`
should only dispatch on `A`.
"""
function unsafe_reconstruct(axis::OneTo, data; kwargs...)
if axis === data
return axis
else
return OneTo(data)
end
end
function unsafe_reconstruct(axis::UnitRange, data; kwargs...)
if axis === data
return axis
else
return UnitRange(first(data), last(data))
end
end
function unsafe_reconstruct(axis::OptionallyStaticUnitRange, data; kwargs...)
if axis === data
return axis
else
return OptionallyStaticUnitRange(static_first(data), static_last(data))
end
end
function unsafe_reconstruct(A::AbstractUnitRange, data; kwargs...)
return static_first(data):static_last(data)
end

"""
to_axes(A, inds) -> Tuple
Construct new axes given the corresponding `inds` constructed after
`to_indices(A, args) -> inds`. This method iterates through each pair of axes and
indices calling [`to_axis`](@ref).
"""
@inline function to_axes(A, inds::Tuple)
if ndims(A) === 1
return (to_axis(axes(A, 1), first(inds)),)
elseif isone(sum(index_ndims(inds)))
return (to_axis(eachindex(IndexLinear(), A), first(inds)),)
else
return to_axes(A, axes(A), inds)
end
end
# drop this dimension
to_axes(A, a::Tuple, i::Tuple{<:Integer,Vararg{Any}}) = to_axes(A, tail(a), tail(i))
to_axes(A, a::Tuple, i::Tuple{I,Vararg{Any}}) where {I} = _to_axes(index_ndims(I), A, a, i)
function _to_axes(::StaticInt{1}, A, axs::Tuple, inds::Tuple)
return (to_axis(first(axs), first(inds)), to_axes(A, tail(axs), tail(inds))...)
end
@propagate_inbounds function _to_axes(::StaticInt{N}, A, axs::Tuple, inds::Tuple) where {N}
axes_front, axes_tail = Base.IteratorsMD.split(axs, Val(N))
if IndexStyle(A) === IndexLinear()
axis = to_axis(LinearIndices(axes_front), getfield(inds, 1))
else
axis = to_axis(CartesianIndices(axes_front), getfield(inds, 1))
end
return (axis, to_axes(A, axes_tail, tail(inds))...)
end
to_axes(A, ::Tuple{Ax,Vararg{Any}}, ::Tuple{}) where {Ax} = ()
to_axes(A, ::Tuple{}, ::Tuple{}) = ()

"""
to_axis(old_axis, index) -> new_axis
Construct an `new_axis` for a newly constructed array that corresponds to the
previously executed `to_index(old_axis, arg) -> index`. `to_axis` assumes that
`index` has already been confirmed to be in bounds. The underlying indices of
`new_axis` begins at one and extends the length of `index` (i.e., one-based indexing).
"""
@inline function to_axis(axis, inds)
if !can_change_size(axis) &&
(known_length(inds) !== nothing && known_length(axis) === known_length(inds))
return axis
else
return to_axis(IndexStyle(axis), axis, inds)
end
end

# don't need to check size b/c slice means it's the entire axis
@inline function to_axis(axis, inds::Slice)
if can_change_size(axis)
return copy(axis)
else
return axis
end
end
to_axis(S::IndexLinear, axis, inds) = StaticInt(1):static_length(inds)


################
### getindex ###
################
Expand All @@ -300,8 +207,8 @@ to_axis(S::IndexLinear, axis, inds) = StaticInt(1):static_length(inds)
Retrieve the value(s) stored at the given key or index within a collection. Creating
another instance of `ArrayInterface.getindex` should only be done by overloading `A`.
Changing indexing based on a given argument from `args` should be done through,
[`to_index`](@ref), or [`to_axis`](@ref).
Changing indexing based on a given argument from `args` should be done through
[`to_index`](@ref).
"""
@propagate_inbounds getindex(A, args...) = unsafe_getindex(A, to_indices(A, args)...)
@propagate_inbounds function getindex(A; kwargs...)
Expand All @@ -326,14 +233,17 @@ end
end

## CartesianIndices/LinearIndices
_ints2range(x::Integer) = x:x
_ints2range(x::AbstractRange) = x
# TODO replace _ints2range with something that actually indexes each axis
_ints2range(::Tuple{}) = ()
@inline _ints2range(x::Tuple{Any,Vararg{Any}}) = (getfield(x, 1), _ints2range(tail(x))...)
@inline _ints2range(x::Tuple{<:Integer,Vararg{Any}}) = _ints2range(tail(x))

unsafe_getindex(A::CartesianIndices, i::Vararg{CanonicalInt,N}) where {N} = @inbounds(A[CartesianIndex(i)])
@inline function unsafe_getindex(A::CartesianIndices{N}, inds::Vararg{Any}) where {N}
if (length(inds) === 1 && N > 1) || stride_preserving_index(typeof(inds)) === False()
return Base._getindex(IndexStyle(A), A, inds...)
else
return CartesianIndices(to_axes(A, _ints2range.(inds)))
return CartesianIndices(_ints2range(inds))
end
end

Expand All @@ -348,7 +258,7 @@ end
if isone(sum(index_ndims(inds)))
return @inbounds(eachindex(A)[first(inds)])
elseif stride_preserving_index(typeof(inds)) === True()
return LinearIndices(to_axes(A, _ints2range.(inds)))
return LinearIndices(_ints2range(inds))
else
return Base._getindex(IndexStyle(A), A, inds...)
end
Expand Down
32 changes: 0 additions & 32 deletions test/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,6 @@ end
@test_throws ArgumentError ArrayInterface.to_index(axis, error)
end

@testset "unsafe_reconstruct" begin
one_to = Base.OneTo(10)
opt_ur = StaticInt(1):10
ur = 1:10
@test @inferred(ArrayInterface.unsafe_reconstruct(one_to, opt_ur)) === one_to
@test @inferred(ArrayInterface.unsafe_reconstruct(one_to, one_to)) === one_to

@test @inferred(ArrayInterface.unsafe_reconstruct(opt_ur, opt_ur)) === opt_ur
@test @inferred(ArrayInterface.unsafe_reconstruct(opt_ur, one_to)) === opt_ur

@test @inferred(ArrayInterface.unsafe_reconstruct(ur, ur)) === ur
@test @inferred(ArrayInterface.unsafe_reconstruct(ur, one_to)) === ur
end

@testset "to_indices" begin
a = ones(2,2,1)
v = ones(2)
Expand Down Expand Up @@ -82,24 +68,6 @@ end
# FIXME @test_throws ErrorException ArrayInterface.to_indices(ones(2,2,2), (1, 1))
end

@testset "to_axes" begin
A = ones(3, 3)
axis = StaticInt(1):StaticInt(3)
inds = StaticInt(1):StaticInt(2)
multi_inds = [CartesianIndex(1, 1), CartesianIndex(1, 2)]

@test @inferred(ArrayInterface.to_axes(A, (axis, axis), (inds, inds))) === (inds, inds)
# vector indexing
@test @inferred(ArrayInterface.to_axes(ones(3), (axis,), (inds,))) === (inds,)
# linear indexing
@test @inferred(ArrayInterface.to_axes(A, (axis, axis), (inds,))) === (inds,)
# multidim arg
@test @inferred(ArrayInterface.to_axes(A, (axis, axis), (multi_inds,))) === (static(1):2,)

@test ArrayInterface.to_axis(axis, axis) === axis
@test ArrayInterface.to_axis(axis, ArrayInterface.indices(axis)) === axis
end

@testset "0-dimensional" begin
x = Array{Int,0}(undef)
ArrayInterface.setindex!(x, 1)
Expand Down

0 comments on commit fc95c21

Please sign in to comment.