Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SubIndex, LinearSubIndex, and PermutedIndex types #202

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name = "ArrayInterface"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.33"

Tokazama marked this conversation as resolved.
Show resolved Hide resolved
version = "3.2"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
249 changes: 245 additions & 4 deletions src/array_index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const MatrixIndex = ArrayIndex{2}

const VectorIndex = ArrayIndex{1}

Base.ndims(::ArrayIndex{N}) where {N} = N
Base.ndims(::Type{<:ArrayIndex{N}}) where {N} = N

struct BidiagonalIndex <: MatrixIndex
Expand Down Expand Up @@ -183,6 +184,10 @@ function BandedBlockBandedMatrixIndex(
rowindobj, colindobj
end

Base.firstindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = 1
Base.lastindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count
Base.length(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count

"""
StrideIndex(x)

Expand All @@ -204,11 +209,124 @@ struct StrideIndex{N,R,C,S,O} <: ArrayIndex{N}
end
end

Base.firstindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = 1
Base.lastindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count
Base.length(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count
"""
PermutedIndex

Subtypes of `ArrayIndex` that is responsible for permuting each index prior to accessing
parent indices.
"""
struct PermutedIndex{N,I1,I2} <: ArrayIndex{N}
Tokazama marked this conversation as resolved.
Show resolved Hide resolved
PermutedIndex{N,I1,I2}() where {N,I1,I2} = new{N,I1,I2}()
function PermutedIndex(p::Tuple{Vararg{StaticInt,N}}, ip::Tuple{Vararg{StaticInt}}) where {N}
PermutedIndex{N,known(p),known(ip)}()
end
end

function Base.getindex(x::PermutedIndex{2,(2,1),(2,)}, i::AbstractCartesianIndex{2})
getfield(Tuple(i), 2)
end
@inline function Base.getindex(x::PermutedIndex{N,I1,I2}, i::AbstractCartesianIndex{N}) where {N,I1,I2}
return NDIndex(permute(Tuple(i), Val(I2)))
end

"""
SubIndex(indices)

Subtype of `ArrayIndex` that provides a multidimensional view of another `ArrayIndex`.
"""
struct SubIndex{N,I} <: ArrayIndex{N}
indices::I

SubIndex{N}(inds::Tuple) where {N} = new{N,typeof(inds)}(inds)
end

@inline function Base.getindex(x::SubIndex{N}, i::AbstractCartesianIndex{N}) where {N}
return NDIndex(_reindex(x.indices, Tuple(i)))
end
@generated function _reindex(subinds::S, inds::I) where {S,I}
inds_i = 1
subinds_i = 1
NS = known_length(S)
NI = known_length(I)
out = Expr(:tuple)
while inds_i <= NI
subinds_type = S.parameters[subinds_i]
if subinds_type <: Integer
push!(out.args, :(getfield(subinds, $subinds_i)))
subinds_i += 1
elseif eltype(subinds_type) <: AbstractCartesianIndex
push!(out.args, :(Tuple(@inbounds(getfield(subinds, $subinds_i)[getfield(inds, $inds_i)]))...))
inds_i += 1
subinds_i += 1
else
push!(out.args, :(@inbounds(getfield(subinds, $subinds_i)[getfield(inds, $inds_i)])))
inds_i += 1
subinds_i += 1
end
end
if subinds_i <= NS
for i in subinds_i:NS
push!(out.args, :(getfield(subinds, $subinds_i)))
end
end
return Expr(:block, Expr(:meta, :inline), :($out))
end

"""
LinearSubIndex(offset, stride)

Subtype of `ArrayIndex` that provides linear indexing for `Base.FastSubArray` and
`FastContiguousSubArray`.
"""
struct LinearSubIndex{O<:CanonicalInt,S<:CanonicalInt} <: VectorIndex
offset::O
stride::S
end

const OffsetIndex{O} = LinearSubIndex{O,StaticInt{1}}
OffsetIndex(offset::CanonicalInt) = LinearSubIndex(offset, static(1))

@inline function Base.getindex(x::LinearSubIndex, i::CanonicalInt)
getfield(x, :offset) + getfield(x, :stride) * i
end

"""
ComposedIndex(i1, i2)

A subtype of `ArrayIndex` that lazily combines index `i1` and `i2`. Indexing a
`ComposedIndex` whith `i` is equivalent to `i2[i1[i]]`.
"""
struct ComposedIndex{N,I1,I2} <: ArrayIndex{N}
i1::I1
i2::I2

ComposedIndex(i1::I1, i2::I2) where {I1,I2} = new{ndims(I1),I1,I2}(i1, i2)
end
# we should be able to assume that if `i1` was indexed without error than it's inbounds
@propagate_inbounds function Base.getindex(x::ComposedIndex)
ii = getfield(x, :i1)[]
@inbounds(getfield(x, :i2)[ii])
end
@propagate_inbounds function Base.getindex(x::ComposedIndex, i::CanonicalInt)
ii = getfield(x, :i1)[i]
@inbounds(getfield(x, :i2)[ii])
end
@propagate_inbounds function Base.getindex(x::ComposedIndex, i::AbstractCartesianIndex)
ii = getfield(x, :i1)[i]
@inbounds(getfield(x, :i2)[ii])
end

Base.getindex(x::ArrayIndex, i::ArrayIndex) = ComposedIndex(i, x)
@inline function Base.getindex(x::ComposedIndex, i::ArrayIndex)
ComposedIndex(getfield(x, :i1)[i], getfield(x, :i2))
end
@inline function Base.getindex(x::ArrayIndex, i::ComposedIndex)
ComposedIndex(getfield(i, :i1), x[getfield(i, :i2)])
end
@inline function Base.getindex(x::ComposedIndex, i::ComposedIndex)
ComposedIndex(getfield(i, :i1), ComposedIndex(getfield(x, :i1)[getfield(i, :i2)], getfield(x, :i2)))
end

## getindex
@propagate_inbounds Base.getindex(x::ArrayIndex, i::CanonicalInt, ii::CanonicalInt...) = x[NDIndex(i, ii...)]
@propagate_inbounds function Base.getindex(ind::BidiagonalIndex, i::Int)
@boundscheck 1 <= i <= ind.count || throw(BoundsError(ind, i))
Expand Down Expand Up @@ -288,3 +406,126 @@ end
end
return Expr(:block, Expr(:meta, :inline), out)
end

@inline function Base.getindex(x::StrideIndex, i::SubIndex{N,I}) where {N,I}
_composed_sub_strides(stride_preserving_index(I), x, i)
end
_composed_sub_strides(::False, x::StrideIndex, i::SubIndex) = ComposedIndex(i, x)
@inline function _composed_sub_strides(::True, x::StrideIndex{N,R,C}, i::SubIndex{Ns,I}) where {N,R,C,Ns,I<:Tuple{Vararg{Any,N}}}
c = static(C)
if _get_tuple(I, c) <: AbstractUnitRange
c2 = known(getfield(_from_sub_dims(I), C))
elseif (_get_tuple(I, c) <: AbstractArray) && (_get_tuple(I, c) <: Integer)
c2 = -1
else
c2 = nothing
end

pdims = _to_sub_dims(I)
o = offsets(x)
s = strides(x)
inds = getfield(i, :indices)
out = StrideIndex{Ns,permute(R, pdims),c2}(
eachop(getmul, pdims, map(maybe_static_step, inds), s),
permute(o, pdims)
)
return OffsetIndex(reduce_tup(+, map(*, map(_diff, inds, o), s)))[out]
end
@inline _diff(::Base.Slice, ::Any) = Zero()
@inline _diff(x::AbstractRange, o) = static_first(x) - o
@inline _diff(x::Integer, o) = x - o

@inline function Base.getindex(x::StrideIndex{1,R,C}, ::PermutedIndex{2,(2,1),(2,)}) where {R,C}
if C === nothing
c2 = nothing
elseif C === 1
c2 = 2
else
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add

elseif C === 2
  c2 = 1

Although given the getfield(iperm, C) fallback in the other method, is this somehow not allowed?
Not sure what PermutedIndex is without looking at the code more closely.

c2 = -1
end
s = getfield(strides(x), 1)
return StrideIndex{2,(2,1),c2}((s, s), (static(1), offset1(x)))
end
@inline function Base.getindex(x::StrideIndex{N,R,C}, ::PermutedIndex{N,perm,iperm}) where {N,R,C,perm,iperm}
if C === nothing || C === -1
c2 = C
else
c2 = getfield(iperm, C)
end
return StrideIndex{N,permute(R, Val(perm)),c2}(
permute(strides(x), Val(perm)),
permute(offsets(x), Val(perm)),
)
end
@inline function Base.getindex(x::PermutedIndex, i::PermutedIndex)
PermutedIndex(
permute(to_parent_dims(x), to_parent_dims(i)),
permute(from_parent_dims(x), from_parent_dims(i))
)
end

@inline function Base.getindex(x::LinearSubIndex, i::LinearSubIndex)
s = getfield(x, :stride)
LinearSubIndex(
getfield(x, :offset) + getfield(i, :offset) * s,
getfield(i, :stride) * s
)
end
Base.getindex(::OffsetIndex{StaticInt{0}}, i::StrideIndex) = i


## ArrayIndex constructorrs
@inline _to_cartesian(a) = CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a))))
@inline function _to_linear(a)
N = ndims(a)
StrideIndex{N,ntuple(+, Val(N)),nothing}(size_to_strides(size(a), static(1)), offsets(a))
end

## DenseArray
"""
ArrayIndex{N}(A) -> index

Constructs a subtype of `ArrayIndex` such that an `N` dimensional indexing argument may be
converted to an appropriate state for accessing the buffer of `A`. For example:

```julia
julia> A = reshape(1:20, 4, 5);

julia> index = ArrayInterface.ArrayIndex{2}(A);

julia> ArrayInterface.buffer(A)[index[2, 2]] == A[2, 2]
true

```
"""
ArrayIndex{N}(x::DenseArray) where {N} = StrideIndex(x)
ArrayIndex{1}(x::DenseArray) = OffsetIndex(static(0))

ArrayIndex{1}(x::ReshapedArray) = OffsetIndex(static(0))
ArrayIndex{N}(x::ReshapedArray) where {N} = _to_linear(x)

ArrayIndex{1}(x::AbstractRange) = OffsetIndex(static(0))

## SubArray
ArrayIndex{N}(x::SubArray) where {N} = SubIndex{ndims(x)}(getfield(x, :indices))
function ArrayIndex{1}(x::SubArray{<:Any,N}) where {N}
ComposedIndex(_to_cartesian(x), SubIndex{N}(getfield(x, :indices)))
end
ArrayIndex{1}(x::Base.FastContiguousSubArray) = OffsetIndex(getfield(x, :offset1))
function ArrayIndex{1}(x::Base.FastSubArray)
LinearSubIndex(getfield(x, :offset1), getfield(x, :stride1))
end

## Permuted arrays
ArrayIndex{2}(::MatAdjTrans) = PermutedIndex{2,(2,1),(2,1)}()
ArrayIndex{2}(::VecAdjTrans) = PermutedIndex{2,(2,1),(2,)}()
ArrayIndex{1}(x::MatAdjTrans) = ComposedIndex(_to_cartesian(x), ArrayIndex{2}(x))
ArrayIndex{1}(x::VecAdjTrans) = OffsetIndex(static(0)) # just unwrap permuting struct
ArrayIndex{1}(::PermutedDimsArray{<:Any,1}) = OffsetIndex(static(0))
function ArrayIndex{N}(::PermutedDimsArray{<:Any,N,perm,iperm}) where {N,perm,iperm}
PermutedIndex{N,perm,iperm}()
end
function ArrayIndex{1}(x::PermutedDimsArray{<:Any,N,perm,iperm}) where {N,perm,iperm}
ComposedIndex(_to_cartesian(x), PermutedIndex{N,perm,iperm}())
end

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because many SubArrays and PermutedDimArrays can be represented by StridedIndex, can we have ArrayIndex for these return a StridedIndex?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way I'm treating ArrayIndex here is a one layer deep index representation. The problem with going straight to StrideIndex is that whatever is calling ArrayIndex(::AbstractArray) doesn't know how deeply nested the buffer is. I figured the user level interface would be more like layout(A, accessor) = ArrayIndex(A), buffer(A) and we could special case for cartesian indexing going straight to StrideIndex.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chriselrod, should this PR have something like the layout method implemented so that it is immediately more useful? I was trying to make a PR that was still functional but didn't create too many new things at once, but I can see how that lack of a fully implemented interface runs the risk of breaking changes in the near future. Alternatively, I could add warnings to the docs that this is experimental, so that we can drop/change stuff without worrying about breaking changes for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, I could add warnings to the docs that this is experimental, so that we can drop/change stuff without worrying about breaking changes for now.

Sure, although why not have these things live in a separate, experimental repo?
Are they part of the array interface, or something other libraries would want to extend?

It's hard for me to assess changes like these if I don't know the vision for their future use.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully recent commits involving the indexing pipeline make it more clear that these changes allow unique interactions between layers of nested arrays. If we were to have a separate package for this we'd need ArrayInterface.jl to depend on it so that we could complete the indexing interface.

That being said, we don't need to strictly use what I've proposed here. The point of the ArrayInterface.compose methods here is to support a simple rule based system for composing simplified/efficient index transformations. I'm sure will need the new types here, but I'm completely open to there being a more optimal way to doing this.

9 changes: 2 additions & 7 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,9 @@ end

# TODO delete this once the layout interface is working
_array_index(::IndexLinear, a, i::CanonicalInt) = i
@inline function _array_index(::IndexStyle, a, i::CanonicalInt)
CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a))))[i]
end
@inline _array_index(::IndexStyle, a, i::CanonicalInt) = @inbounds(_to_cartesian(a)[i])
_array_index(::IndexLinear, a, i::AbstractCartesianIndex{1}) = getfield(Tuple(i), 1)
@inline function _array_index(::IndexLinear, a, i::AbstractCartesianIndex)
N = ndims(a)
StrideIndex{N,ntuple(+, Val(N)),nothing}(size_to_strides(size(a), static(1)), offsets(a))[i]
end
@inline _array_index(::IndexLinear, a, i::AbstractCartesianIndex) = _to_linear(a)[i]
_array_index(::IndexStyle, a, i::AbstractCartesianIndex) = i

"""
Expand Down
65 changes: 58 additions & 7 deletions test/array_index.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,65 @@

function test_array_index(x)
@testset "$x" begin
linear_idx = @inferred(ArrayInterface.ArrayIndex{1}(x))
b = ArrayInterface.buffer(x)
for i in eachindex(IndexLinear(), x)
@test b[linear_idx[i]] == x[i]
end
cartesian_idx = @inferred(ArrayInterface.ArrayIndex{ndims(x)}(x))
for i in eachindex(IndexCartesian(), x)
@test b[cartesian_idx[i]] == x[i]
end
end
end

A = zeros(3, 4, 5);
A[:] = 1:60
Ap = @view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])';
A[:] = 1:60;
Aperm = PermutedDimsArray(A,(3,1,2));
Aview = @view(Aperm[:,1:2,1]);
Ap = Aview';
Apperm = PermutedDimsArray(Ap, (2, 1));

ap_index = ArrayInterface.StrideIndex(Ap)
for x_i in axes(Ap, 1)
for y_i in axes(Ap, 2)
@test ap_index[x_i, y_i] == ap_index[x_i, y_i]
end
test_array_index(A)
test_array_index(Aperm)
test_array_index(Aview)
test_array_index(Ap)
test_array_index(view(A, :, :, 1)) # FastContiguousSubArray
test_array_index(view(A, 2, :, :)) # FastSubArray

idx = @inferred(ArrayInterface.ArrayIndex{3}(A)[ArrayInterface.ArrayIndex{3}(Aperm)])
for i in eachindex(IndexCartesian(), Aperm)
@test A[idx[i]] == Aperm[i]
end
idx = @inferred(idx[ArrayInterface.ArrayIndex{2}(Aview)])
for i in eachindex(IndexCartesian(), Aview)
@test A[idx[i]] == Aview[i]
end

idx_perm = @inferred(ArrayInterface.ArrayIndex{2}(Ap)[ArrayInterface.ArrayIndex{2}(Apperm)])
idx = @inferred(idx[idx_perm])
for i in eachindex(IndexCartesian(), Apperm)
@test A[idx[i]] == Apperm[i]
end

v = Vector{Int}(undef, 4);
vp = v'
vnot = @inferred(ArrayInterface.ArrayIndex{1}(v))
vidx = @inferred(vnot[ArrayInterface.StrideIndex(v)])
@test @inferred(vidx[ArrayInterface.ArrayIndex{2}(vp)]) isa ArrayInterface.StrideIndex{2,(2,1)}


idx = @inferred(ArrayInterface.ArrayIndex{1}(1:2))
@test idx[@inferred(ArrayInterface.ArrayIndex{1}((1:2)'))] isa ArrayInterface.OffsetIndex{StaticInt{0}}
@test @inferred(ArrayInterface.ArrayIndex{2}((1:2)'))[CartesianIndex(1, 2)] == 2
@test @inferred(ArrayInterface.ArrayIndex{1}(1:2)) isa ArrayInterface.OffsetIndex{StaticInt{0}}
@test @inferred(ArrayInterface.ArrayIndex{1}((1:2)')) isa ArrayInterface.OffsetIndex{StaticInt{0}}
@test @inferred(ArrayInterface.ArrayIndex{1}(PermutedDimsArray(1:2, (1,)))) isa ArrayInterface.OffsetIndex{StaticInt{0}}
@test @inferred(ArrayInterface.ArrayIndex{1}(reshape(1:10, 2, 5))) isa ArrayInterface.OffsetIndex{StaticInt{0}}
@test @inferred(ArrayInterface.ArrayIndex{2}(reshape(1:10, 2, 5))) isa ArrayInterface.StrideIndex

ap_index = ArrayInterface.StrideIndex(Ap)
@test @inferred(ndims(ap_index)) == ndims(Ap)
@test @inferred(ArrayInterface.known_offsets(ap_index)) === ArrayInterface.known_offsets(Ap)
@test @inferred(ArrayInterface.known_offset1(ap_index)) === ArrayInterface.known_offset1(Ap)
@test @inferred(ArrayInterface.offsets(ap_index, 1)) === ArrayInterface.offset1(Ap)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ end
end
end

@testset "" begin
@testset "ArrayIndex" begin
include("array_index.jl")
end

Expand Down