From 6de3ecf2e17f65b284318a9e6570f03d2f3b5f8e Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 5 Sep 2021 17:47:35 -0400 Subject: [PATCH 01/16] Some basic subtypes of ArrayIndex --- src/array_index.jl | 153 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 149 insertions(+), 4 deletions(-) diff --git a/src/array_index.jl b/src/array_index.jl index 3e1ce456f..f768ba06e 100644 --- a/src/array_index.jl +++ b/src/array_index.jl @@ -183,6 +183,10 @@ function BandedBlockBandedMatrixIndex( rowindobj, colindobj end +Base.firstindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = 1 +Base.lastindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count +Base.length(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count + """ StrideIndex(x) @@ -204,11 +208,122 @@ struct StrideIndex{N,R,C,S,O} <: ArrayIndex{N} end end -Base.firstindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = 1 -Base.lastindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count -Base.length(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count +""" + PermutedIndex + +Subtypes of `ArrayIndex` that is responsible for permuting each index prior to accessing +parent indices. +""" +struct PermutedIndex{N,I1,I2} <: ArrayIndex{N} + PermutedIndex{N,I1,I2}() where {N,I1,I2} = new{N,I1::NTuple{N,Int},I2::NTuple{N,Int}}() + # the only time we permit the inverse permutation to not have the same length as the permutation + PermutedIndex{2,(2,1),(2,)}() = new{2,(2,1),(2,)}() + + function PermutedIndex(perm::Tuple{Vararg{StaticInt,N}}, iperm::Tuple{Vararg{StaticInt}}) where {N} + PermutedIndex{N,known(perm),known(iperm)}() + end + PermutedIndex(::A) where {A} = PermutedIndex(to_parent_dims(A), from_parent_dims(A)) +end + +""" + SubIndex(indices) + +Subtype of `ArrayIndex` that provides a multidimensional view of another `ArrayIndex`. +""" +struct SubIndex{N,I} <: ArrayIndex{N} + indices::I + + SubIndex{N}(inds::Tuple) where {N} = new{N,typeof(inds)}(inds) + SubIndex(x::SubArray{T,N}) where {T,N} = SubIndex{N}(getfield(x, :indices)) +end + +@inline function Base.getindex(x::SubIndex{N}, i::AbstractCartesianIndex{N}) where {N} + return NDIndex(_reindex(x.indices, Tuple(i))) +end +@generated function _reindex(subinds::S, inds::I) where {S,I} + inds_i = 1 + subinds_i = 1 + NS = known_length(S) + NI = known_length(I) + out = Expr(:tuple) + while inds_i <= NI + subinds_type = S.parameters[subinds_i] + if subinds_type <: Integer + push!(out.args, :(getfield(subinds, $subinds_i))) + subinds_i += 1 + elseif eltype(subinds_type) <: AbstractCartesianIndex + push!(out.args, :(Tuple(@inbounds(getfield(subinds, $subinds_i)[getfield(inds, $inds_i)]))...)) + inds_i += 1 + subinds_i += 1 + else + push!(out.args, :(@inbounds(getfield(subinds, $subinds_i)[getfield(inds, $inds_i)]))) + inds_i += 1 + subinds_i += 1 + end + end + if subinds_i <= NS + for i in subinds_i:NS + push!(out.args, :(getfield(subinds, $subinds_i))) + end + end + return Expr(:block, Expr(:meta, :inline), :($out)) +end + +""" + LinearSubIndex(offset, stride) + +Subtype of `ArrayIndex` that provides linear indexing for `Base.FastSubArray` and +`FastContiguousSubArray`. +""" +struct LinearSubIndex{O<:CanonicalInt,S<:CanonicalInt} <: VectorIndex + offset::O + stride::S +end + +const OffsetIndex{O} = LinearSubIndex{O,StaticInt{1}} +OffsetIndex(offset::CanonicalInt) = LinearSubIndex(offset, static(1)) + +@inline function Base.getindex(x::LinearSubIndex, i::CanonicalInt) + getfield(x, :offset) + getfield(x, :stride) * i +end + +""" + ComposedIndex(i1, i2) + +A subtype of `ArrayIndex` that lazily combines index `i1` and `i2`. Indexing a +`ComposedIndex` whith `i` is equivalent to `i2[i1[i]]`. +""" +struct ComposedIndex{N,I1,I2} <: ArrayIndex{N} + i1::I1 + i2::I2 + + ComposedIndex(i1::I1, i2::I2) where {I1,I2} = new{ndims(I1),I1,I2}(i1, i2) +end +# we should be able to assume that if `i1` was indexed without error than it's inbounds +@propagate_inbounds function Base.getindex(x::ComposedIndex) + ii = getfield(x, :i1)[] + @inbounds(getfield(x, :i2)[ii]) +end +@propagate_inbounds function Base.getindex(x::ComposedIndex, i::CanonicalInt) + ii = getfield(x, :i1)[i] + @inbounds(getfield(x, :i2)[ii]) +end +@propagate_inbounds function Base.getindex(x::ComposedIndex, i::AbstractCartesianIndex) + ii = getfield(x, :i1)[i] + @inbounds(getfield(x, :i2)[ii]) +end + +Base.getindex(x::ArrayIndex, i::ArrayIndex) = ComposedIndex(i, x) +@inline function Base.getindex(x::ComposedIndex, i::ArrayIndex) + ComposedIndex(getfield(x, :i1)[i], getfield(x, :i2)) +end +@inline function Base.getindex(x::ArrayIndex, i::ComposedIndex) + ComposedIndex(getfield(i, :i1), x[getfield(i, :i2)]) +end +@inline function Base.getindex(x::ComposedIndex, i::ComposedIndex) + ComposedIndex(getfield(i, :i1), ComposedIndex(getfield(x, :i1)[getfield(i, :i2)], getfield(x, :i2))) +end -## getindex @propagate_inbounds Base.getindex(x::ArrayIndex, i::CanonicalInt, ii::CanonicalInt...) = x[NDIndex(i, ii...)] @propagate_inbounds function Base.getindex(ind::BidiagonalIndex, i::Int) @boundscheck 1 <= i <= ind.count || throw(BoundsError(ind, i)) @@ -288,3 +403,33 @@ end end return Expr(:block, Expr(:meta, :inline), out) end + +@inline function Base.getindex(x::StrideIndex{1,R,C}, ::PermutedIndex{2,(2,1),(2,)}) where {R,C} + if C === nothing + c2 = nothing + elseif C === 1 + c2 = 2 + else + c2 = -1 + end + s = getfield(strides(x), 1) + return StrideIndex{2,(2,1),c2}((s, s), (static(1), offset1(x))) +end +@inline function Base.getindex(x::StrideIndex{N,R,C}, ::PermutedIndex{N,perm,iperm}) where {N,R,C,perm,iperm} + if C === nothing || C === -1 + c2 = C + else + c2 = getfield(iperm, C) + end + return StrideIndex{N,permute(R, Val(perm)),c2}( + permute(strides(x), Val(perm)), + permute(offsets(x), Val(perm)), + ) +end +@inline function Base.getindex(x::PermutedIndex, i::PermutedIndex) + PermutedIndex( + permute(to_parent_dims(x), to_parent_dims(i)), + permute(from_parent_dims(x), from_parent_dims(i)) + ) +end + From 2d5c2601516d2e67d9d2b1a0800fdd2868865db9 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 5 Sep 2021 20:10:22 -0400 Subject: [PATCH 02/16] Add tests --- src/array_index.jl | 83 +++++++++++++++-- src/indexing.jl | 9 +- test/array_index.jl | 210 ++++++++++++++++++++++++++++++++++++++++++-- test/runtests.jl | 2 +- 4 files changed, 280 insertions(+), 24 deletions(-) diff --git a/src/array_index.jl b/src/array_index.jl index f768ba06e..1f68f9c1d 100644 --- a/src/array_index.jl +++ b/src/array_index.jl @@ -214,15 +214,13 @@ end Subtypes of `ArrayIndex` that is responsible for permuting each index prior to accessing parent indices. """ -struct PermutedIndex{N,I1,I2} <: ArrayIndex{N} - PermutedIndex{N,I1,I2}() where {N,I1,I2} = new{N,I1::NTuple{N,Int},I2::NTuple{N,Int}}() - # the only time we permit the inverse permutation to not have the same length as the permutation - PermutedIndex{2,(2,1),(2,)}() = new{2,(2,1),(2,)}() +struct PermutedIndex{N,I1,I2} <: ArrayIndex{N} end - function PermutedIndex(perm::Tuple{Vararg{StaticInt,N}}, iperm::Tuple{Vararg{StaticInt}}) where {N} - PermutedIndex{N,known(perm),known(iperm)}() - end - PermutedIndex(::A) where {A} = PermutedIndex(to_parent_dims(A), from_parent_dims(A)) +function Base.getindex(x::PermutedIndex{2,(2,1),(2,)}, i::AbstractCartesianIndex{2}) + getfield(Tuple(i), 2) +end +@inline function Base.getindex(x::PermutedIndex{N,I1,I2}, i::AbstractCartesianIndex{N}) where {N,I1,I2} + return NDIndex(permute(Tuple(i), Val(I2))) end """ @@ -234,7 +232,6 @@ struct SubIndex{N,I} <: ArrayIndex{N} indices::I SubIndex{N}(inds::Tuple) where {N} = new{N,typeof(inds)}(inds) - SubIndex(x::SubArray{T,N}) where {T,N} = SubIndex{N}(getfield(x, :indices)) end @inline function Base.getindex(x::SubIndex{N}, i::AbstractCartesianIndex{N}) where {N} @@ -404,6 +401,34 @@ end return Expr(:block, Expr(:meta, :inline), out) end +@inline function Base.getindex(x::StrideIndex, i::SubIndex{N,I}) where {N,I} + _composed_sub_strides(stride_preserving_index(I), x, i) +end +_composed_sub_strides(::False, x::StrideIndex, i::SubIndex) = ComposedIndex(i, x) +@inline function _composed_sub_strides(::True, x::StrideIndex{N,R,C}, i::SubIndex{Ns,I}) where {N,R,C,Ns,I<:Tuple{Vararg{Any,N}}} + c = static(C) + if _get_tuple(I, c) <: AbstractUnitRange + c2 = known(getfield(_from_sub_dims(I), C)) + elseif (_get_tuple(I, c) <: AbstractArray) && (_get_tuple(I, c) <: Integer) + c2 = -1 + else + c2 = nothing + end + + pdims = _to_sub_dims(I) + o = offsets(x) + s = strides(x) + inds = getfield(i, :indices) + out = StrideIndex{Ns,permute(R, pdims),c2}( + eachop(getmul, pdims, map(maybe_static_step, inds), s), + permute(o, pdims) + ) + return OffsetIndex(reduce_tup(+, map(*, map(_diff, inds, o), s)))[out] +end +@inline _diff(::Base.Slice, ::Any) = Zero() +@inline _diff(x::AbstractRange, o) = static_first(x) - o +@inline _diff(x::Integer, o) = x - o + @inline function Base.getindex(x::StrideIndex{1,R,C}, ::PermutedIndex{2,(2,1),(2,)}) where {R,C} if C === nothing c2 = nothing @@ -433,3 +458,43 @@ end ) end +## ArrayIndex constructorrs +@inline _to_cartesian(a) = CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a)))) +@inline function _to_linear(a) + N = ndims(a) + StrideIndex{N,ntuple(+, Val(N)),nothing}(size_to_strides(size(a), static(1)), offsets(a)) +end + +## DenseArray +ArrayIndex{N}(x::DenseArray) where {N} = StrideIndex(x) +ArrayIndex{1}(x::DenseArray) = OffsetIndex(static(0)) + +ArrayIndex{1}(x::ReshapedArray) = OffsetIndex(static(0)) +ArrayIndex{N}(x::ReshapedArray) where {N} = _to_linear(x) + +## SubArray +ArrayIndex{N}(x::SubArray) where {N} = SubIndex{ndims(x)}(getfield(x, :indices)) +function ArrayIndex{1}(x::SubArray{<:Any,N}) where {N} + ComposedIndex(_to_cartesian(x), SubIndex{N}(getfield(x, :indices))) +end +ArrayIndex{1}(x::Base.FastContiguousSubArray) = OffsetIndex(getfield(x, :offset1)) +function ArrayIndex{1}(x::Base.FastSubArray) + LinearSubIndex(getfield(x, :offset1), getfield(x, :stride1)) +end + +## Permuted arrays +ArrayIndex{2}(::MatAdjTrans) = PermutedIndex{2,(2,1),(2,1)}() +ArrayIndex{2}(::VecAdjTrans) = PermutedIndex{2,(2,1),(2,)}() +ArrayIndex{1}(x::MatAdjTrans) = ComposedIndex(_to_cartesian(x), ArrayIndex{2}(x)) +ArrayIndex{1}(x::VecAdjTrans) = OffsetIndex(static(0)) # jus unwrap permuting struct + +function ArrayIndex{N}(::PermutedDimsArray{<:Any,N,perm,iperm}) where {N,perm,iperm} + PermutedIndex{N,perm,iperm}() +end +function ArrayIndex{1}(::PermutedDimsArray{<:Any,1,perm,iperm}) where {perm,iperm} + OffsetIndex(static(0)) +end +function ArrayIndex{1}(x::PermutedDimsArray{<:Any,N,perm,iperm}) where {N,perm,iperm} + ComposedIndex(_to_cartesian(x), PermutedIndex{N,perm,iperm}()) +end + diff --git a/src/indexing.jl b/src/indexing.jl index 11d0ab979..79e6b31d0 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -202,14 +202,9 @@ end # TODO delete this once the layout interface is working _array_index(::IndexLinear, a, i::CanonicalInt) = i -@inline function _array_index(::IndexStyle, a, i::CanonicalInt) - CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a))))[i] -end +@inline _array_index(::IndexStyle, a, i::CanonicalInt) = @inbounds(_to_cartesian(a)[i]) _array_index(::IndexLinear, a, i::AbstractCartesianIndex{1}) = getfield(Tuple(i), 1) -@inline function _array_index(::IndexLinear, a, i::AbstractCartesianIndex) - N = ndims(a) - StrideIndex{N,ntuple(+, Val(N)),nothing}(size_to_strides(size(a), static(1)), offsets(a))[i] -end +@inline _array_index(::IndexLinear, a, i::AbstractCartesianIndex) = _to_linear(a)[i] _array_index(::IndexStyle, a, i::AbstractCartesianIndex) = i """ diff --git a/test/array_index.jl b/test/array_index.jl index f59b70d73..6101bc8bc 100644 --- a/test/array_index.jl +++ b/test/array_index.jl @@ -1,14 +1,36 @@ -A = zeros(3, 4, 5); -A[:] = 1:60 -Ap = @view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])'; -ap_index = ArrayInterface.StrideIndex(Ap) -for x_i in axes(Ap, 1) - for y_i in axes(Ap, 2) - @test ap_index[x_i, y_i] == ap_index[x_i, y_i] +function test_array_index(x) + @testset "$x" begin + linear_idx = ArrayInterface.ArrayIndex{1}(x) + b = ArrayInterface.buffer(x) + for i in eachindex(IndexLinear(), x) + @test b[linear_idx[i]] == x[i] + end + cartesian_idx = ArrayInterface.ArrayIndex{ndims(x)}(x) + for i in eachindex(IndexCartesian(), x) + @test b[cartesian_idx[i]] == x[i] + end end end + + +A = zeros(3, 4, 5); +A[:] = 1:60; +Aperm = PermutedDimsArray(A,(3,1,2)); +Aview = @view(Aperm[:,1:2,1]); +Ap = Aview'; + +#ArrayInterface.ArrayIndex{1}(x) + +test_array_index(A) +test_array_index(Aperm) +test_array_index(Aview) +test_array_index(Ap) +test_array_index(view(A, :, :, 1)) # FastContiguousSubArray +test_array_index(view(A, 2, :, :)) # FastSubArray + +ap_index = ArrayInterface.StrideIndex(Ap) @test @inferred(ArrayInterface.known_offsets(ap_index)) === ArrayInterface.known_offsets(Ap) @test @inferred(ArrayInterface.known_offset1(ap_index)) === ArrayInterface.known_offset1(Ap) @test @inferred(ArrayInterface.offsets(ap_index, 1)) === ArrayInterface.offset1(Ap) @@ -18,3 +40,177 @@ end @test @inferred(ArrayInterface.contiguous_axis(ArrayInterface.StrideIndex{2,(1,2),nothing,NTuple{2,Int},NTuple{2,Int}})) == nothing @test @inferred(ArrayInterface.stride_rank(ap_index)) == (1, 3) + +#= +using Revise +using Pkg +Pkg.activate(".") +using ArrayInterface +using ArrayInterface: buffer, array_index, LinearAccess, CartesianAccess + +function test_layouts(x) + index = ArrayInterface.array_index(x, LinearAccess()) + for i in eachindex(IndexLinear(), x) + @test buffer(x)[index[i]] == x[i] + end + index = ArrayInterface.array_index(x, CartesianAccess()) + for i in eachindex(IndexCartesian(), x) + @test buffer(x)[index[i]] == x[i] + end + + lyt = ArrayInterface.layout(x, LinearAccess()) + for i in eachindex(IndexLinear(), x) + @test lyt[i] == x[i] + end + + lyt = ArrayInterface.layout(x, CartesianAccess()) + for i in eachindex(IndexCartesian(), x) + @test lyt[i] == x[i] + end +end + +A = zeros(Int, 3, 4, 5); +A[:] = 1:60; +Aperm = PermutedDimsArray(A, (3,1,2)); +Asub = @view(Aperm[:,1:2,1]); +Ap = Asub'; + +test_layouts(A) +test_layouts(Aperm) +test_layouts(Asub) +test_layouts(Ap) +test_layouts(view(A, :, :, 1)) # FastContiguousSubArray +test_layouts(view(A, 2, :, 1)) # FastSubArray + + + +lyt = ArrayInterface.layout(view(A, 2, :, 1), LinearAccess()) +for i in eachindex(IndexLinear(), x) + @test lyt[i] == x[i] +end + +function base_add(x) + out = zero(eltype(x)) + @inbounds for i in eachindex(IndexCartesian(), x) + out += x[i] + end + return out +end + +function layout_add(x) + out = zero(eltype(x)) + lyt = ArrayInterface.layout(x, ArrayInterface.CartesianAccess()) + @inbounds for i in eachindex(IndexCartesian(), x) + out += lyt[i] + end + return out +end + + +@btime base_add($Ap) + +@btime layout_add($Ap) + + +#= + +lyt = ArrayInterface.layout(A, CartesianAccess()) +lyt = ArrayInterface.layout(Ap, CartesianAccess()) + +lyt = ArrayInterface.layout(Ap, LinearAccess()) +for i in eachindex(IndexCartesian(), Ap) + @test lyt[i] == Ap[i] +end + +@testset "FastContiguousSubArray" begin + test_array_index(view(A, :, :, 1)) +end +@testset "FastSubArray" begin + test_array_index(view(A, 2, :, 1)) +end + +A = zeros(Int, 3, 4, 5); +A[:] = 1:60; +Aperm = PermutedDimsArray(A, (3,1,2)); +Asub = @view(Aperm[:,1:2,1]); +Ap = Asub'; + +test_layout(Asub) + +i1 = ArrayInterface.array_index(Aperm, LinearAccess()) +i2 = ArrayInterface.array_index(parent(Aperm), CartesianAccess()) +i1[i2] + +i = ArrayInterface.array_index(Aperm, LinearAccess()) +lyt[i] + lyt = ArrayInterface.layout(Aperm, CartesianAccess()) + for i in eachindex(IndexCartesian(), Aperm) + @test lyt[i] == Aperm[i] + end + +lyt = ArrayInterface.layout(A, CartesianAccess()) +lyt = ArrayInterface.layout(Aperm, CartesianAccess()) +lyt = ArrayInterface.layout(Asub, CartesianAccess()) +lyt = ArrayInterface.layout(Ap, CartesianAccess()) + +test_array_index(A) +test_array_index(Aperm) +test_array_index(Asub) +test_array_index(Ap) + +test_layout(A) +test_layout(Aperm) +test_layout(Asub) +test_layout(Ap) +=# + +#= +Asub = view(A, 2, :, 1); +index = array_index(Asub, LinearAccess()) +=# + +# SubArray +Aview = view(A, 2, :, 1); +index = ArrayInterface.SubIndex(Aview) +shaped = ArrayInterface.ShapedIndex(A) +for i in eachindex(Aview) + @test shaped[index[i]] == Aview[i] +end + +stride_index = ArrayInterface.StrideIndex(A) +Aperm = PermutedDimsArray(A,(3,1,2)) +perm_index = ArrayInterface.PermutedIndex(Aperm) +Aview = view(Aperm, 2, 1:2, 1) +sub_index = ArrayInterface.SubIndex(Aview) +Aconj = Aview' +conj_index = ArrayInterface.ConjugateIndex() +multidim = ArrayInterface.MultidimIndex(Aconj) + +composed = stride_index ∘ perm_index ∘ sub_index ∘ conj_index ∘ multidim +x1 = stride_index ∘ perm_index + +for i in eachindex(IndexLinear(), Aconj) + i0 = Aconj[i] + i1 = multidim[i] + i2 = conj_index[i1] + i3 = sub_index[i2] + i4 = perm_index[i3] + i5 = stride_index[i4] + @test Aconj[i1] == i0 + @test Aview[i2] == i0 + @test Aperm[i3] == i0 + @test A[i4] == i0 + @test i5 == i0 + @test composed[i] == i0 +end + +@test @inferred(ArrayInterface.known_offsets(stride_index)) === ArrayInterface.known_offsets(A) +@test @inferred(ArrayInterface.known_offset1(stride_index)) === ArrayInterface.known_offset1(A) +@test @inferred(ArrayInterface.known_strides(stride_index)) === ArrayInterface.known_strides(A) + +A = zeros(3, 4, 5); +A[:] = 1:60; +Aview = view(A, :, 2, :); + +Aview = view(A, 2, :, :); +=# diff --git a/test/runtests.jl b/test/runtests.jl index 4e53a1386..fff1d634a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -668,7 +668,7 @@ end end end -@testset "" begin +@testset "ArrayIndex" begin include("array_index.jl") end From ace2c57148a598fe046c1fc72857658292b9dec3 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 5 Sep 2021 22:13:23 -0400 Subject: [PATCH 03/16] Test ArrayIndex[::ArrayIndex] --- src/array_index.jl | 13 +++++++++++++ test/array_index.jl | 36 ++++++++++++++++++++++++++++++------ 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/src/array_index.jl b/src/array_index.jl index 1f68f9c1d..3448d2e5a 100644 --- a/src/array_index.jl +++ b/src/array_index.jl @@ -15,6 +15,7 @@ const MatrixIndex = ArrayIndex{2} const VectorIndex = ArrayIndex{1} +Base.ndims(::ArrayIndex{N}) where {N} = N Base.ndims(::Type{<:ArrayIndex{N}}) where {N} = N struct BidiagonalIndex <: MatrixIndex @@ -458,6 +459,16 @@ end ) end +@inline function Base.getindex(x::LinearSubIndex, i::LinearSubIndex) + s = getfield(x, :stride) + LinearSubIndex( + getfield(x, :offset) + getfield(i, :offset) * s, + getfield(i, :stride) * s + ) +end +Base.getindex(::OffsetIndex{StaticInt{0}}, i::StrideIndex) = i + + ## ArrayIndex constructorrs @inline _to_cartesian(a) = CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a)))) @inline function _to_linear(a) @@ -472,6 +483,8 @@ ArrayIndex{1}(x::DenseArray) = OffsetIndex(static(0)) ArrayIndex{1}(x::ReshapedArray) = OffsetIndex(static(0)) ArrayIndex{N}(x::ReshapedArray) where {N} = _to_linear(x) +ArrayIndex{1}(x::AbstractRange) = OffsetIndex(static(0)) + ## SubArray ArrayIndex{N}(x::SubArray) where {N} = SubIndex{ndims(x)}(getfield(x, :indices)) function ArrayIndex{1}(x::SubArray{<:Any,N}) where {N} diff --git a/test/array_index.jl b/test/array_index.jl index 6101bc8bc..9a9657625 100644 --- a/test/array_index.jl +++ b/test/array_index.jl @@ -1,27 +1,24 @@ - function test_array_index(x) @testset "$x" begin - linear_idx = ArrayInterface.ArrayIndex{1}(x) + linear_idx = @inferred(ArrayInterface.ArrayIndex{1}(x)) b = ArrayInterface.buffer(x) for i in eachindex(IndexLinear(), x) @test b[linear_idx[i]] == x[i] end - cartesian_idx = ArrayInterface.ArrayIndex{ndims(x)}(x) + cartesian_idx = @inferred(ArrayInterface.ArrayIndex{ndims(x)}(x)) for i in eachindex(IndexCartesian(), x) @test b[cartesian_idx[i]] == x[i] end end end - A = zeros(3, 4, 5); A[:] = 1:60; Aperm = PermutedDimsArray(A,(3,1,2)); Aview = @view(Aperm[:,1:2,1]); Ap = Aview'; - -#ArrayInterface.ArrayIndex{1}(x) +Apperm = PermutedDimsArray(Ap, (2, 1)); test_array_index(A) test_array_index(Aperm) @@ -30,6 +27,24 @@ test_array_index(Ap) test_array_index(view(A, :, :, 1)) # FastContiguousSubArray test_array_index(view(A, 2, :, :)) # FastSubArray +idx = @inferred(ArrayInterface.ArrayIndex{3}(A)[ArrayInterface.ArrayIndex{3}(Aperm)]) +for i in eachindex(IndexCartesian(), Aperm) + @test A[idx[i]] == Aperm[i] +end +idx = @inferred(idx[ArrayInterface.ArrayIndex{2}(Aview)]) +for i in eachindex(IndexCartesian(), Aview) + @test A[idx[i]] == Aview[i] +end +idx = @inferred(idx[ArrayInterface.ArrayIndex{2}(Ap)]) +for i in eachindex(IndexCartesian(), Ap) + @test A[idx[i]] == Ap[i] +end +idx = @inferred(idx[ArrayInterface.ArrayIndex{2}(Apperm)]) +for i in eachindex(IndexCartesian(), Apperm) + @test A[idx[i]] == Apperm[i] +end + + ap_index = ArrayInterface.StrideIndex(Ap) @test @inferred(ArrayInterface.known_offsets(ap_index)) === ArrayInterface.known_offsets(Ap) @test @inferred(ArrayInterface.known_offset1(ap_index)) === ArrayInterface.known_offset1(Ap) @@ -41,6 +56,15 @@ ap_index = ArrayInterface.StrideIndex(Ap) @test @inferred(ArrayInterface.stride_rank(ap_index)) == (1, 3) + +#= +idx1 = ArrayInterface.ArrayIndex{ndims(A)}(A) +idx2 = ArrayInterface.ArrayIndex{ndims(Aperm)}(Aperm) +idx1[idx2] + +idx = ArrayInterface.ArrayIndex{1}(Aperm) +=# + #= using Revise using Pkg From a963ebdd5794866865cb5e927dcbdb5dc5d72355 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 5 Sep 2021 23:19:17 -0400 Subject: [PATCH 04/16] Test more constructors --- Project.toml | 2 +- src/array_index.jl | 7 +- test/array_index.jl | 191 ++------------------------------------------ 3 files changed, 11 insertions(+), 189 deletions(-) diff --git a/Project.toml b/Project.toml index f74d7ff68..0174a42be 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ArrayInterface" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "3.1.31" +version = "3.2" [deps] IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" diff --git a/src/array_index.jl b/src/array_index.jl index 3448d2e5a..f587624cf 100644 --- a/src/array_index.jl +++ b/src/array_index.jl @@ -499,14 +499,11 @@ end ArrayIndex{2}(::MatAdjTrans) = PermutedIndex{2,(2,1),(2,1)}() ArrayIndex{2}(::VecAdjTrans) = PermutedIndex{2,(2,1),(2,)}() ArrayIndex{1}(x::MatAdjTrans) = ComposedIndex(_to_cartesian(x), ArrayIndex{2}(x)) -ArrayIndex{1}(x::VecAdjTrans) = OffsetIndex(static(0)) # jus unwrap permuting struct - +ArrayIndex{1}(x::VecAdjTrans) = OffsetIndex(static(0)) # just unwrap permuting struct +ArrayIndex{1}(::PermutedDimsArray{<:Any,1}) = OffsetIndex(static(0)) function ArrayIndex{N}(::PermutedDimsArray{<:Any,N,perm,iperm}) where {N,perm,iperm} PermutedIndex{N,perm,iperm}() end -function ArrayIndex{1}(::PermutedDimsArray{<:Any,1,perm,iperm}) where {perm,iperm} - OffsetIndex(static(0)) -end function ArrayIndex{1}(x::PermutedDimsArray{<:Any,N,perm,iperm}) where {N,perm,iperm} ComposedIndex(_to_cartesian(x), PermutedIndex{N,perm,iperm}()) end diff --git a/test/array_index.jl b/test/array_index.jl index 9a9657625..8ad2fecc0 100644 --- a/test/array_index.jl +++ b/test/array_index.jl @@ -44,6 +44,14 @@ for i in eachindex(IndexCartesian(), Apperm) @test A[idx[i]] == Apperm[i] end +idx = @inferred(ArrayInterface.ArrayIndex{1}(1:2)) +@inferred idx[@inferred(ArrayInterface.ArrayIndex{1}((1:2)'))] isa ArrayInterface.OffsetIndex{StaticInt{0}} +@test @inferred(ArrayInterface.ArrayIndex{2}((1:2)'))[CartesianIndex(1, 2)] == 2 +@test @inferred(ArrayInterface.ArrayIndex{1}(1:2)) isa ArrayInterface.OffsetIndex{StaticInt{0}} +@test @inferred(ArrayInterface.ArrayIndex{1}((1:2)')) isa ArrayInterface.OffsetIndex{StaticInt{0}} +@test @inferred(ArrayInterface.ArrayIndex{1}(PermutedDimsArray(1:2, (1,)))) isa ArrayInterface.OffsetIndex{StaticInt{0}} +@test @inferred(ArrayInterface.ArrayIndex{1}(reshape(1:10, 2, 5)) isa ArrayInterface.OffsetIndex{StaticInt{0}} +@test @inferred(ArrayInterface.ArrayIndex{2}(reshape(1:10, 2, 5)) isa ArrayInterface.StridIndex ap_index = ArrayInterface.StrideIndex(Ap) @test @inferred(ArrayInterface.known_offsets(ap_index)) === ArrayInterface.known_offsets(Ap) @@ -55,186 +63,3 @@ ap_index = ArrayInterface.StrideIndex(Ap) @test @inferred(ArrayInterface.contiguous_axis(ArrayInterface.StrideIndex{2,(1,2),nothing,NTuple{2,Int},NTuple{2,Int}})) == nothing @test @inferred(ArrayInterface.stride_rank(ap_index)) == (1, 3) - - -#= -idx1 = ArrayInterface.ArrayIndex{ndims(A)}(A) -idx2 = ArrayInterface.ArrayIndex{ndims(Aperm)}(Aperm) -idx1[idx2] - -idx = ArrayInterface.ArrayIndex{1}(Aperm) -=# - -#= -using Revise -using Pkg -Pkg.activate(".") -using ArrayInterface -using ArrayInterface: buffer, array_index, LinearAccess, CartesianAccess - -function test_layouts(x) - index = ArrayInterface.array_index(x, LinearAccess()) - for i in eachindex(IndexLinear(), x) - @test buffer(x)[index[i]] == x[i] - end - index = ArrayInterface.array_index(x, CartesianAccess()) - for i in eachindex(IndexCartesian(), x) - @test buffer(x)[index[i]] == x[i] - end - - lyt = ArrayInterface.layout(x, LinearAccess()) - for i in eachindex(IndexLinear(), x) - @test lyt[i] == x[i] - end - - lyt = ArrayInterface.layout(x, CartesianAccess()) - for i in eachindex(IndexCartesian(), x) - @test lyt[i] == x[i] - end -end - -A = zeros(Int, 3, 4, 5); -A[:] = 1:60; -Aperm = PermutedDimsArray(A, (3,1,2)); -Asub = @view(Aperm[:,1:2,1]); -Ap = Asub'; - -test_layouts(A) -test_layouts(Aperm) -test_layouts(Asub) -test_layouts(Ap) -test_layouts(view(A, :, :, 1)) # FastContiguousSubArray -test_layouts(view(A, 2, :, 1)) # FastSubArray - - - -lyt = ArrayInterface.layout(view(A, 2, :, 1), LinearAccess()) -for i in eachindex(IndexLinear(), x) - @test lyt[i] == x[i] -end - -function base_add(x) - out = zero(eltype(x)) - @inbounds for i in eachindex(IndexCartesian(), x) - out += x[i] - end - return out -end - -function layout_add(x) - out = zero(eltype(x)) - lyt = ArrayInterface.layout(x, ArrayInterface.CartesianAccess()) - @inbounds for i in eachindex(IndexCartesian(), x) - out += lyt[i] - end - return out -end - - -@btime base_add($Ap) - -@btime layout_add($Ap) - - -#= - -lyt = ArrayInterface.layout(A, CartesianAccess()) -lyt = ArrayInterface.layout(Ap, CartesianAccess()) - -lyt = ArrayInterface.layout(Ap, LinearAccess()) -for i in eachindex(IndexCartesian(), Ap) - @test lyt[i] == Ap[i] -end - -@testset "FastContiguousSubArray" begin - test_array_index(view(A, :, :, 1)) -end -@testset "FastSubArray" begin - test_array_index(view(A, 2, :, 1)) -end - -A = zeros(Int, 3, 4, 5); -A[:] = 1:60; -Aperm = PermutedDimsArray(A, (3,1,2)); -Asub = @view(Aperm[:,1:2,1]); -Ap = Asub'; - -test_layout(Asub) - -i1 = ArrayInterface.array_index(Aperm, LinearAccess()) -i2 = ArrayInterface.array_index(parent(Aperm), CartesianAccess()) -i1[i2] - -i = ArrayInterface.array_index(Aperm, LinearAccess()) -lyt[i] - lyt = ArrayInterface.layout(Aperm, CartesianAccess()) - for i in eachindex(IndexCartesian(), Aperm) - @test lyt[i] == Aperm[i] - end - -lyt = ArrayInterface.layout(A, CartesianAccess()) -lyt = ArrayInterface.layout(Aperm, CartesianAccess()) -lyt = ArrayInterface.layout(Asub, CartesianAccess()) -lyt = ArrayInterface.layout(Ap, CartesianAccess()) - -test_array_index(A) -test_array_index(Aperm) -test_array_index(Asub) -test_array_index(Ap) - -test_layout(A) -test_layout(Aperm) -test_layout(Asub) -test_layout(Ap) -=# - -#= -Asub = view(A, 2, :, 1); -index = array_index(Asub, LinearAccess()) -=# - -# SubArray -Aview = view(A, 2, :, 1); -index = ArrayInterface.SubIndex(Aview) -shaped = ArrayInterface.ShapedIndex(A) -for i in eachindex(Aview) - @test shaped[index[i]] == Aview[i] -end - -stride_index = ArrayInterface.StrideIndex(A) -Aperm = PermutedDimsArray(A,(3,1,2)) -perm_index = ArrayInterface.PermutedIndex(Aperm) -Aview = view(Aperm, 2, 1:2, 1) -sub_index = ArrayInterface.SubIndex(Aview) -Aconj = Aview' -conj_index = ArrayInterface.ConjugateIndex() -multidim = ArrayInterface.MultidimIndex(Aconj) - -composed = stride_index ∘ perm_index ∘ sub_index ∘ conj_index ∘ multidim -x1 = stride_index ∘ perm_index - -for i in eachindex(IndexLinear(), Aconj) - i0 = Aconj[i] - i1 = multidim[i] - i2 = conj_index[i1] - i3 = sub_index[i2] - i4 = perm_index[i3] - i5 = stride_index[i4] - @test Aconj[i1] == i0 - @test Aview[i2] == i0 - @test Aperm[i3] == i0 - @test A[i4] == i0 - @test i5 == i0 - @test composed[i] == i0 -end - -@test @inferred(ArrayInterface.known_offsets(stride_index)) === ArrayInterface.known_offsets(A) -@test @inferred(ArrayInterface.known_offset1(stride_index)) === ArrayInterface.known_offset1(A) -@test @inferred(ArrayInterface.known_strides(stride_index)) === ArrayInterface.known_strides(A) - -A = zeros(3, 4, 5); -A[:] = 1:60; -Aview = view(A, :, 2, :); - -Aview = view(A, 2, :, :); -=# From b309e1fd71051911219ba58e530ae59d8f94696a Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 5 Sep 2021 23:24:12 -0400 Subject: [PATCH 05/16] Fix typos --- test/array_index.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/array_index.jl b/test/array_index.jl index 8ad2fecc0..0a9175c6a 100644 --- a/test/array_index.jl +++ b/test/array_index.jl @@ -45,13 +45,13 @@ for i in eachindex(IndexCartesian(), Apperm) end idx = @inferred(ArrayInterface.ArrayIndex{1}(1:2)) -@inferred idx[@inferred(ArrayInterface.ArrayIndex{1}((1:2)'))] isa ArrayInterface.OffsetIndex{StaticInt{0}} +@test idx[@inferred(ArrayInterface.ArrayIndex{1}((1:2)'))] isa ArrayInterface.OffsetIndex{StaticInt{0}} @test @inferred(ArrayInterface.ArrayIndex{2}((1:2)'))[CartesianIndex(1, 2)] == 2 @test @inferred(ArrayInterface.ArrayIndex{1}(1:2)) isa ArrayInterface.OffsetIndex{StaticInt{0}} @test @inferred(ArrayInterface.ArrayIndex{1}((1:2)')) isa ArrayInterface.OffsetIndex{StaticInt{0}} @test @inferred(ArrayInterface.ArrayIndex{1}(PermutedDimsArray(1:2, (1,)))) isa ArrayInterface.OffsetIndex{StaticInt{0}} -@test @inferred(ArrayInterface.ArrayIndex{1}(reshape(1:10, 2, 5)) isa ArrayInterface.OffsetIndex{StaticInt{0}} -@test @inferred(ArrayInterface.ArrayIndex{2}(reshape(1:10, 2, 5)) isa ArrayInterface.StridIndex +@test @inferred(ArrayInterface.ArrayIndex{1}(reshape(1:10, 2, 5))) isa ArrayInterface.OffsetIndex{StaticInt{0}} +@test @inferred(ArrayInterface.ArrayIndex{2}(reshape(1:10, 2, 5))) isa ArrayInterface.StrideIndex ap_index = ArrayInterface.StrideIndex(Ap) @test @inferred(ArrayInterface.known_offsets(ap_index)) === ArrayInterface.known_offsets(Ap) From 3da21dde8587064a154a63ff80809b92ea223e37 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 5 Sep 2021 23:45:36 -0400 Subject: [PATCH 06/16] Fix + test PermutedIndex constructor --- src/array_index.jl | 7 ++++++- test/array_index.jl | 8 +++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/array_index.jl b/src/array_index.jl index f587624cf..2f5fb3b76 100644 --- a/src/array_index.jl +++ b/src/array_index.jl @@ -215,7 +215,12 @@ end Subtypes of `ArrayIndex` that is responsible for permuting each index prior to accessing parent indices. """ -struct PermutedIndex{N,I1,I2} <: ArrayIndex{N} end +struct PermutedIndex{N,I1,I2} <: ArrayIndex{N} + PermutedIndex{N,I1,I2}() where {N,I1,I2} = new{N,I1,I2}() + function PermutedIndex(p::Tuple{Vararg{StaticInt,N}}, ip::Tuple{Vararg{StaticInt}}) where {N} + PermutedIndex{N,known(p),known(ip)}() + end +end function Base.getindex(x::PermutedIndex{2,(2,1),(2,)}, i::AbstractCartesianIndex{2}) getfield(Tuple(i), 2) diff --git a/test/array_index.jl b/test/array_index.jl index 0a9175c6a..c39a3ee1b 100644 --- a/test/array_index.jl +++ b/test/array_index.jl @@ -35,11 +35,9 @@ idx = @inferred(idx[ArrayInterface.ArrayIndex{2}(Aview)]) for i in eachindex(IndexCartesian(), Aview) @test A[idx[i]] == Aview[i] end -idx = @inferred(idx[ArrayInterface.ArrayIndex{2}(Ap)]) -for i in eachindex(IndexCartesian(), Ap) - @test A[idx[i]] == Ap[i] -end -idx = @inferred(idx[ArrayInterface.ArrayIndex{2}(Apperm)]) + +idx_perm = @inferred(ArrayInterface.ArrayIndex{2}(Ap)[ArrayInterface.ArrayIndex{2}(Apperm)]) +idx = @inferred(idx[idx_perm]) for i in eachindex(IndexCartesian(), Apperm) @test A[idx[i]] == Apperm[i] end From adaad827dc5211948dda55543edc348755fe0adb Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 6 Sep 2021 02:32:48 -0400 Subject: [PATCH 07/16] Test vector transposed index --- test/array_index.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/array_index.jl b/test/array_index.jl index c39a3ee1b..4515dfdff 100644 --- a/test/array_index.jl +++ b/test/array_index.jl @@ -42,6 +42,13 @@ for i in eachindex(IndexCartesian(), Apperm) @test A[idx[i]] == Apperm[i] end +v = Vector{Int}(undef, 4); +vp = v' +vnot = @inferred(ArrayInterface.ArrayIndex{1}(v)) +vidx = @inferred(vnot[ArrayInterface.StrideIndex(v)]) +@test @inferred(vidx[ArrayInterface.ArrayIndex{2}(vp)]) isa ArrayInterface.StrideIndex{2,(2,1)} + + idx = @inferred(ArrayInterface.ArrayIndex{1}(1:2)) @test idx[@inferred(ArrayInterface.ArrayIndex{1}((1:2)'))] isa ArrayInterface.OffsetIndex{StaticInt{0}} @test @inferred(ArrayInterface.ArrayIndex{2}((1:2)'))[CartesianIndex(1, 2)] == 2 From d2fef1388e901596891eeeb857d682172fa84ed6 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Tue, 7 Sep 2021 07:47:30 -0400 Subject: [PATCH 08/16] Add docs for `ArrayIndex{N}` constructor --- src/array_index.jl | 16 ++++++++++++++++ test/array_index.jl | 1 + 2 files changed, 17 insertions(+) diff --git a/src/array_index.jl b/src/array_index.jl index 2f5fb3b76..cb1ebf286 100644 --- a/src/array_index.jl +++ b/src/array_index.jl @@ -482,6 +482,22 @@ Base.getindex(::OffsetIndex{StaticInt{0}}, i::StrideIndex) = i end ## DenseArray +""" + ArrayIndex{N}(A) -> index + +Constructs a subtype of `ArrayIndex` such that an `N` dimensional indexing argument may be +converted to an appropriate state for accessing the buffer of `A`. For example: + +```julia +julia> A = reshape(1:20, 4, 5); + +julia> index = ArrayInterface.ArrayIndex{2}(A); + +julia> ArrayInterface.buffer(A)[index[2, 2]] == A[2, 2] +true + +``` +""" ArrayIndex{N}(x::DenseArray) where {N} = StrideIndex(x) ArrayIndex{1}(x::DenseArray) = OffsetIndex(static(0)) diff --git a/test/array_index.jl b/test/array_index.jl index 4515dfdff..04e0ef285 100644 --- a/test/array_index.jl +++ b/test/array_index.jl @@ -59,6 +59,7 @@ idx = @inferred(ArrayInterface.ArrayIndex{1}(1:2)) @test @inferred(ArrayInterface.ArrayIndex{2}(reshape(1:10, 2, 5))) isa ArrayInterface.StrideIndex ap_index = ArrayInterface.StrideIndex(Ap) +@test @inferred(ndims(ap_index)) == ndims(Ap) @test @inferred(ArrayInterface.known_offsets(ap_index)) === ArrayInterface.known_offsets(Ap) @test @inferred(ArrayInterface.known_offset1(ap_index)) === ArrayInterface.known_offset1(Ap) @test @inferred(ArrayInterface.offsets(ap_index, 1)) === ArrayInterface.offset1(Ap) From 58b9d7bd26e13710c357408cde4838e47e53b861 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 20 Sep 2021 11:33:29 -0400 Subject: [PATCH 09/16] Get rid of empty line --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index f6826a503..adf26ee59 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,5 @@ name = "ArrayInterface" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" - version = "3.2" [deps] From d391ba55666c903dfc646d9e262d8398a1a23e58 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Fri, 1 Oct 2021 17:41:02 -0400 Subject: [PATCH 10/16] Experimental files for show --- Project.toml | 2 +- src/ArrayInterface.jl | 3 +- src/Experimental/Experimental.jl | 81 ++++++++++ src/Experimental/access_styles.jl | 47 ++++++ src/Experimental/layouts.jl | 117 ++++++++++++++ src/array_index.jl | 254 +++++++++++++----------------- src/axes.jl | 2 +- src/indexing.jl | 149 +++++------------- src/size.jl | 6 +- test/array_index.jl | 69 ++------ test/dimensions.jl | 3 + test/runtests.jl | 11 +- 12 files changed, 419 insertions(+), 325 deletions(-) create mode 100644 src/Experimental/Experimental.jl create mode 100644 src/Experimental/access_styles.jl create mode 100644 src/Experimental/layouts.jl diff --git a/Project.toml b/Project.toml index adf26ee59..ddfd3573a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ArrayInterface" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "3.2" +version = "3.2.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index aa263b3a7..ab7b090a2 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -665,6 +665,7 @@ function _is_lazy_conjugate(::Type{T}, isconj) where {T <: Adjoint} end include("ranges.jl") +include("Experimental/Experimental.jl") include("axes.jl") include("size.jl") include("dimensions.jl") @@ -672,8 +673,6 @@ include("indexing.jl") include("stridelayout.jl") include("broadcast.jl") - - function __init__() @require SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin diff --git a/src/Experimental/Experimental.jl b/src/Experimental/Experimental.jl new file mode 100644 index 000000000..7341d8ba8 --- /dev/null +++ b/src/Experimental/Experimental.jl @@ -0,0 +1,81 @@ + +include("access_styles.jl") +include("layouts.jl") + +""" instantiate(lyt::Layouted) """ +@inline function instantiate(x::Layouted{S,P}) where {S,P} + lyt = _instantiate(P, layout(parent(x), S())) + return Layouted{typeof(AccessStyle(lyt))}( + parent(lyt), + combined_index(getfield(lyt, :indices), getfield(x, :indices)), + combined_transform(getfield(lyt, :f), getfield(x, :f)) + ) +end +@inline _instantiate(::Type{P1}, lyt::Layouted{S,P2}) where {P1,S,P2} = instantiate(lyt) +_instantiate(::Type{P}, lyt::Layouted{S,P}) where {P,S} = lyt + +@inline function instantiate(x::Layouted{AccessElement{N},P,<:AbstractCartesianIndex{N}}) where {P,N} + Layouted{AccessElement{N}}(instantiate(layout(parent(x), AccessElement{N}())), getfield(x, :indices), getfield(x, :f)) +end +@inline function instantiate(x::Layouted{AccessElement{1},P,<:CanonicalInt}) where {P} + Layouted{AccessElement{1}}(instantiate(layout(parent(x), AccessElement{1}())), getfield(x, :indices), getfield(x, :f)) +end +@inline function instantiate(x::Layouted{S,P,I}) where {S<:AccessIndices,P,I<:Tuple} + Layouted{S}( + instantiate(layout(parent(x), AccessElement{dynamic(ndims_index(I))}())), + getfield(x, :indices), + getfield(x, :f) + ) +end + +# combine element transforms between arrays +combined_transform(::typeof(identity), y) = y +@inline function combined_transform(x::ComposedFunction, y) + getfield(x, :outer) ∘ combined_transform(getfield(x, :inner), y) +end +@inline combined_transform(x, y) = _combined_transform(x, y) +_combined_transform(x, ::typeof(identity)) = x +@inline function _combined_transform(x, y::ComposedFunction) + combined_index(x, getfield(y, :outer)) ∘ getfield(y, :inner) +end + +function Base.showarg(io::IO, x::StrideIndex{N,R,C}, toplevel) where {N,R,C} + print(io, "StrideIndex{$N,$R,$C}(") + print(io, strides(x)) + print(io, ", ") + print(io, offsets(x)) + print(io, ")") +end +function Base.showarg(io::IO, x::SubIndex, toplevel) + print(io, "SubIndex{$(ndims(x))}(") + print_index(io, getfield(x, :indices)) + print(io, ")") +end +function Base.showarg(io::IO, x::LinearSubIndex, toplevel) + print(io, "LinearSubIndex(offset=$(getfield(x, :offset)),stride=$(getfield(x, :stride)))") +end +function Base.showarg(io::IO, x::CombinedIndex, toplevel) + print(io, "combine(") + print_index(io, x.i1) + print(io, ", ") + print_index(io, x.i2) + print(io, ")") +end + +function Base.showarg(io::IO, x::Layouted{S}, toplevel) where {S} + print(io, "Layouted{$S}(") + print_index(io, parent(x)) + print(io, ", ") + print_index(io, x.indices) + print(io, ", ") + print_index(io, x.f) + print(io, ")") +end + +Base.show(io::IO, ::MIME"text/plain", x::Layouted) = Base.showarg(io, x, true) +Base.show(io::IO, ::MIME"text/plain", x::StrideIndex) = Base.showarg(io, x, true) +Base.show(io::IO, ::MIME"text/plain", x::SubIndex) = Base.showarg(io, x, true) + +print_index(io, x::CartesianIndices) = print(io, "::CartesianIndices{$(ndims(x))}") +print_index(io, x) = Base.showarg(io, x, false) + diff --git a/src/Experimental/access_styles.jl b/src/Experimental/access_styles.jl new file mode 100644 index 000000000..10c0be4d0 --- /dev/null +++ b/src/Experimental/access_styles.jl @@ -0,0 +1,47 @@ + +""" + AccessStyle(I) + +`AccessStyle` specifies how the default index `I` accesses other collections. +""" +abstract type AccessStyle end + +struct AccessElement{N} <: AccessStyle end + +struct AccessUnkown{T} <: AccessStyle end + +struct AccessBoolean <: AccessStyle end + +struct AccessRange <: AccessStyle end + +struct AccessIndices{N} <: AccessStyle end + +# FIXME This should be lispy so we can have .. specialization +# _astyle(::Type{I}, i::StaticInt) where {I} = AccessStyle(_get_tuple(I, i)) +# AccessStyle(::Type{I}) where {I<:Tuple) = AccessIndices(eachop(_astyle, nstatic(Val(N)), I)) + +@generated function static_typed_tail(::Type{T}) where {T<:Tuple} + N = length(T.parameters) + out = Expr(:curly, :Tuple) + for i in 2:N + push!(out.args, T.parameters[i]) + end + return out +end + +AccessStyle(::Type{T}) where {T} = AccessUnkown{T}() +AccessStyle(@nospecialize(x::Type{<:Integer})) = AccessElement{1}() +AccessStyle(::Type{<:Union{OneTo,UnitRange,StepRange,OptionallyStaticRange}}) = AccessRange() +@inline AccessStyle(::Type{<:SubIndex{<:Any,I}}) where {I} = AccessElement{sum(dynamic(ndims_index(I)))}() +AccessStyle(x::Type{<:StrideIndex{N,I1,I2}}) where {N,I1,I2} = AccessElement{length(I2)}() +AccessStyle(x::Type{PermutedIndex{2,(2,1),(2,1)}}) = AccessElement{1}() +AccessStyle(x::Type{<:AbstractCartesianIndex{N}}) where {N} = AccessElement{N}() +# TODO should dig into parents +AccessStyle(x::Type{<:AbstractArray}) = AccessStyle(eltype(x)) + +AccessStyle(::Type{Tuple{I}}) where {I} = (AccessStyle(I),) +@inline function AccessStyle(x::Type{Tuple{I,Vararg{Any}}}) where {I} + (AccessStyle(I), AccessStyle(static_typed_tail(x))...) +end +AccessStyle(@nospecialize(x)) = AccessStyle(typeof(x)) + diff --git a/src/Experimental/layouts.jl b/src/Experimental/layouts.jl new file mode 100644 index 000000000..5d91ac8ff --- /dev/null +++ b/src/Experimental/layouts.jl @@ -0,0 +1,117 @@ + +struct Layouted{S<:AccessStyle,P,I,F} + parent::P + indices::I + f::F + + Layouted{S}(p::P, i::I, f::F) where {S,P,I,F} = new{S,P,I,F}(p, i, f) + Layouted{S}(p, i) where {S} = Layouted{S}(p, i, identity) +end + +AccessStyle(::Type{<:Layouted{S}}) where {S} = S() + +parent_type(::Type{<:Layouted{S,P}}) where {S,P} = P + +Base.parent(x::Layouted) = getfield(x, :parent) + +@inline function Base.getindex(x::Layouted{S,P,I}, i) where {S,P,I} + getfield(x, :f)(@inbounds(parent(x)[getfield(x, :indices)[i]])) +end +@inline function Base.getindex(x::Layouted{S,P,Nothing}, i) where {S,P} + getfield(x, :f)(@inbounds(parent(x)[i])) +end + +@inline function Base.setindex!(x::Layouted{S,P,I}, v, i) where {S,P,I} + @inbounds(Base.setindex!(parent(x), getfield(x, :f)(v), getfield(x, :indices)[i])) +end +@inline function Base.setindex!(x::Layouted{S,P,Nothing}, v, i) where {S,P} + @inbounds(Base.setindex!(parent(x), getfield(x, :f)(v), i)) +end + +""" + layout(x, access::AccessStyle) + +Returns a representation of `x`'s layout given a particular `AccessStyle`. +""" +layout(x, i::CanonicalInt) = Layouted{AccessElement{1}}(x, i) +layout(x, i::AbstractCartesianIndex{N}) where {N} = Layouted{AccessElement{N}}(x, i) +layout(x, i::Tuple{CanonicalInt}) = layout(x, getfield(i, 1)) +layout(x, i::Tuple{CanonicalInt,Vararg{CanonicalInt}}) = layout(x, NDIndex(i)) +layout(x, i::Tuple{Vararg{Any,N}}) where {N} = Layouted{AccessIndices{N}}(x, i) +layout(x, s::AccessStyle) = Layouted{typeof(s)}(x, nothing) + +## Base type ranges +@inline function layout(x::Union{UnitRange,OneTo,StepRange,OptionallyStaticRange}, ::AccessElement{1}) + Layouted{AccessElement{1}}(x, nothing, identity) +end + +## Array +layout(x::Array, ::AccessElement{1}) = Layouted{AccessElement{1}}(x, nothing) +@inline layout(x::Array, ::AccessElement) = Layouted{AccessElement{1}}(x, StrideIndex(x)) + +## ReshapedArray +layout(x::ReshapedArray, ::AccessElement{1}) = Layouted{AccessElement{1}}(parent(x), nothing) +@inline function layout(x::ReshapedArray{T,N}, ::AccessElement{N}) where {T,N} + Layouted{AccessElement{1}}(x, _to_linear(x)) +end + +## Transpose/Adjoint{Real} +@inline function layout(x::Union{Transpose{<:Any,<:AbstractMatrix},Adjoint{<:Real,<:AbstractMatrix}}, ::AccessElement{2}) + Layouted{AccessElement{2}}(parent(x), PermutedIndex{2,(2,1),(2,1)}()) +end +@inline function layout(x::Union{Transpose{<:Any,<:AbstractVector},Adjoint{<:Real,<:AbstractVector}}, ::AccessElement{2}) + Layouted{AccessElement{1}}(parent(x), PermutedIndex{2,(2,1),(2,)}()) +end +@inline function layout(x::Union{Transpose{<:Any,<:AbstractMatrix},Adjoint{<:Real,<:AbstractMatrix}}, ::AccessElement{1}) + Layouted{AccessElement{2}}(parent(x), combined_index(PermutedIndex{2,(2,1),(2,1)}(), _to_cartesian(x))) +end +@inline function layout(x::Union{Transpose{<:Any,<:AbstractVector},Adjoint{<:Real,<:AbstractVector}}, ::AccessElement{1}) + Layouted{AccessElement{1}}(parent(x), nothing) +end + +## Adjoint +@inline function layout(x::Adjoint{<:Any,<:AbstractMatrix}, ::AccessElement{2}) + Layouted{AccessElement{2}}(parent(x), PermutedIndex{2,(2,1),(2,1)}(), adjoint) +end +@inline function layout(x::Adjoint{<:Any,<:AbstractVector}, ::AccessElement{2}) + Layouted{AccessElement{1}}(parent(x), PermutedIndex{2,(2,1),(2,)}(), adjoint) +end +@inline function layout(x::Adjoint{<:Any,<:AbstractMatrix}, ::AccessElement{1}) + Layouted{AccessElement{2}}(parent(x), combined_index(PermutedIndex{2,(2,1),(2,1)}(), _to_cartesian(x)), adjoint) +end +@inline function layout(x::Adjoint{<:Any,<:AbstractVector}, ::AccessElement{1}) + Layouted{AccessElement{1}}(parent(x), nothing, adjoint) +end + +## PermutedDimsArray +@inline function layout(x::PermutedDimsArray{T,N,I1,I2}, ::AccessElement{1}) where {T,N,I1,I2} + if N === 1 + return Layouted{AccessElement{1}}(parent(x), nothing) + else + return Layouted{AccessElement{N}}(parent(x), combined_index(PermutedIndex{N,I1,I2}(), _to_cartesian(x))) + end +end +@inline function layout(x::PermutedDimsArray{T,N,I1,I2}, ::AccessElement{N}) where {T,N,I1,I2} + Layouted{AccessElement{N}}(parent(x), PermutedIndex{N,I1,I2}()) +end + +## SubArray +@inline function layout(x::Base.FastContiguousSubArray, ::AccessElement{1}) + Layouted{AccessElement{1}}(parent(x), OffsetIndex(getfield(x, :offset1))) +end +@inline function layout(x::Base.FastSubArray, ::AccessElement{1}) + Layouted{AccessElement{1}}(parent(x), LinearSubIndex(getfield(x, :offset1), getfield(x, :stride1))) +end +@inline function layout(x::SubArray{T,N}, ::AccessElement{1}) where {T,N} + if N === 1 + i = SubIndex{1}(getfield(x, :indices)) + return Layouted{typeof(AccessStyle(i))}(parent(x), i) + else + i = SubIndex{N}(getfield(x, :indices)) + return Layouted{typeof(AccessStyle(i))}(parent(x), combined_index(i, _to_cartesian(x))) + end +end +@inline function layout(x::SubArray{T,N,P,I}, ::AccessElement{N}) where {T,N,P,I} + Layouted{AccessElement{sum(dynamic(ndims_index(I)))}}(parent(x), SubIndex{N}(getfield(x, :indices))) +end + diff --git a/src/array_index.jl b/src/array_index.jl index cb1ebf286..e607badaa 100644 --- a/src/array_index.jl +++ b/src/array_index.jl @@ -184,10 +184,6 @@ function BandedBlockBandedMatrixIndex( rowindobj, colindobj end -Base.firstindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = 1 -Base.lastindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count -Base.length(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count - """ StrideIndex(x) @@ -222,13 +218,6 @@ struct PermutedIndex{N,I1,I2} <: ArrayIndex{N} end end -function Base.getindex(x::PermutedIndex{2,(2,1),(2,)}, i::AbstractCartesianIndex{2}) - getfield(Tuple(i), 2) -end -@inline function Base.getindex(x::PermutedIndex{N,I1,I2}, i::AbstractCartesianIndex{N}) where {N,I1,I2} - return NDIndex(permute(Tuple(i), Val(I2))) -end - """ SubIndex(indices) @@ -240,38 +229,6 @@ struct SubIndex{N,I} <: ArrayIndex{N} SubIndex{N}(inds::Tuple) where {N} = new{N,typeof(inds)}(inds) end -@inline function Base.getindex(x::SubIndex{N}, i::AbstractCartesianIndex{N}) where {N} - return NDIndex(_reindex(x.indices, Tuple(i))) -end -@generated function _reindex(subinds::S, inds::I) where {S,I} - inds_i = 1 - subinds_i = 1 - NS = known_length(S) - NI = known_length(I) - out = Expr(:tuple) - while inds_i <= NI - subinds_type = S.parameters[subinds_i] - if subinds_type <: Integer - push!(out.args, :(getfield(subinds, $subinds_i))) - subinds_i += 1 - elseif eltype(subinds_type) <: AbstractCartesianIndex - push!(out.args, :(Tuple(@inbounds(getfield(subinds, $subinds_i)[getfield(inds, $inds_i)]))...)) - inds_i += 1 - subinds_i += 1 - else - push!(out.args, :(@inbounds(getfield(subinds, $subinds_i)[getfield(inds, $inds_i)]))) - inds_i += 1 - subinds_i += 1 - end - end - if subinds_i <= NS - for i in subinds_i:NS - push!(out.args, :(getfield(subinds, $subinds_i))) - end - end - return Expr(:block, Expr(:meta, :inline), :($out)) -end - """ LinearSubIndex(offset, stride) @@ -286,58 +243,35 @@ end const OffsetIndex{O} = LinearSubIndex{O,StaticInt{1}} OffsetIndex(offset::CanonicalInt) = LinearSubIndex(offset, static(1)) -@inline function Base.getindex(x::LinearSubIndex, i::CanonicalInt) - getfield(x, :offset) + getfield(x, :stride) * i -end - -""" - ComposedIndex(i1, i2) - -A subtype of `ArrayIndex` that lazily combines index `i1` and `i2`. Indexing a -`ComposedIndex` whith `i` is equivalent to `i2[i1[i]]`. -""" -struct ComposedIndex{N,I1,I2} <: ArrayIndex{N} +struct CombinedIndex{N,I1,I2} <: ArrayIndex{N} i1::I1 i2::I2 - ComposedIndex(i1::I1, i2::I2) where {I1,I2} = new{ndims(I1),I1,I2}(i1, i2) + CombinedIndex(i1::I1, i2::I2) where {I1,I2} = new{ndims(I1),I1,I2}(i1, i2) end + # we should be able to assume that if `i1` was indexed without error than it's inbounds -@propagate_inbounds function Base.getindex(x::ComposedIndex) - ii = getfield(x, :i1)[] - @inbounds(getfield(x, :i2)[ii]) +@propagate_inbounds function Base.getindex(x::CombinedIndex) + i2 = getfield(x, :i1)[] + @inbounds(getfield(x, :i1)[ii]) end -@propagate_inbounds function Base.getindex(x::ComposedIndex, i::CanonicalInt) - ii = getfield(x, :i1)[i] - @inbounds(getfield(x, :i2)[ii]) +@propagate_inbounds function Base.getindex(x::CombinedIndex, i::CanonicalInt) + ii = getfield(x, :i2)[i] + @inbounds(getfield(x, :i1)[ii]) end -@propagate_inbounds function Base.getindex(x::ComposedIndex, i::AbstractCartesianIndex) - ii = getfield(x, :i1)[i] - @inbounds(getfield(x, :i2)[ii]) +@propagate_inbounds function Base.getindex(x::CombinedIndex, i::AbstractCartesianIndex) + ii = getfield(x, :i2)[i] + @inbounds(getfield(x, :i1)[ii]) end -Base.getindex(x::ArrayIndex, i::ArrayIndex) = ComposedIndex(i, x) -@inline function Base.getindex(x::ComposedIndex, i::ArrayIndex) - ComposedIndex(getfield(x, :i1)[i], getfield(x, :i2)) -end -@inline function Base.getindex(x::ArrayIndex, i::ComposedIndex) - ComposedIndex(getfield(i, :i1), x[getfield(i, :i2)]) -end -@inline function Base.getindex(x::ComposedIndex, i::ComposedIndex) - ComposedIndex(getfield(i, :i1), ComposedIndex(getfield(x, :i1)[getfield(i, :i2)], getfield(x, :i2))) -end +## Traits -@propagate_inbounds Base.getindex(x::ArrayIndex, i::CanonicalInt, ii::CanonicalInt...) = x[NDIndex(i, ii...)] -@propagate_inbounds function Base.getindex(ind::BidiagonalIndex, i::Int) - @boundscheck 1 <= i <= ind.count || throw(BoundsError(ind, i)) - if ind.isup - ii = i + 1 - else - ii = i + 1 + 1 - end - convert(Int, floor(ii / 2)) -end +Base.firstindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = 1 +Base.lastindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count +Base.length(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count +## getindex +@propagate_inbounds Base.getindex(x::ArrayIndex, i::CanonicalInt, ii::CanonicalInt...) = x[NDIndex(i, ii...)] @propagate_inbounds function Base.getindex(ind::TridiagonalIndex, i::Int) @boundscheck 1 <= i <= ind.count || throw(BoundsError(ind, i)) offsetu = ind.isrow ? 0 : 1 @@ -406,12 +340,84 @@ end end return Expr(:block, Expr(:meta, :inline), out) end +function Base.getindex(x::PermutedIndex{2,(2,1),(2,)}, i::AbstractCartesianIndex{2}) + getfield(Tuple(i), 2) +end +@inline function Base.getindex(x::PermutedIndex{N,I1,I2}, i::AbstractCartesianIndex{N}) where {N,I1,I2} + return NDIndex(permute(Tuple(i), Val(I2))) +end +@inline function Base.getindex(x::SubIndex{N}, i::AbstractCartesianIndex{N}) where {N} + return NDIndex(_reindex(x.indices, Tuple(i))) +end +@generated function _reindex(subinds::S, inds::I) where {S,I} + inds_i = 1 + subinds_i = 1 + NS = known_length(S) + NI = known_length(I) + out = Expr(:tuple) + while inds_i <= NI + subinds_type = S.parameters[subinds_i] + if subinds_type <: Integer + push!(out.args, :(getfield(subinds, $subinds_i))) + subinds_i += 1 + elseif eltype(subinds_type) <: AbstractCartesianIndex + push!(out.args, :(Tuple(@inbounds(getfield(subinds, $subinds_i)[getfield(inds, $inds_i)]))...)) + inds_i += 1 + subinds_i += 1 + else + push!(out.args, :(@inbounds(getfield(subinds, $subinds_i)[getfield(inds, $inds_i)]))) + inds_i += 1 + subinds_i += 1 + end + end + if subinds_i <= NS + for i in subinds_i:NS + push!(out.args, :(getfield(subinds, $subinds_i))) + end + end + return Expr(:block, Expr(:meta, :inline), :($out)) +end +@inline function Base.getindex(x::LinearSubIndex, i::CanonicalInt) + getfield(x, :offset) + getfield(x, :stride) * i +end +@propagate_inbounds function Base.getindex(ind::BidiagonalIndex, i::Int) + @boundscheck 1 <= i <= ind.count || throw(BoundsError(ind, i)) + if ind.isup + ii = i + 1 + else + ii = i + 1 + 1 + end + convert(Int, floor(ii / 2)) +end + +""" + combined_index(i1, i2) -@inline function Base.getindex(x::StrideIndex, i::SubIndex{N,I}) where {N,I} - _composed_sub_strides(stride_preserving_index(I), x, i) +Given two subtypes of `ArrayIndex`, combines a new instance that when indexed is equivalent +to `i1[i2[i]]`. Default behavior produces a `CombinedIndex`, but more `i1` and `i2` may be +consolidated into a more efficient representation. +""" +combined_index(::Nothing, y::ArrayIndex) = y +combined_index(x::ArrayIndex, ::Nothing) = x +combined_index(::Nothing, ::Nothing) = nothing +combined_index(x::ArrayIndex, y::ArrayIndex) = CombinedIndex(x, y) +@inline function combined_index(x::CombinedIndex, y::ArrayIndex) + CombinedIndex(getfield(x, :i1), combined_index(getfield(x, :i2), y)) +end +@inline function combined_index(x::ArrayIndex, y::CombinedIndex) + CombinedIndex(combined_index(x, getfield(y, :i1)), getfield(y, :i2)) +end +@inline function combined_index(x::CombinedIndex, y::CombinedIndex) + CombinedIndex( + getfield(x, :i1), + CombinedIndex(combined_index(getfield(x, :i2), getfield(y, :i1)), getfield(y, :i2)) + ) end -_composed_sub_strides(::False, x::StrideIndex, i::SubIndex) = ComposedIndex(i, x) -@inline function _composed_sub_strides(::True, x::StrideIndex{N,R,C}, i::SubIndex{Ns,I}) where {N,R,C,Ns,I<:Tuple{Vararg{Any,N}}} +@inline function combined_index(x::StrideIndex, y::SubIndex{N,I}) where {N,I} + _combined_sub_strides(stride_preserving_index(I), x, y) +end +_combined_sub_strides(::False, x::StrideIndex, i::SubIndex) = CombinedIndex(x, i) +@inline function _combined_sub_strides(::True, x::StrideIndex{N,R,C}, i::SubIndex{Ns,I}) where {N,R,C,Ns,I<:Tuple{Vararg{Any,N}}} c = static(C) if _get_tuple(I, c) <: AbstractUnitRange c2 = known(getfield(_from_sub_dims(I), C)) @@ -429,13 +435,13 @@ _composed_sub_strides(::False, x::StrideIndex, i::SubIndex) = ComposedIndex(i, x eachop(getmul, pdims, map(maybe_static_step, inds), s), permute(o, pdims) ) - return OffsetIndex(reduce_tup(+, map(*, map(_diff, inds, o), s)))[out] + return combined_index(OffsetIndex(reduce_tup(+, map(*, map(_diff, inds, o), s))), out) end @inline _diff(::Base.Slice, ::Any) = Zero() @inline _diff(x::AbstractRange, o) = static_first(x) - o @inline _diff(x::Integer, o) = x - o -@inline function Base.getindex(x::StrideIndex{1,R,C}, ::PermutedIndex{2,(2,1),(2,)}) where {R,C} +@inline function combined_index(x::StrideIndex{1,R,C}, ::PermutedIndex{2,(2,1),(2,)}) where {R,C} if C === nothing c2 = nothing elseif C === 1 @@ -446,7 +452,9 @@ end s = getfield(strides(x), 1) return StrideIndex{2,(2,1),c2}((s, s), (static(1), offset1(x))) end -@inline function Base.getindex(x::StrideIndex{N,R,C}, ::PermutedIndex{N,perm,iperm}) where {N,R,C,perm,iperm} + + +@inline function combined_index(x::StrideIndex{N,R,C}, ::PermutedIndex{N,perm,iperm}) where {N,R,C,perm,iperm} if C === nothing || C === -1 c2 = C else @@ -457,75 +465,25 @@ end permute(offsets(x), Val(perm)), ) end -@inline function Base.getindex(x::PermutedIndex, i::PermutedIndex) - PermutedIndex( - permute(to_parent_dims(x), to_parent_dims(i)), - permute(from_parent_dims(x), from_parent_dims(i)) - ) +@inline function combined_index(::PermutedIndex{<:Any,I11,I12},::PermutedIndex{<:Any,I21,I22}) where {I11,I12,I21,I22} + PermutedIndex(permute(static(I11), static(I21)), permute(static(I12), static(I22))) end - -@inline function Base.getindex(x::LinearSubIndex, i::LinearSubIndex) +@inline function combined_index(x::LinearSubIndex, i::LinearSubIndex) s = getfield(x, :stride) LinearSubIndex( getfield(x, :offset) + getfield(i, :offset) * s, getfield(i, :stride) * s ) end -Base.getindex(::OffsetIndex{StaticInt{0}}, i::StrideIndex) = i +combined_index(::OffsetIndex{StaticInt{0}}, y::StrideIndex) = y +combined_index(x::ArrayIndex, y::CartesianIndices) = CombinedIndex(x, y) +combined_index(x::CartesianIndices, y::ArrayIndex) = CombinedIndex(x, y) -## ArrayIndex constructorrs +## ArrayIndex constructors @inline _to_cartesian(a) = CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a)))) @inline function _to_linear(a) N = ndims(a) StrideIndex{N,ntuple(+, Val(N)),nothing}(size_to_strides(size(a), static(1)), offsets(a)) end -## DenseArray -""" - ArrayIndex{N}(A) -> index - -Constructs a subtype of `ArrayIndex` such that an `N` dimensional indexing argument may be -converted to an appropriate state for accessing the buffer of `A`. For example: - -```julia -julia> A = reshape(1:20, 4, 5); - -julia> index = ArrayInterface.ArrayIndex{2}(A); - -julia> ArrayInterface.buffer(A)[index[2, 2]] == A[2, 2] -true - -``` -""" -ArrayIndex{N}(x::DenseArray) where {N} = StrideIndex(x) -ArrayIndex{1}(x::DenseArray) = OffsetIndex(static(0)) - -ArrayIndex{1}(x::ReshapedArray) = OffsetIndex(static(0)) -ArrayIndex{N}(x::ReshapedArray) where {N} = _to_linear(x) - -ArrayIndex{1}(x::AbstractRange) = OffsetIndex(static(0)) - -## SubArray -ArrayIndex{N}(x::SubArray) where {N} = SubIndex{ndims(x)}(getfield(x, :indices)) -function ArrayIndex{1}(x::SubArray{<:Any,N}) where {N} - ComposedIndex(_to_cartesian(x), SubIndex{N}(getfield(x, :indices))) -end -ArrayIndex{1}(x::Base.FastContiguousSubArray) = OffsetIndex(getfield(x, :offset1)) -function ArrayIndex{1}(x::Base.FastSubArray) - LinearSubIndex(getfield(x, :offset1), getfield(x, :stride1)) -end - -## Permuted arrays -ArrayIndex{2}(::MatAdjTrans) = PermutedIndex{2,(2,1),(2,1)}() -ArrayIndex{2}(::VecAdjTrans) = PermutedIndex{2,(2,1),(2,)}() -ArrayIndex{1}(x::MatAdjTrans) = ComposedIndex(_to_cartesian(x), ArrayIndex{2}(x)) -ArrayIndex{1}(x::VecAdjTrans) = OffsetIndex(static(0)) # just unwrap permuting struct -ArrayIndex{1}(::PermutedDimsArray{<:Any,1}) = OffsetIndex(static(0)) -function ArrayIndex{N}(::PermutedDimsArray{<:Any,N,perm,iperm}) where {N,perm,iperm} - PermutedIndex{N,perm,iperm}() -end -function ArrayIndex{1}(x::PermutedDimsArray{<:Any,N,perm,iperm}) where {N,perm,iperm} - ComposedIndex(_to_cartesian(x), PermutedIndex{N,perm,iperm}()) -end - diff --git a/src/axes.jl b/src/axes.jl index 1d13eb5e6..6715abf5a 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -311,5 +311,5 @@ lazy_axes(x::LinearIndices) = axes(x) lazy_axes(x::CartesianIndices) = axes(x) @inline lazy_axes(x::MatAdjTrans) = reverse(lazy_axes(parent(x))) @inline lazy_axes(x::VecAdjTrans) = (LazyAxis{1}(x), first(lazy_axes(parent(x)))) -@inline lazy_axes(x::PermutedDimsArray) = permute(lazy_axes(parent(x)), to_parent_dims(A)) +@inline lazy_axes(x::PermutedDimsArray) = permute(lazy_axes(parent(x)), to_parent_dims(x)) diff --git a/src/indexing.jl b/src/indexing.jl index 8721e2b34..8bb0f339c 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -39,9 +39,6 @@ function canonical_convert(x::AbstractUnitRange) return OptionallyStaticUnitRange(static_first(x), static_last(x)) end -is_linear_indexing(A, args::Tuple{Arg}) where {Arg} = ndims_index(Arg) < 2 -is_linear_indexing(A, args::Tuple{Arg,Vararg{Any}}) where {Arg} = false - """ to_indices(A, inds::Tuple) -> Tuple @@ -65,9 +62,15 @@ end @boundscheck if !Base.checkbounds_indices(Bool, lazy_axes(A), inds) throw(BoundsError(A, inds)) end - return inds + if ndims(A) < length(inds) + # FIXME bad solution to trailing indices when canonical + return permute(inds, nstatic(Val(ndims(A)))) + else + return inds + end end end + @propagate_inbounds function _to_indices(::False, A, inds) if isone(sum(ndims_index(inds))) return (to_index(LazyAxis{:}(A), getfield(inds, 1)),) @@ -200,13 +203,6 @@ end return LogicalIndex{Int}(arg) end -# TODO delete this once the layout interface is working -_array_index(::IndexLinear, a, i::CanonicalInt) = i -@inline _array_index(::IndexStyle, a, i::CanonicalInt) = @inbounds(_to_cartesian(a)[i]) -_array_index(::IndexLinear, a, i::AbstractCartesianIndex{1}) = getfield(Tuple(i), 1) -@inline _array_index(::IndexLinear, a, i::AbstractCartesianIndex) = _to_linear(a)[i] -_array_index(::IndexStyle, a, i::AbstractCartesianIndex) = i - """ unsafe_reconstruct(A, data; kwargs...) @@ -299,6 +295,7 @@ end end to_axis(S::IndexLinear, axis, inds) = StaticInt(1):static_length(inds) + ################ ### getindex ### ################ @@ -322,59 +319,29 @@ function unsafe_getindex(a::A) where {A} parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A,))) return unsafe_getindex(parent(a)) end -function unsafe_getindex(a::A, i::CanonicalInt) where {A} - idx = _array_index(IndexStyle(A), a, i) - if idx === i - parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A, i))) - return unsafe_getindex(parent(a), i) - else - return unsafe_getindex(a, idx) - end -end -function unsafe_getindex(a::A, i::AbstractCartesianIndex) where {A} - idx = _array_index(IndexStyle(A), a, i) - if idx === i - parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A, i))) - return unsafe_getindex(parent(a), i) - else - return unsafe_getindex(a, idx) - end -end -function unsafe_getindex(a, i::CanonicalInt, ii::Vararg{CanonicalInt}) - unsafe_getindex(a, NDIndex(i, ii...)) -end -unsafe_getindex(a, i::Vararg{Any}) = unsafe_get_collection(a, i) - unsafe_getindex(A::Array) = Base.arrayref(false, A, 1) -unsafe_getindex(A::Array, i::CanonicalInt) = Base.arrayref(false, A, Int(i)) -unsafe_getindex(A::LinearIndices, i::CanonicalInt) = Int(i) - -unsafe_getindex(A::CartesianIndices, i::AbstractCartesianIndex) = CartesianIndex(i) - -unsafe_getindex(A::SubArray, i::CanonicalInt) = @inbounds(A[i]) -unsafe_getindex(A::SubArray, i::AbstractCartesianIndex) = @inbounds(A[i]) +unsafe_getindex(A::LinearIndices, i::CanonicalInt) = @inbounds(A[Int(i)]) +@inline function unsafe_getindex(A::LinearIndices, i::CanonicalInt, ii::Vararg{CanonicalInt}) + Int(@inbounds(_to_linear(A)[NDIndex(i, ii...)])) +end -# This is based on Base._unsafe_getindex from https://github.com/JuliaLang/julia/blob/c5ede45829bf8eb09f2145bfd6f089459d77b2b1/base/multidimensional.jl#L755. -#= - unsafe_get_collection(A, inds) +unsafe_getindex(A::CartesianIndices, i::AbstractCartesianIndex) = @inbounds(A[CartesianIndex(i)]) +unsafe_getindex(A::CartesianIndices, i::CanonicalInt) = @inbounds(A[CartesianIndex(i)]) +unsafe_getindex(A::CartesianIndices, i::CanonicalInt, ii::Vararg{CanonicalInt}) = CartesianIndex(i, ii...) -Returns a collection of `A` given `inds`. `inds` is assumed to have been bounds-checked. -=# -function unsafe_get_collection(A, inds) - axs = to_axes(A, inds) - dest = similar(A, axs) - if map(Base.unsafe_length, axes(dest)) == map(Base.unsafe_length, axs) - _unsafe_get_index!(dest, A, inds...) # usually a generated function, don't allow it to impact inference result - else - Base.throw_checksize_error(dest, axs) - end - return dest +@inline unsafe_getindex(a, i::Vararg{Any}) = _unsafe_getindex(layout(a, i)) +@inline function _unsafe_getindex(x::Layouted{S}) where {S<:AccessElement} + lyt = instantiate(x) + return getfield(lyt, :f)(@inbounds(parent(lyt)[getfield(lyt, :indices)])) end - -function _generate_unsafe_get_index!_body(N::Int) +@generated function _unsafe_getindex(x::Layouted{S}) where {N,S<:AccessIndices{N}} quote Compat.@inline() + lyt = instantiate(x) + buf = parent(lyt) + I = getfield(lyt, :indices) + dest = similar(parent(x), to_axes(parent(x), I)) D = eachindex(dest) Dy = iterate(D) @inbounds Base.Cartesian.@nloops $N j d -> I[d] begin @@ -382,26 +349,23 @@ function _generate_unsafe_get_index!_body(N::Int) # the optimizer is not clever enough to split the union without it Dy === nothing && return dest (idx, state) = Dy - dest[idx] = unsafe_getindex(src, NDIndex(Base.Cartesian.@ntuple($N, j))) + dest[idx] = buf[NDIndex(Base.Cartesian.@ntuple($N, j))] Dy = iterate(D, state) end return dest end end -@generated function _unsafe_get_index!(dest, src, I::Vararg{Any,N}) where {N} - return _generate_unsafe_get_index!_body(N) -end _ints2range(x::Integer) = x:x _ints2range(x::AbstractRange) = x -@inline function unsafe_get_collection(A::CartesianIndices{N}, inds) where {N} +@inline function unsafe_getindex(A::CartesianIndices{N}, inds::Vararg{Any}) where {N} if (length(inds) === 1 && N > 1) || stride_preserving_index(typeof(inds)) === False() return Base._getindex(IndexStyle(A), A, inds...) else return CartesianIndices(to_axes(A, _ints2range.(inds))) end end -@inline function unsafe_get_collection(A::LinearIndices{N}, inds) where {N} +@inline function unsafe_getindex(A::LinearIndices{N}, inds::Vararg{Any}) where {N} if isone(sum(ndims_index(inds))) return @inbounds(eachindex(A)[first(inds)]) elseif stride_preserving_index(typeof(inds)) === True() @@ -435,48 +399,20 @@ function unsafe_setindex!(a::A, v) where {A} parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v))) return unsafe_setindex!(parent(a), v) end -function unsafe_setindex!(a::A, v, i::CanonicalInt) where {A} - idx = _array_index(IndexStyle(A), a, i) - if idx === i - parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v, i))) - return unsafe_setindex!(parent(a), v, i) - else - return unsafe_setindex!(a, v, idx) - end -end -function unsafe_setindex!(a::A, v, i::AbstractCartesianIndex) where {A} - idx = _array_index(IndexStyle(A), a, i) - if idx === i - parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v, i))) - return unsafe_setindex!(parent(a), v, i) - else - return unsafe_setindex!(a, v, idx) - end -end -function unsafe_setindex!(a, v, i::CanonicalInt, ii::Vararg{CanonicalInt}) - unsafe_setindex!(a, v, NDIndex(i, ii...)) +unsafe_setindex!(A::Array{T}, v) where {T} = Base.arrayset(false, A, convert(T, v)::T, 1) +@inline unsafe_setindex!(a, v, i::Vararg{Any}) = _unsafe_setindex!(layout(a, i), v) +@inline function _unsafe_setindex!(x::Layouted{S}, v) where {S<:AccessElement} + lyt = instantiate(x) + @inbounds(Base.setindex!(parent(lyt), getfield(lyt, :f)(v), getfield(lyt, :indices))) end -function unsafe_setindex!(A::Array{T}, v) where {T} - Base.arrayset(false, A, convert(T, v)::T, 1) -end -function unsafe_setindex!(A::Array{T}, v, i::CanonicalInt) where {T} - return Base.arrayset(false, A, convert(T, v)::T, Int(i)) -end - -unsafe_setindex!(a, v, i::Vararg{Any}) = unsafe_set_collection!(a, v, i) - -# This is based on Base._unsafe_setindex!. -#= - unsafe_set_collection!(A, val, inds) - -Sets `inds` of `A` to `val`. `inds` is assumed to have been bounds-checked. -=# -@inline unsafe_set_collection!(A, v, i) = _unsafe_setindex!(A, v, i...) - -function _generate_unsafe_setindex!_body(N::Int) +@generated function _unsafe_setindex!(x::Layouted{S}, v) where {N,S<:AccessIndices{N}} quote - x′ = Base.unalias(A, x) - Base.Cartesian.@nexprs $N d -> (I_d = Base.unalias(A, I[d])) + lyt = instantiate(x) + buf = parent(lyt) + I = getfield(lyt, :indices) + f = getfield(lyt, :f) + x′ = Base.unalias(buf, v) + Base.Cartesian.@nexprs $N d -> (I_d = Base.unalias(buf, I[d])) idxlens = Base.Cartesian.@ncall $N Base.index_lengths I Base.Cartesian.@ncall $N Base.setindex_shape_check x′ (d -> idxlens[d]) Xy = iterate(x′) @@ -485,14 +421,9 @@ function _generate_unsafe_setindex!_body(N::Int) # the optimizer that it does not need to emit error paths Xy === nothing && break (val, state) = Xy - unsafe_setindex!(A, val, NDIndex(Base.Cartesian.@ntuple($N, i))) + buf[NDIndex(Base.Cartesian.@ntuple($N, i))] = f(val) Xy = iterate(x′, state) end - A end end -@generated function _unsafe_setindex!(A, x, I::Vararg{Any,N}) where {N} - return _generate_unsafe_setindex!_body(N) -end - diff --git a/src/size.jl b/src/size.jl index eb338e5ec..37b7c6d8a 100644 --- a/src/size.jl +++ b/src/size.jl @@ -26,10 +26,8 @@ end size(x::SubArray) = eachop(_sub_size, to_parent_dims(x), x.indices) _sub_size(x::Tuple, ::StaticInt{dim}) where {dim} = static_length(getfield(x, dim)) @inline size(B::VecAdjTrans) = (One(), length(parent(B))) -@inline size(B::MatAdjTrans) = permute(size(parent(B)), to_parent_dims(B)) -@inline function size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A} - return permute(size(parent(B)), to_parent_dims(B)) -end +@inline size(B::MatAdjTrans) = permute(size(parent(B)), (static(2), static(1))) +@inline size(B::PermutedDimsArray{T,N,I}) where {T,N,I} = permute(size(parent(B)), static(I)) function size(a::ReinterpretArray{T,N,S,A}) where {T,N,S,A} psize = size(parent(a)) if _is_reshaped(typeof(a)) diff --git a/test/array_index.jl b/test/array_index.jl index 04e0ef285..03792ce6c 100644 --- a/test/array_index.jl +++ b/test/array_index.jl @@ -1,71 +1,28 @@ -function test_array_index(x) +function test_layout(x) @testset "$x" begin - linear_idx = @inferred(ArrayInterface.ArrayIndex{1}(x)) - b = ArrayInterface.buffer(x) + linear_lyt = ArrayInterface.instantiate(ArrayInterface.layout(x, ArrayInterface.AccessElement{1}())) for i in eachindex(IndexLinear(), x) - @test b[linear_idx[i]] == x[i] + @test linear_lyt[i] == x[i] end - cartesian_idx = @inferred(ArrayInterface.ArrayIndex{ndims(x)}(x)) + cartesian_lyt = ArrayInterface.instantiate(ArrayInterface.layout(x, ArrayInterface.AccessElement{ndims(x)}())) for i in eachindex(IndexCartesian(), x) - @test b[cartesian_idx[i]] == x[i] + @test cartesian_lyt[i] == x[i] end end + return nothing end -A = zeros(3, 4, 5); -A[:] = 1:60; +A = rand(4,4,4); +A[:] .= eachindex(A); Aperm = PermutedDimsArray(A,(3,1,2)); Aview = @view(Aperm[:,1:2,1]); Ap = Aview'; Apperm = PermutedDimsArray(Ap, (2, 1)); -test_array_index(A) -test_array_index(Aperm) -test_array_index(Aview) -test_array_index(Ap) -test_array_index(view(A, :, :, 1)) # FastContiguousSubArray -test_array_index(view(A, 2, :, :)) # FastSubArray - -idx = @inferred(ArrayInterface.ArrayIndex{3}(A)[ArrayInterface.ArrayIndex{3}(Aperm)]) -for i in eachindex(IndexCartesian(), Aperm) - @test A[idx[i]] == Aperm[i] -end -idx = @inferred(idx[ArrayInterface.ArrayIndex{2}(Aview)]) -for i in eachindex(IndexCartesian(), Aview) - @test A[idx[i]] == Aview[i] -end - -idx_perm = @inferred(ArrayInterface.ArrayIndex{2}(Ap)[ArrayInterface.ArrayIndex{2}(Apperm)]) -idx = @inferred(idx[idx_perm]) -for i in eachindex(IndexCartesian(), Apperm) - @test A[idx[i]] == Apperm[i] -end - -v = Vector{Int}(undef, 4); -vp = v' -vnot = @inferred(ArrayInterface.ArrayIndex{1}(v)) -vidx = @inferred(vnot[ArrayInterface.StrideIndex(v)]) -@test @inferred(vidx[ArrayInterface.ArrayIndex{2}(vp)]) isa ArrayInterface.StrideIndex{2,(2,1)} - - -idx = @inferred(ArrayInterface.ArrayIndex{1}(1:2)) -@test idx[@inferred(ArrayInterface.ArrayIndex{1}((1:2)'))] isa ArrayInterface.OffsetIndex{StaticInt{0}} -@test @inferred(ArrayInterface.ArrayIndex{2}((1:2)'))[CartesianIndex(1, 2)] == 2 -@test @inferred(ArrayInterface.ArrayIndex{1}(1:2)) isa ArrayInterface.OffsetIndex{StaticInt{0}} -@test @inferred(ArrayInterface.ArrayIndex{1}((1:2)')) isa ArrayInterface.OffsetIndex{StaticInt{0}} -@test @inferred(ArrayInterface.ArrayIndex{1}(PermutedDimsArray(1:2, (1,)))) isa ArrayInterface.OffsetIndex{StaticInt{0}} -@test @inferred(ArrayInterface.ArrayIndex{1}(reshape(1:10, 2, 5))) isa ArrayInterface.OffsetIndex{StaticInt{0}} -@test @inferred(ArrayInterface.ArrayIndex{2}(reshape(1:10, 2, 5))) isa ArrayInterface.StrideIndex - -ap_index = ArrayInterface.StrideIndex(Ap) -@test @inferred(ndims(ap_index)) == ndims(Ap) -@test @inferred(ArrayInterface.known_offsets(ap_index)) === ArrayInterface.known_offsets(Ap) -@test @inferred(ArrayInterface.known_offset1(ap_index)) === ArrayInterface.known_offset1(Ap) -@test @inferred(ArrayInterface.offsets(ap_index, 1)) === ArrayInterface.offset1(Ap) -@test @inferred(ArrayInterface.offsets(ap_index, static(1))) === ArrayInterface.offset1(Ap) -@test @inferred(ArrayInterface.known_strides(ap_index)) === ArrayInterface.known_strides(Ap) -@test @inferred(ArrayInterface.contiguous_axis(ap_index)) == 1 -@test @inferred(ArrayInterface.contiguous_axis(ArrayInterface.StrideIndex{2,(1,2),nothing,NTuple{2,Int},NTuple{2,Int}})) == nothing -@test @inferred(ArrayInterface.stride_rank(ap_index)) == (1, 3) +test_layout(A) +test_layout(Aperm) +test_layout(Aview) +test_layout(Ap) +test_layout(Apperm) diff --git a/test/dimensions.jl b/test/dimensions.jl index c936b3e26..838774848 100644 --- a/test/dimensions.jl +++ b/test/dimensions.jl @@ -12,6 +12,9 @@ end ArrayInterface.parent_type(::Type{T}) where {P,T<:NamedDimsWrapper{<:Any,<:Any,<:Any,P}} = P ArrayInterface.dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L}} = static(L) Base.parent(x::NamedDimsWrapper) = x.parent +function ArrayInterface.layout(x::NamedDimsWrapper, s::ArrayInterface.AccessElement) + ArrayInterface.Layouted{typeof(s)}(parent(x), nothing) +end @testset "dimension permutations" begin a = ones(2, 2, 2) diff --git a/test/runtests.jl b/test/runtests.jl index fff1d634a..3e3536db8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -254,6 +254,9 @@ end ArrayInterface.parent_type(::Type{<:Wrapper{T,N,P}}) where {T,N,P} = P Base.parent(x::Wrapper) = x.parent ArrayInterface.device(::Type{T}) where {T<:Wrapper} = ArrayInterface.device(parent_type(T)) +function ArrayInterface.layout(x::Wrapper, s::ArrayInterface.AccessElement) + ArrayInterface.Layouted{typeof(s)}(parent(x), nothing) +end struct DenseWrapper{T,N,P<:AbstractArray{T,N}} <: DenseArray{T,N} end ArrayInterface.parent_type(::Type{DenseWrapper{T,N,P}}) where {T,N,P} = P @@ -668,10 +671,6 @@ end end end -@testset "ArrayIndex" begin - include("array_index.jl") -end - @testset "Reshaped views" begin u_base = randn(10, 10) u_view = view(u_base, 3, :) @@ -834,6 +833,10 @@ end end include("indexing.jl") include("dimensions.jl") +@testset "ArrayIndex" begin + include("array_index.jl") +end + @testset "broadcast" begin include("broadcast.jl") From 55ca0e53f17359056194adff79991f0a02cf3b65 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 4 Oct 2021 10:42:28 -0400 Subject: [PATCH 11/16] Get rid of experimental directory and add relayout funciton. --- src/ArrayInterface.jl | 17 ++- src/Experimental/Experimental.jl | 81 ----------- src/Experimental/access_styles.jl | 47 ------- src/Experimental/layouts.jl | 117 --------------- src/array_index.jl | 216 ++++++++++++++++++++-------- src/dimensions.jl | 45 ++++-- src/indexing.jl | 227 +++++++++++++++++++++++------- src/size.jl | 6 +- src/stridelayout.jl | 1 - test/array_index.jl | 8 +- test/dimensions.jl | 4 +- test/indexing.jl | 12 +- test/runtests.jl | 4 +- 13 files changed, 395 insertions(+), 390 deletions(-) delete mode 100644 src/Experimental/Experimental.jl delete mode 100644 src/Experimental/access_styles.jl delete mode 100644 src/Experimental/layouts.jl diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index ab7b090a2..64ebf8ea6 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -665,7 +665,6 @@ function _is_lazy_conjugate(::Type{T}, isconj) where {T <: Adjoint} end include("ranges.jl") -include("Experimental/Experimental.jl") include("axes.jl") include("size.jl") include("dimensions.jl") @@ -757,9 +756,21 @@ function __init__() @inline strides(B::StaticArrays.SizedArray{S,T,M,N,A}) where {S,T,M,N,A<:SubArray} = strides(B.data) parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N,A}}) where {S,T,M,N,A} = A else - parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N}}) where {S,T,M,N} = - Array{T,N} + parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N}}) where {S,T,M,N} = Array{T,N} end + + function static_size(x, inds) + StaticArrays.Size(ArrayInterface._mapsub(ArrayInterface.known_length, inds)) + end + + function ArrayInterface.relayout_constructor(::Type{<:StaticArrays.SArray}) + static_size + end + + function ArrayInterface.compose(x::NTuple{N}, y::StaticArrays.Size{S}) where {N,S} + StaticArrays.SArray{Tuple{S...},eltype(x),length(S),N}(x) + end + @require Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" begin function Adapt.adapt_storage( ::Type{<:StaticArrays.SArray{S}}, diff --git a/src/Experimental/Experimental.jl b/src/Experimental/Experimental.jl deleted file mode 100644 index 7341d8ba8..000000000 --- a/src/Experimental/Experimental.jl +++ /dev/null @@ -1,81 +0,0 @@ - -include("access_styles.jl") -include("layouts.jl") - -""" instantiate(lyt::Layouted) """ -@inline function instantiate(x::Layouted{S,P}) where {S,P} - lyt = _instantiate(P, layout(parent(x), S())) - return Layouted{typeof(AccessStyle(lyt))}( - parent(lyt), - combined_index(getfield(lyt, :indices), getfield(x, :indices)), - combined_transform(getfield(lyt, :f), getfield(x, :f)) - ) -end -@inline _instantiate(::Type{P1}, lyt::Layouted{S,P2}) where {P1,S,P2} = instantiate(lyt) -_instantiate(::Type{P}, lyt::Layouted{S,P}) where {P,S} = lyt - -@inline function instantiate(x::Layouted{AccessElement{N},P,<:AbstractCartesianIndex{N}}) where {P,N} - Layouted{AccessElement{N}}(instantiate(layout(parent(x), AccessElement{N}())), getfield(x, :indices), getfield(x, :f)) -end -@inline function instantiate(x::Layouted{AccessElement{1},P,<:CanonicalInt}) where {P} - Layouted{AccessElement{1}}(instantiate(layout(parent(x), AccessElement{1}())), getfield(x, :indices), getfield(x, :f)) -end -@inline function instantiate(x::Layouted{S,P,I}) where {S<:AccessIndices,P,I<:Tuple} - Layouted{S}( - instantiate(layout(parent(x), AccessElement{dynamic(ndims_index(I))}())), - getfield(x, :indices), - getfield(x, :f) - ) -end - -# combine element transforms between arrays -combined_transform(::typeof(identity), y) = y -@inline function combined_transform(x::ComposedFunction, y) - getfield(x, :outer) ∘ combined_transform(getfield(x, :inner), y) -end -@inline combined_transform(x, y) = _combined_transform(x, y) -_combined_transform(x, ::typeof(identity)) = x -@inline function _combined_transform(x, y::ComposedFunction) - combined_index(x, getfield(y, :outer)) ∘ getfield(y, :inner) -end - -function Base.showarg(io::IO, x::StrideIndex{N,R,C}, toplevel) where {N,R,C} - print(io, "StrideIndex{$N,$R,$C}(") - print(io, strides(x)) - print(io, ", ") - print(io, offsets(x)) - print(io, ")") -end -function Base.showarg(io::IO, x::SubIndex, toplevel) - print(io, "SubIndex{$(ndims(x))}(") - print_index(io, getfield(x, :indices)) - print(io, ")") -end -function Base.showarg(io::IO, x::LinearSubIndex, toplevel) - print(io, "LinearSubIndex(offset=$(getfield(x, :offset)),stride=$(getfield(x, :stride)))") -end -function Base.showarg(io::IO, x::CombinedIndex, toplevel) - print(io, "combine(") - print_index(io, x.i1) - print(io, ", ") - print_index(io, x.i2) - print(io, ")") -end - -function Base.showarg(io::IO, x::Layouted{S}, toplevel) where {S} - print(io, "Layouted{$S}(") - print_index(io, parent(x)) - print(io, ", ") - print_index(io, x.indices) - print(io, ", ") - print_index(io, x.f) - print(io, ")") -end - -Base.show(io::IO, ::MIME"text/plain", x::Layouted) = Base.showarg(io, x, true) -Base.show(io::IO, ::MIME"text/plain", x::StrideIndex) = Base.showarg(io, x, true) -Base.show(io::IO, ::MIME"text/plain", x::SubIndex) = Base.showarg(io, x, true) - -print_index(io, x::CartesianIndices) = print(io, "::CartesianIndices{$(ndims(x))}") -print_index(io, x) = Base.showarg(io, x, false) - diff --git a/src/Experimental/access_styles.jl b/src/Experimental/access_styles.jl deleted file mode 100644 index 10c0be4d0..000000000 --- a/src/Experimental/access_styles.jl +++ /dev/null @@ -1,47 +0,0 @@ - -""" - AccessStyle(I) - -`AccessStyle` specifies how the default index `I` accesses other collections. -""" -abstract type AccessStyle end - -struct AccessElement{N} <: AccessStyle end - -struct AccessUnkown{T} <: AccessStyle end - -struct AccessBoolean <: AccessStyle end - -struct AccessRange <: AccessStyle end - -struct AccessIndices{N} <: AccessStyle end - -# FIXME This should be lispy so we can have .. specialization -# _astyle(::Type{I}, i::StaticInt) where {I} = AccessStyle(_get_tuple(I, i)) -# AccessStyle(::Type{I}) where {I<:Tuple) = AccessIndices(eachop(_astyle, nstatic(Val(N)), I)) - -@generated function static_typed_tail(::Type{T}) where {T<:Tuple} - N = length(T.parameters) - out = Expr(:curly, :Tuple) - for i in 2:N - push!(out.args, T.parameters[i]) - end - return out -end - -AccessStyle(::Type{T}) where {T} = AccessUnkown{T}() -AccessStyle(@nospecialize(x::Type{<:Integer})) = AccessElement{1}() -AccessStyle(::Type{<:Union{OneTo,UnitRange,StepRange,OptionallyStaticRange}}) = AccessRange() -@inline AccessStyle(::Type{<:SubIndex{<:Any,I}}) where {I} = AccessElement{sum(dynamic(ndims_index(I)))}() -AccessStyle(x::Type{<:StrideIndex{N,I1,I2}}) where {N,I1,I2} = AccessElement{length(I2)}() -AccessStyle(x::Type{PermutedIndex{2,(2,1),(2,1)}}) = AccessElement{1}() -AccessStyle(x::Type{<:AbstractCartesianIndex{N}}) where {N} = AccessElement{N}() -# TODO should dig into parents -AccessStyle(x::Type{<:AbstractArray}) = AccessStyle(eltype(x)) - -AccessStyle(::Type{Tuple{I}}) where {I} = (AccessStyle(I),) -@inline function AccessStyle(x::Type{Tuple{I,Vararg{Any}}}) where {I} - (AccessStyle(I), AccessStyle(static_typed_tail(x))...) -end -AccessStyle(@nospecialize(x)) = AccessStyle(typeof(x)) - diff --git a/src/Experimental/layouts.jl b/src/Experimental/layouts.jl deleted file mode 100644 index 5d91ac8ff..000000000 --- a/src/Experimental/layouts.jl +++ /dev/null @@ -1,117 +0,0 @@ - -struct Layouted{S<:AccessStyle,P,I,F} - parent::P - indices::I - f::F - - Layouted{S}(p::P, i::I, f::F) where {S,P,I,F} = new{S,P,I,F}(p, i, f) - Layouted{S}(p, i) where {S} = Layouted{S}(p, i, identity) -end - -AccessStyle(::Type{<:Layouted{S}}) where {S} = S() - -parent_type(::Type{<:Layouted{S,P}}) where {S,P} = P - -Base.parent(x::Layouted) = getfield(x, :parent) - -@inline function Base.getindex(x::Layouted{S,P,I}, i) where {S,P,I} - getfield(x, :f)(@inbounds(parent(x)[getfield(x, :indices)[i]])) -end -@inline function Base.getindex(x::Layouted{S,P,Nothing}, i) where {S,P} - getfield(x, :f)(@inbounds(parent(x)[i])) -end - -@inline function Base.setindex!(x::Layouted{S,P,I}, v, i) where {S,P,I} - @inbounds(Base.setindex!(parent(x), getfield(x, :f)(v), getfield(x, :indices)[i])) -end -@inline function Base.setindex!(x::Layouted{S,P,Nothing}, v, i) where {S,P} - @inbounds(Base.setindex!(parent(x), getfield(x, :f)(v), i)) -end - -""" - layout(x, access::AccessStyle) - -Returns a representation of `x`'s layout given a particular `AccessStyle`. -""" -layout(x, i::CanonicalInt) = Layouted{AccessElement{1}}(x, i) -layout(x, i::AbstractCartesianIndex{N}) where {N} = Layouted{AccessElement{N}}(x, i) -layout(x, i::Tuple{CanonicalInt}) = layout(x, getfield(i, 1)) -layout(x, i::Tuple{CanonicalInt,Vararg{CanonicalInt}}) = layout(x, NDIndex(i)) -layout(x, i::Tuple{Vararg{Any,N}}) where {N} = Layouted{AccessIndices{N}}(x, i) -layout(x, s::AccessStyle) = Layouted{typeof(s)}(x, nothing) - -## Base type ranges -@inline function layout(x::Union{UnitRange,OneTo,StepRange,OptionallyStaticRange}, ::AccessElement{1}) - Layouted{AccessElement{1}}(x, nothing, identity) -end - -## Array -layout(x::Array, ::AccessElement{1}) = Layouted{AccessElement{1}}(x, nothing) -@inline layout(x::Array, ::AccessElement) = Layouted{AccessElement{1}}(x, StrideIndex(x)) - -## ReshapedArray -layout(x::ReshapedArray, ::AccessElement{1}) = Layouted{AccessElement{1}}(parent(x), nothing) -@inline function layout(x::ReshapedArray{T,N}, ::AccessElement{N}) where {T,N} - Layouted{AccessElement{1}}(x, _to_linear(x)) -end - -## Transpose/Adjoint{Real} -@inline function layout(x::Union{Transpose{<:Any,<:AbstractMatrix},Adjoint{<:Real,<:AbstractMatrix}}, ::AccessElement{2}) - Layouted{AccessElement{2}}(parent(x), PermutedIndex{2,(2,1),(2,1)}()) -end -@inline function layout(x::Union{Transpose{<:Any,<:AbstractVector},Adjoint{<:Real,<:AbstractVector}}, ::AccessElement{2}) - Layouted{AccessElement{1}}(parent(x), PermutedIndex{2,(2,1),(2,)}()) -end -@inline function layout(x::Union{Transpose{<:Any,<:AbstractMatrix},Adjoint{<:Real,<:AbstractMatrix}}, ::AccessElement{1}) - Layouted{AccessElement{2}}(parent(x), combined_index(PermutedIndex{2,(2,1),(2,1)}(), _to_cartesian(x))) -end -@inline function layout(x::Union{Transpose{<:Any,<:AbstractVector},Adjoint{<:Real,<:AbstractVector}}, ::AccessElement{1}) - Layouted{AccessElement{1}}(parent(x), nothing) -end - -## Adjoint -@inline function layout(x::Adjoint{<:Any,<:AbstractMatrix}, ::AccessElement{2}) - Layouted{AccessElement{2}}(parent(x), PermutedIndex{2,(2,1),(2,1)}(), adjoint) -end -@inline function layout(x::Adjoint{<:Any,<:AbstractVector}, ::AccessElement{2}) - Layouted{AccessElement{1}}(parent(x), PermutedIndex{2,(2,1),(2,)}(), adjoint) -end -@inline function layout(x::Adjoint{<:Any,<:AbstractMatrix}, ::AccessElement{1}) - Layouted{AccessElement{2}}(parent(x), combined_index(PermutedIndex{2,(2,1),(2,1)}(), _to_cartesian(x)), adjoint) -end -@inline function layout(x::Adjoint{<:Any,<:AbstractVector}, ::AccessElement{1}) - Layouted{AccessElement{1}}(parent(x), nothing, adjoint) -end - -## PermutedDimsArray -@inline function layout(x::PermutedDimsArray{T,N,I1,I2}, ::AccessElement{1}) where {T,N,I1,I2} - if N === 1 - return Layouted{AccessElement{1}}(parent(x), nothing) - else - return Layouted{AccessElement{N}}(parent(x), combined_index(PermutedIndex{N,I1,I2}(), _to_cartesian(x))) - end -end -@inline function layout(x::PermutedDimsArray{T,N,I1,I2}, ::AccessElement{N}) where {T,N,I1,I2} - Layouted{AccessElement{N}}(parent(x), PermutedIndex{N,I1,I2}()) -end - -## SubArray -@inline function layout(x::Base.FastContiguousSubArray, ::AccessElement{1}) - Layouted{AccessElement{1}}(parent(x), OffsetIndex(getfield(x, :offset1))) -end -@inline function layout(x::Base.FastSubArray, ::AccessElement{1}) - Layouted{AccessElement{1}}(parent(x), LinearSubIndex(getfield(x, :offset1), getfield(x, :stride1))) -end -@inline function layout(x::SubArray{T,N}, ::AccessElement{1}) where {T,N} - if N === 1 - i = SubIndex{1}(getfield(x, :indices)) - return Layouted{typeof(AccessStyle(i))}(parent(x), i) - else - i = SubIndex{N}(getfield(x, :indices)) - return Layouted{typeof(AccessStyle(i))}(parent(x), combined_index(i, _to_cartesian(x))) - end -end -@inline function layout(x::SubArray{T,N,P,I}, ::AccessElement{N}) where {T,N,P,I} - Layouted{AccessElement{sum(dynamic(ndims_index(I)))}}(parent(x), SubIndex{N}(getfield(x, :indices))) -end - diff --git a/src/array_index.jl b/src/array_index.jl index e607badaa..36ea3de8e 100644 --- a/src/array_index.jl +++ b/src/array_index.jl @@ -226,7 +226,8 @@ Subtype of `ArrayIndex` that provides a multidimensional view of another `ArrayI struct SubIndex{N,I} <: ArrayIndex{N} indices::I - SubIndex{N}(inds::Tuple) where {N} = new{N,typeof(inds)}(inds) + SubIndex{N}(inds::Tuple) where {N} = new{N,typeof(inds)}(Base.ensure_indexable(inds)) + SubIndex(x::SubArray{T,N,P,I}) where {T,N,P,I} = new{N,I}(getfield(x, :indices)) end """ @@ -240,38 +241,121 @@ struct LinearSubIndex{O<:CanonicalInt,S<:CanonicalInt} <: VectorIndex stride::S end +offset1(x::LinearSubIndex) = getfield(x, :offset) +stride1(x::LinearSubIndex) = getfield(x, :stride) + const OffsetIndex{O} = LinearSubIndex{O,StaticInt{1}} OffsetIndex(offset::CanonicalInt) = LinearSubIndex(offset, static(1)) -struct CombinedIndex{N,I1,I2} <: ArrayIndex{N} - i1::I1 - i2::I2 +""" + IdentityIndex{N} + +Used to specify that indices don't need any transformation. +""" +struct IdentityIndex{N} <: ArrayIndex{N} end + +""" + UnkownIndex{N} + +This default return type when calling `ArrayIndex{N}(x)`. +""" +struct UnkownIndex{N} <: ArrayIndex{N} end + - CombinedIndex(i1::I1, i2::I2) where {I1,I2} = new{ndims(I1),I1,I2}(i1, i2) +struct ComposedIndex{N,O,I} <: ArrayIndex{N} + outer::O + inner::I + + ComposedIndex(i1::I1, i2::I2) where {I1,I2} = new{ndims(I1),I1,I2}(i1, i2) end -# we should be able to assume that if `i1` was indexed without error than it's inbounds -@propagate_inbounds function Base.getindex(x::CombinedIndex) - i2 = getfield(x, :i1)[] - @inbounds(getfield(x, :i1)[ii]) +outer(x::ComposedIndex) = getfield(x, :outer) +inner(x::ComposedIndex) = getfield(x, :inner) + +@inline _to_cartesian(x) = CartesianIndices(indices(x, ntuple(+, Val(ndims(x))))) +@inline function _to_linear(x) + N = ndims(x) + StrideIndex{N,ntuple(+, Val(N)),nothing}(size_to_strides(size(x), static(1)), offsets(x)) +end + +""" + ArrayIndex{N}(A) + +Constructs a subtype of `ArrayIndex` such that an `N` dimensional indexing argument may be +converted to an appropriate state for accessing the buffer of `A`. For example: + +```julia +julia> A = reshape(1:20, 4, 5); + +julia> index = ArrayInterface.ArrayIndex{2}(A); + +julia> ArrayInterface.buffer(A)[index[2, 2]] == A[2, 2] +true + +``` +""" +ArrayIndex{N}(x) where {N} = UnkownIndex{N}() +ArrayIndex{N}(x::Array) where {N} = StrideIndex(x) +ArrayIndex{1}(x::Array) = OffsetIndex(static(0)) + +ArrayIndex{1}(x::ReshapedArray) = IdentityIndex{1}() +ArrayIndex{N}(x::ReshapedArray) where {N} = _to_linear(x) + +# TODO should we only define index constructors for explicit types? +ArrayIndex{1}(x::AbstractRange) = OffsetIndex(offset1(x) - static(1)) + +## SubArray +ArrayIndex{N}(x::SubArray) where {N} = SubIndex{ndims(x)}(getfield(x, :indices)) +@inline function ArrayIndex{1}(x::SubArray{T,N}) where {T,N} + if N === 1 + return SubIndex(x) + else + return compose(SubIndex(x), _to_cartesian(x)) + end end -@propagate_inbounds function Base.getindex(x::CombinedIndex, i::CanonicalInt) - ii = getfield(x, :i2)[i] - @inbounds(getfield(x, :i1)[ii]) +ArrayIndex{1}(x::Base.FastContiguousSubArray) = OffsetIndex(getfield(x, :offset1)) +function ArrayIndex{1}(x::Base.FastSubArray) + LinearSubIndex(getfield(x, :offset1), getfield(x, :stride1)) end -@propagate_inbounds function Base.getindex(x::CombinedIndex, i::AbstractCartesianIndex) - ii = getfield(x, :i2)[i] - @inbounds(getfield(x, :i1)[ii]) + +## PermutedDimsArray +@inline function ArrayIndex{1}(x::PermutedDimsArray{T,N,I1,I2}) where {T,N,I1,I2} + if N === 1 + return IdentityIndex{1}() + else + return compose(PermutedIndex{N,I1,I2}(), _to_cartesian(x)) + end end +@inline ArrayIndex{N}(x::PermutedDimsArray{T,N,I1,I2}) where {T,N,I1,I2} = PermutedIndex{N,I1,I2}() -## Traits +## Transpose/Adjoint{Real} +@inline function ArrayIndex{2}(x::Union{Transpose{<:Any,<:AbstractMatrix},Adjoint{<:Real,<:AbstractMatrix}}) + PermutedIndex{2,(2,1),(2,1)}() +end +@inline function ArrayIndex{2}(x::Union{Transpose{<:Any,<:AbstractVector},Adjoint{<:Real,<:AbstractVector}}) + PermutedIndex{2,(2,1),(2,)}() +end +@inline function ArrayIndex{1}(x::Union{Transpose{<:Any,<:AbstractMatrix},Adjoint{<:Real,<:AbstractMatrix}}) + compose(PermutedIndex{2,(2,1),(2,1)}(), _to_cartesian(x)) +end +@inline function ArrayIndex{1}(x::Union{Transpose{<:Any,<:AbstractVector},Adjoint{<:Real,<:AbstractVector}}) + IdentityIndex{1}() +end +## Traits Base.firstindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = 1 Base.lastindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count Base.length(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count ## getindex @propagate_inbounds Base.getindex(x::ArrayIndex, i::CanonicalInt, ii::CanonicalInt...) = x[NDIndex(i, ii...)] +Base.getindex(x::IdentityIndex, i::CanonicalInt) = 1 +Base.getindex(x::IdentityIndex, i::AbstractCartesianIndex) = i +# we should be able to assume that if `i1` was indexed without error than it's inbounds +@propagate_inbounds Base.getindex(x::ComposedIndex) = @inbounds(outer(x)[inner(x)[]]) +@propagate_inbounds Base.getindex(x::ComposedIndex, i::CanonicalInt) = @inbounds(outer(x)[inner(x)[i]]) +@propagate_inbounds Base.getindex(x::ComposedIndex, i::AbstractCartesianIndex) = @inbounds(outer(x)[inner(x)[i]]) + @propagate_inbounds function Base.getindex(ind::TridiagonalIndex, i::Int) @boundscheck 1 <= i <= ind.count || throw(BoundsError(ind, i)) offsetu = ind.isrow ? 0 : 1 @@ -328,7 +412,7 @@ end ind.reflocalinds[p][_i] + ind.refcoords[p] - 1 end -@inline function Base.getindex(x::StrideIndex{N}, i::AbstractCartesianIndex) where {N} +@inline function Base.getindex(x::StrideIndex{N}, i::AbstractCartesianIndex{N}) where {N} return _strides2int(offsets(x), strides(x), Tuple(i)) + static(1) end @generated function _strides2int(o::O, s::S, i::I) where {O,S,I} @@ -377,9 +461,7 @@ end end return Expr(:block, Expr(:meta, :inline), :($out)) end -@inline function Base.getindex(x::LinearSubIndex, i::CanonicalInt) - getfield(x, :offset) + getfield(x, :stride) * i -end +@inline Base.getindex(x::LinearSubIndex, i::CanonicalInt) = offset1(x) + stride1(x) * i @propagate_inbounds function Base.getindex(ind::BidiagonalIndex, i::Int) @boundscheck 1 <= i <= ind.count || throw(BoundsError(ind, i)) if ind.isup @@ -390,33 +472,30 @@ end convert(Int, floor(ii / 2)) end +const compose = ∘ + """ - combined_index(i1, i2) + compose(outer_index, inner_index) + outer_index ∘ inner_index Given two subtypes of `ArrayIndex`, combines a new instance that when indexed is equivalent -to `i1[i2[i]]`. Default behavior produces a `CombinedIndex`, but more `i1` and `i2` may be +to `i1[i2[i]]`. Default behavior produces a `ComposedIndex`, but more `i1` and `i2` may be consolidated into a more efficient representation. """ -combined_index(::Nothing, y::ArrayIndex) = y -combined_index(x::ArrayIndex, ::Nothing) = x -combined_index(::Nothing, ::Nothing) = nothing -combined_index(x::ArrayIndex, y::ArrayIndex) = CombinedIndex(x, y) -@inline function combined_index(x::CombinedIndex, y::ArrayIndex) - CombinedIndex(getfield(x, :i1), combined_index(getfield(x, :i2), y)) -end -@inline function combined_index(x::ArrayIndex, y::CombinedIndex) - CombinedIndex(combined_index(x, getfield(y, :i1)), getfield(y, :i2)) -end -@inline function combined_index(x::CombinedIndex, y::CombinedIndex) - CombinedIndex( - getfield(x, :i1), - CombinedIndex(combined_index(getfield(x, :i2), getfield(y, :i1)), getfield(y, :i2)) - ) +compose(x::ArrayIndex, y::ArrayIndex) = _compose(x, y) +_compose(x, y::IdentityIndex) = x +_compose(x, y) = ComposedIndex(x, y) +_compose(x, y::ComposedIndex) = ComposedIndex(compose(x, outer(y)), inner(y)) + +compose(::IdentityIndex, y::ArrayIndex) = y +@inline compose(x::ComposedIndex, y::ArrayIndex) = ComposedIndex(outer(x), compose(inner(x), y)) +@inline function compose(x::ComposedIndex, y::ComposedIndex) + ComposedIndex(outer(x), ComposedIndex(compose(inner(x), outer(y)), inner(y))) end -@inline function combined_index(x::StrideIndex, y::SubIndex{N,I}) where {N,I} +@inline function compose(x::StrideIndex, y::SubIndex{N,I}) where {N,I} _combined_sub_strides(stride_preserving_index(I), x, y) end -_combined_sub_strides(::False, x::StrideIndex, i::SubIndex) = CombinedIndex(x, i) +_combined_sub_strides(::False, x::StrideIndex, i::SubIndex) = ComposedIndex(x, i) @inline function _combined_sub_strides(::True, x::StrideIndex{N,R,C}, i::SubIndex{Ns,I}) where {N,R,C,Ns,I<:Tuple{Vararg{Any,N}}} c = static(C) if _get_tuple(I, c) <: AbstractUnitRange @@ -435,13 +514,13 @@ _combined_sub_strides(::False, x::StrideIndex, i::SubIndex) = CombinedIndex(x, i eachop(getmul, pdims, map(maybe_static_step, inds), s), permute(o, pdims) ) - return combined_index(OffsetIndex(reduce_tup(+, map(*, map(_diff, inds, o), s))), out) + return compose(OffsetIndex(reduce_tup(+, map(*, map(_diff, inds, o), s))), out) end @inline _diff(::Base.Slice, ::Any) = Zero() @inline _diff(x::AbstractRange, o) = static_first(x) - o @inline _diff(x::Integer, o) = x - o -@inline function combined_index(x::StrideIndex{1,R,C}, ::PermutedIndex{2,(2,1),(2,)}) where {R,C} +@inline function compose(x::StrideIndex{1,R,C}, ::PermutedIndex{2,(2,1),(2,)}) where {R,C} if C === nothing c2 = nothing elseif C === 1 @@ -453,37 +532,50 @@ end return StrideIndex{2,(2,1),c2}((s, s), (static(1), offset1(x))) end - -@inline function combined_index(x::StrideIndex{N,R,C}, ::PermutedIndex{N,perm,iperm}) where {N,R,C,perm,iperm} +@inline function compose(x::StrideIndex{N,R,C}, ::PermutedIndex{N,I1,I2}) where {N,R,C,I1,I2} if C === nothing || C === -1 c2 = C else - c2 = getfield(iperm, C) + c2 = getfield(I2, C) end - return StrideIndex{N,permute(R, Val(perm)),c2}( - permute(strides(x), Val(perm)), - permute(offsets(x), Val(perm)), + return StrideIndex{N,permute(R, Val(I1)),c2}( + permute(strides(x), Val(I1)), + permute(offsets(x), Val(I1)), ) end -@inline function combined_index(::PermutedIndex{<:Any,I11,I12},::PermutedIndex{<:Any,I21,I22}) where {I11,I12,I21,I22} +@inline function compose(x::PermutedIndex{<:Any,I11,I12},::PermutedIndex{<:Any,I21,I22}) where {I11,I12,I21,I22} PermutedIndex(permute(static(I11), static(I21)), permute(static(I12), static(I22))) end -@inline function combined_index(x::LinearSubIndex, i::LinearSubIndex) - s = getfield(x, :stride) - LinearSubIndex( - getfield(x, :offset) + getfield(i, :offset) * s, - getfield(i, :stride) * s - ) +@inline function compose(x::LinearSubIndex, y::LinearSubIndex) + LinearSubIndex(offset1(x) + offset1(y) * stride1(x), stride1(y) * stride1(x)) end -combined_index(::OffsetIndex{StaticInt{0}}, y::StrideIndex) = y +compose(::OffsetIndex{StaticInt{0}}, y::StrideIndex) = y +compose(x::ArrayIndex, y::CartesianIndices) = ComposedIndex(x, y) -combined_index(x::ArrayIndex, y::CartesianIndices) = CombinedIndex(x, y) -combined_index(x::CartesianIndices, y::ArrayIndex) = CombinedIndex(x, y) +function compose(x::AbstractArray{T,N}, ::PermutedIndex{N,I1,I2}) where {T,N,I1,I2} + PermutedDimsArray{T,N,I1,I2,typeof(x)}(x) +end +# TODO call to more direct constructors so that we don't repeat checks already performed +# when constructin SubIndex +compose(x::AbstractArray, y::SubIndex) = SubArray(x, getfield(y, :indices)) +compose(x::AbstractArray, y::ComposedIndex) = compose(compose(x, outer(y)), inner(y)) +compose(x::AbstractArray, y::ArrayIndex) = ComposedIndex(x, y) -## ArrayIndex constructors -@inline _to_cartesian(a) = CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a)))) -@inline function _to_linear(a) - N = ndims(a) - StrideIndex{N,ntuple(+, Val(N)),nothing}(size_to_strides(size(a), static(1)), offsets(a)) +## show(::IO, ::MIME, ::ArrayIndex) +function Base.show(io::IO, ::MIME"text/plain", @nospecialize(x::StrideIndex)) + print(io, "StrideIndex{$(ndims(x)), $(known(stride_rank(x))), $(known(contiguous_axis(x)))}($(strides(x)), $(offsets(x)))") +end +function Base.show(io::IO, ::MIME"text/plain", @nospecialize(x::SubIndex)) + print(io, "SubIndex{$(ndims(x))}($(x.indices))") end +function Base.show(io::IO, ::MIME"text/plain", @nospecialize(x::LinearSubIndex)) + print(io, "LinearSubIndex(offset=$(offset1(x)),stride=$(stride1(x)))") +end +function Base.show(io::IO, m::MIME"text/plain", @nospecialize(x::ComposedIndex)) + show(io, m, outer(x)) + print(io, " ∘ ") + show(io, m, inner(x)) + #print(io, "$(outer(x)) ∘ $(inner(x))") +end + diff --git a/src/dimensions.jl b/src/dimensions.jl index a4df908c9..425436741 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -23,18 +23,43 @@ end is_increasing(::Tuple{StaticInt{X}}) where {X} = True() #= - ndims_index(::Type{I})::StaticInt + index_ndims(::Type{I})::StaticInt -The number of dimensions an instance of `I` maps to when indexing an instance of `A`. +The number of dimensions an instance of `I` maps to when used as an index. =# -ndims_index(i) = ndims_index(typeof(i)) -ndims_index(::Type{I}) where {I} = static(1) -ndims_index(::Type{I}) where {N,I<:AbstractCartesianIndex{N}} = static(N) -ndims_index(::Type{I}) where {I<:AbstractArray} = ndims_index(eltype(I)) -ndims_index(::Type{I}) where {I<:AbstractArray{Bool}} = static(ndims(I)) -ndims_index(::Type{I}) where {N,I<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = static(N) -_ndims_index(::Type{I}, i::StaticInt) where {I} = ndims_index(_get_tuple(I, i)) -ndims_index(::Type{I}) where {N,I<:Tuple{Vararg{Any,N}}} = eachop(_ndims_index, nstatic(Val(N)), I) +index_ndims(x) = index_ndims(typeof(x)) +index_ndims(x::Type{I}) where {I} = static(1) +index_ndims(x::Type{I}) where {N,I<:Base.AbstractCartesianIndex{N}} = static(N) +index_ndims(x::Type{I}) where {I<:AbstractArray} = index_ndims(eltype(I)) +index_ndims(x::Type{I}) where {I<:AbstractArray{Bool}} = static(ndims(I)) +index_ndims(x::Type{I}) where {N,I<:Base.LogicalIndex{Int,<:AbstractArray{Bool,N}}} = static(1) +index_ndims(x::Type{I}) where {N,I<:Base.LogicalIndex{CartesianIndex{N},<:AbstractArray{Bool,N}}} = static(N) +index_ndims(x::Type{<:PermutedIndex{N,I1,I2}}) where {N,I1,I2} = static(length(I2)) +index_ndims(x::Type{<:SubIndex{N,I}}) where {N,I} = index_dimsum(I) +index_ndims(x::Type{<:IdentityIndex{N}}) where {N} = static(N) +index_ndims(x::Type{<:StrideIndex}) = static(1) +index_ndims(x::Type{<:LinearSubIndex}) = static(1) +index_ndims(x::Type{<:ComposedIndex{N,O,I}}) where {N,O,I} = index_ndims(O) +_index_ndims(::Type{I}, i::StaticInt) where {I} = index_ndims(_get_tuple(I, i)) +index_ndims(::Type{I}) where {N,I<:Tuple{Vararg{Any,N}}} = eachop(_index_ndims, nstatic(Val(N)), I) + +# index_dimsum(x)::StaticInt - returns the total number of dimension that `x` maps to +@inline index_dimsum(x) = sum(index_ndims(x)) + +_mapsub(f, ::Tuple{}) = () +@inline function _mapsub(f, x::Tuple{I,Vararg{Any}}) where {I} + if (I<:AbstractCartesianIndex) || (I<:Integer) + return _mapsub(f, tail(x)) + else + return (f(getfield(x, 1)), _mapsub(f, tail(x))...) + end +end +_mapsub(f, ::Type{Tuple{}}) = () +_mapsub(f, ::Type{I}) where {I} = __mapsub(f, I, _to_sub_dims(I)) +@inline function __mapsub(f, ::Type{I}, x::Tuple{StaticInt{N},Vararg{Any}}) where {I,N} + (f(_get_tuple(I, static(N))), __mapsub(f, I, tail(x))...) +end +__mapsub(f, ::Type{I}, ::Tuple{}) where {I} = () """ from_parent_dims(::Type{T}) -> Tuple{Vararg{Union{Int,StaticInt}}} diff --git a/src/indexing.jl b/src/indexing.jl index 8bb0f339c..9f3995936 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -39,6 +39,9 @@ function canonical_convert(x::AbstractUnitRange) return OptionallyStaticUnitRange(static_first(x), static_last(x)) end +is_linear_indexing(A, args::Tuple{Arg}) where {Arg} = index_ndims(Arg) < 2 +is_linear_indexing(A, args::Tuple{Arg,Vararg{Any}}) where {Arg} = false + """ to_indices(A, inds::Tuple) -> Tuple @@ -53,7 +56,7 @@ on a call to [`is_canonical`](@ref), then they each are checked at the axis leve to_indices(A, lazy_axes(A), axes(getfield(inds, 1))) end @propagate_inbounds function _to_indices(::True, A, inds) - if isone(sum(ndims_index(inds))) + if isone(sum(index_ndims(inds))) @boundscheck if !checkindex(Bool, eachindex(IndexLinear(), A), getfield(inds, 1)) throw(BoundsError(A, inds)) end @@ -72,7 +75,7 @@ end end @propagate_inbounds function _to_indices(::False, A, inds) - if isone(sum(ndims_index(inds))) + if isone(sum(index_ndims(inds))) return (to_index(LazyAxis{:}(A), getfield(inds, 1)),) else return to_indices(A, lazy_axes(A), inds) @@ -82,7 +85,7 @@ end to_indices(A, axs, (Tuple(getfield(inds, 1))..., tail(inds)...)) end @propagate_inbounds function to_indices(A, axs, inds::Tuple{I,Vararg{Any}}) where {I} - _to_indices(ndims_index(I), A, axs, inds) + _to_indices(index_ndims(I), A, axs, inds) end @propagate_inbounds function _to_indices(::StaticInt{1}, A, axs, inds) @@ -244,7 +247,7 @@ indices calling [`to_axis`](@ref). @inline function to_axes(A, inds::Tuple) if ndims(A) === 1 return (to_axis(axes(A, 1), first(inds)),) - elseif isone(sum(ndims_index(inds))) + elseif isone(sum(index_ndims(inds))) return (to_axis(eachindex(IndexLinear(), A), first(inds)),) else return to_axes(A, axes(A), inds) @@ -252,7 +255,7 @@ indices calling [`to_axis`](@ref). end # drop this dimension to_axes(A, a::Tuple, i::Tuple{<:Integer,Vararg{Any}}) = to_axes(A, tail(a), tail(i)) -to_axes(A, a::Tuple, i::Tuple{I,Vararg{Any}}) where {I} = _to_axes(ndims_index(I), A, a, i) +to_axes(A, a::Tuple, i::Tuple{I,Vararg{Any}}) where {I} = _to_axes(index_ndims(I), A, a, i) function _to_axes(::StaticInt{1}, A, axs::Tuple, inds::Tuple) return (to_axis(first(axs), first(inds)), to_axes(A, tail(axs), tail(inds))...) end @@ -314,50 +317,25 @@ end @propagate_inbounds getindex(x::Tuple, i::Int) = getfield(x, i) @propagate_inbounds getindex(x::Tuple, ::StaticInt{i}) where {i} = getfield(x, i) -## unsafe_getindex ## +## unsafe_getindex function unsafe_getindex(a::A) where {A} parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A,))) return unsafe_getindex(parent(a)) end unsafe_getindex(A::Array) = Base.arrayref(false, A, 1) - -unsafe_getindex(A::LinearIndices, i::CanonicalInt) = @inbounds(A[Int(i)]) -@inline function unsafe_getindex(A::LinearIndices, i::CanonicalInt, ii::Vararg{CanonicalInt}) - Int(@inbounds(_to_linear(A)[NDIndex(i, ii...)])) -end - -unsafe_getindex(A::CartesianIndices, i::AbstractCartesianIndex) = @inbounds(A[CartesianIndex(i)]) -unsafe_getindex(A::CartesianIndices, i::CanonicalInt) = @inbounds(A[CartesianIndex(i)]) -unsafe_getindex(A::CartesianIndices, i::CanonicalInt, ii::Vararg{CanonicalInt}) = CartesianIndex(i, ii...) - -@inline unsafe_getindex(a, i::Vararg{Any}) = _unsafe_getindex(layout(a, i)) -@inline function _unsafe_getindex(x::Layouted{S}) where {S<:AccessElement} - lyt = instantiate(x) - return getfield(lyt, :f)(@inbounds(parent(lyt)[getfield(lyt, :indices)])) +@inline function unsafe_getindex(A, inds::Vararg{CanonicalInt,N}) where {N} + buf, lyt = layout(A, static(N)) + @inbounds(buf[lyt[inds...]]) end -@generated function _unsafe_getindex(x::Layouted{S}) where {N,S<:AccessIndices{N}} - quote - Compat.@inline() - lyt = instantiate(x) - buf = parent(lyt) - I = getfield(lyt, :indices) - dest = similar(parent(x), to_axes(parent(x), I)) - D = eachindex(dest) - Dy = iterate(D) - @inbounds Base.Cartesian.@nloops $N j d -> I[d] begin - # This condition is never hit, but at the moment - # the optimizer is not clever enough to split the union without it - Dy === nothing && return dest - (idx, state) = Dy - dest[idx] = buf[NDIndex(Base.Cartesian.@ntuple($N, j))] - Dy = iterate(D, state) - end - return dest - end +@inline function unsafe_getindex(A, inds::Vararg{Any}) + buf, lyt = layout(A, index_dimsum(inds)) + return relayout(getlayout(device(buf), buf, lyt, inds), A, inds) end +## CartesianIndices/LinearIndices _ints2range(x::Integer) = x:x _ints2range(x::AbstractRange) = x +unsafe_getindex(A::CartesianIndices, i::Vararg{CanonicalInt,N}) where {N} = @inbounds(A[CartesianIndex(i)]) @inline function unsafe_getindex(A::CartesianIndices{N}, inds::Vararg{Any}) where {N} if (length(inds) === 1 && N > 1) || stride_preserving_index(typeof(inds)) === False() return Base._getindex(IndexStyle(A), A, inds...) @@ -365,8 +343,16 @@ _ints2range(x::AbstractRange) = x return CartesianIndices(to_axes(A, _ints2range.(inds))) end end + +function unsafe_getindex(A::LinearIndices, i::Vararg{CanonicalInt,N}) where {N} + if N === 1 + return Int(@inbounds(i[1])) + else + return Int(@inbounds(_to_linear(A)[NDIndex(i...)])) + end +end @inline function unsafe_getindex(A::LinearIndices{N}, inds::Vararg{Any}) where {N} - if isone(sum(ndims_index(inds))) + if isone(sum(index_ndims(inds))) return @inbounds(eachindex(A)[first(inds)]) elseif stride_preserving_index(typeof(inds)) === True() return LinearIndices(to_axes(A, _ints2range.(inds))) @@ -400,19 +386,158 @@ function unsafe_setindex!(a::A, v) where {A} return unsafe_setindex!(parent(a), v) end unsafe_setindex!(A::Array{T}, v) where {T} = Base.arrayset(false, A, convert(T, v)::T, 1) -@inline unsafe_setindex!(a, v, i::Vararg{Any}) = _unsafe_setindex!(layout(a, i), v) -@inline function _unsafe_setindex!(x::Layouted{S}, v) where {S<:AccessElement} - lyt = instantiate(x) - @inbounds(Base.setindex!(parent(lyt), getfield(lyt, :f)(v), getfield(lyt, :indices))) +@inline function unsafe_setindex!(A, v, i::Vararg{CanonicalInt,N}) where {N} + buf, lyt = layout(A, static(N)) + setlayout!(device(buf), buf, lyt, v, i) +end + +## layouts - TODO finalize `layout(x, access)` design +@inline layout(x, ::StaticInt{N}) where {N} = _layout(x, buffer(x), ArrayIndex{N}(x)) +@inline function _layout(x::X, y::Y, index::ArrayIndex{N}) where {X,Y,N} + b, i = layout(y, index_dimsum(index)) + return b, compose(i, index) +end +# end recursion b/c no new buffer +_layout(x::X, y::X, i::ArrayIndex) where {X} = x, i +# no new buffer and unkown index transformation, s +_layout(x::X, y::X, ::UnkownIndex{N}) where {X,N} = x, IdentityIndex{N}() +# new buffer, but don't know how to transform indices properly +_layout(x::X, y::Y, ::UnkownIndex{N}) where {X,Y,N} = x, IdentityIndex{N}() + +""" + relayout_constructor(::Type{T}) -> Function + +Returns a function that construct a new layout for wrapping sub indices of an array. +This method is called in the context of the indexing arguments and at the array's top level. +Therefore, in the call `relayout_constructor(T)(A, inds) -> layout` the array `A` may be a wrapper +around instance of `T`. + +It is assumed that the return of this function can appropriately recompose an layouted array +via `buffer ∘ layout` +""" +relayout_constructor(::Type{T}) where {T} = nothing + +@inline function _relayout_constructors(::Type{T}) where {T} + if parent_type(T) <: T + return (relayout_constructor(T),) + else + return (relayout_constructor(T), _relayout_constructors(parent_type(T))...) + end +end + +""" + relayout(dest, A, inds) + +Derives the function from [`relayout_constructor`](@ref) for each nested parent type of `A`, +which are then used to construct a layout given the arguments `A` and `inds`, and recompose +`dest`. If `relayout_constructor` returns `nothing` then it is not used to in the +recomposing stage. + + +For example, if `A` had the parent type `B` and `B` had the parent type `C` the following +steps would occure to to derive new layouts: +``` + A--relayout_constructor(A)--> rc_a--> rc_a(A, inds)--> lyt_a + \ + parent_type(A) -> B--relayout_constructor(B)--> rc_b--> rc_b(A, inds)--> lyt_b + \ + parent_type(B)--> C --relayout_constructor(C)--> nothing +``` +These results would finally be called as `dest ∘ lyt_b ∘ lyt_a` +""" +relayout(dest, A, inds) = _relayout(_relayout_constructors(typeof(A)), dest, A, inds) +@generated function _relayout(fxns::F, B, A, inds) where {F} + N = length(F.parameters) + bexpr = :B + for i in N:-1:1 + if !(F.parameters[i] <: Nothing) + bexpr = :(compose($bexpr, getfield(fxns, $i)(A, inds))) + end + end + Expr(:block, Expr(:meta, :inline), bexpr) +end + +@generated getlayout(::CPUTuple, buf::B, lyt::L, inds::I) where {B,L,I} = _tup_lyt(B, L, I) +@generated getlayout(::CPUIndex, buf::B, lyt::L, inds::I) where {B,L,I} = _idx_lyt(B, L, I) +@generated getlayout(::CPUPointer, buf::B, lyt::L, inds::I) where {B,L,I} = _ptr_lyt(B, L, I) + +## CPUTuple +function _tup_lyt(B::Type, L::Type, I::Type) + N = length(I.parameters) + s = Vector{Int}(undef, N) + o = Vector{Int}(undef, N) + static_check = true + @inbounds for i in 1:N + s_i = ArrayInterface.known_length(I.parameters[i]) + if s_i === nothing + static_check = false + break + else + s[i] = s_i + end + o_i = ArrayInterface.known_offset1(I.parameters[i]) + if o_i === nothing + static_check = false + break + else + o[i] = o_i + end + end + if static_check + t = Expr(:tuple) + foreach(i->push!(t.args, :(buf[lyt[$(i...)]])), Iterators.product(map((o_i, s_i) -> o_i:(o_i + s_i -1), o, s)...)) + return t + else # don't know size and offsets so we can't compose tuple statically + return _idx_lyt(B, L, I) + end +end + +## CPUPointer +function _ptr_lyt(B::Type, L::Type, I::Type) + if known(index_ndims(I)) === 1 + _idx_lyt(B, L, I) + else # cant use pointer b/c layout doesn't converge to an integer + _idx_lyt(B, L, I) + end +end + +## CPUIndex +function _idx_lyt(B::Type, L::Type, I::Type) + T = eltype(B) + N = length(I.parameters) + quote + Compat.@inline() + dest = Array{$T}(undef, _mapsub(length, inds)) + D = eachindex(dest) + Dy = iterate(D) + @inbounds Base.Cartesian.@nloops $N j d -> inds[d] begin + # This condition is never hit, but at the moment + # the optimizer is not clever enough to split the union without it + Dy === nothing && return dest + (idx, state) = Dy + dest[idx] = buf[lyt[NDIndex(Base.Cartesian.@ntuple($N, j))]] + Dy = iterate(D, state) + end + return dest + end +end + +function unsafe_setindex!(A, v, inds::Vararg{Any,N}) where {N} + buf, lyt = layout(A, index_dimsum(inds)) + return setlayout!(device(buf), buf, lyt, v, inds) end -@generated function _unsafe_setindex!(x::Layouted{S}, v) where {N,S<:AccessIndices{N}} + +function setlayout!(::AbstractDevice, buf::B, lyt::L, v, inds::Tuple{Vararg{CanonicalInt}}) where {B,L} + @inbounds(Base.setindex!(buf, v, lyt[inds...])) +end +@generated function setlayout!(::AbstractDevice, buf::B, lyt::L, v, inds::Tuple{Vararg{Any,N}}) where {B,L,N} + _setlayout!(N) +end + +function _setlayout!(N::Int) quote - lyt = instantiate(x) - buf = parent(lyt) - I = getfield(lyt, :indices) - f = getfield(lyt, :f) x′ = Base.unalias(buf, v) - Base.Cartesian.@nexprs $N d -> (I_d = Base.unalias(buf, I[d])) + Base.Cartesian.@nexprs $N d -> (I_d = Base.unalias(buf, inds[d])) idxlens = Base.Cartesian.@ncall $N Base.index_lengths I Base.Cartesian.@ncall $N Base.setindex_shape_check x′ (d -> idxlens[d]) Xy = iterate(x′) @@ -421,7 +546,7 @@ end # the optimizer that it does not need to emit error paths Xy === nothing && break (val, state) = Xy - buf[NDIndex(Base.Cartesian.@ntuple($N, i))] = f(val) + buf[lyt[NDIndex(Base.Cartesian.@ntuple($N, i))]] = val Xy = iterate(x′, state) end end diff --git a/src/size.jl b/src/size.jl index 37b7c6d8a..821063db9 100644 --- a/src/size.jl +++ b/src/size.jl @@ -23,8 +23,10 @@ function size(a::A) where {A} end end -size(x::SubArray) = eachop(_sub_size, to_parent_dims(x), x.indices) -_sub_size(x::Tuple, ::StaticInt{dim}) where {dim} = static_length(getfield(x, dim)) +size(x::SubArray) = _subsize(x.indices) +@inline _subsize(x::Tuple) = _mapsub(static_length, x) +Base.size(x::SubIndex) = size(x) +size(x::SubIndex) = _subsize(getfield(x, :indices)) @inline size(B::VecAdjTrans) = (One(), length(parent(B))) @inline size(B::MatAdjTrans) = permute(size(parent(B)), (static(2), static(1))) @inline size(B::PermutedDimsArray{T,N,I}) where {T,N,I} = permute(size(parent(B)), static(I)) diff --git a/src/stridelayout.jl b/src/stridelayout.jl index 6a7ba5869..c79c54638 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -35,7 +35,6 @@ function known_offsets(::Type{T}, dim::Integer) where {T} return known_offsets(T)[dim] end end - known_offsets(x) = known_offsets(typeof(x)) function known_offsets(::Type{T}) where {T} return eachop(_known_offsets, nstatic(Val(ndims(T))), axes_types(T)) diff --git a/test/array_index.jl b/test/array_index.jl index 03792ce6c..6e36479a0 100644 --- a/test/array_index.jl +++ b/test/array_index.jl @@ -1,13 +1,13 @@ function test_layout(x) @testset "$x" begin - linear_lyt = ArrayInterface.instantiate(ArrayInterface.layout(x, ArrayInterface.AccessElement{1}())) + linbuf, linlyt = ArrayInterface.layout(x, static(1)) for i in eachindex(IndexLinear(), x) - @test linear_lyt[i] == x[i] + @test linbuf[linlyt[i]] == x[i] end - cartesian_lyt = ArrayInterface.instantiate(ArrayInterface.layout(x, ArrayInterface.AccessElement{ndims(x)}())) + carbuf, carlyt = ArrayInterface.layout(x, static(ndims(x))) for i in eachindex(IndexCartesian(), x) - @test cartesian_lyt[i] == x[i] + @test carbuf[carlyt[i]] == x[i] end end return nothing diff --git a/test/dimensions.jl b/test/dimensions.jl index 838774848..0233c26cf 100644 --- a/test/dimensions.jl +++ b/test/dimensions.jl @@ -12,9 +12,7 @@ end ArrayInterface.parent_type(::Type{T}) where {P,T<:NamedDimsWrapper{<:Any,<:Any,<:Any,P}} = P ArrayInterface.dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L}} = static(L) Base.parent(x::NamedDimsWrapper) = x.parent -function ArrayInterface.layout(x::NamedDimsWrapper, s::ArrayInterface.AccessElement) - ArrayInterface.Layouted{typeof(s)}(parent(x), nothing) -end +ArrayInterface.ArrayIndex{N}(x::NamedDimsWrapper) where {N} = ArrayInterface.IdentityIndex{N}() @testset "dimension permutations" begin a = ones(2, 2, 2) diff --git a/test/indexing.jl b/test/indexing.jl index 93dae57fc..fbd799b7c 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -9,12 +9,12 @@ @test @inferred(ArrayInterface.canonicalize(Int32(2):Int32(1):Int32(3))) isa ArrayInterface.OptionallyStaticStepRange{Int,Int,Int} end -@testset "ndims_index" begin - @test @inferred(ArrayInterface.ndims_index((1, CartesianIndex(1,2)))) === static((1, 2)) - @test @inferred(ArrayInterface.ndims_index((1, [CartesianIndex(1,2), CartesianIndex(1,3)]))) === static((1, 2)) - @test @inferred(ArrayInterface.ndims_index((1, CartesianIndex((2,2))))) === static((1, 2)) - @test @inferred(ArrayInterface.ndims_index((CartesianIndex((2,2)), :, :))) === static((2, 1, 1)) - @test @inferred(ArrayInterface.ndims_index(Vector{Int})) === static(1) +@testset "index_ndims" begin + @test @inferred(ArrayInterface.index_ndims((1, CartesianIndex(1,2)))) === static((1, 2)) + @test @inferred(ArrayInterface.index_ndims((1, [CartesianIndex(1,2), CartesianIndex(1,3)]))) === static((1, 2)) + @test @inferred(ArrayInterface.index_ndims((1, CartesianIndex((2,2))))) === static((1, 2)) + @test @inferred(ArrayInterface.index_ndims((CartesianIndex((2,2)), :, :))) === static((2, 1, 1)) + @test @inferred(ArrayInterface.index_ndims(Vector{Int})) === static(1) end @testset "to_index" begin diff --git a/test/runtests.jl b/test/runtests.jl index 3e3536db8..e12324b6f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -254,9 +254,7 @@ end ArrayInterface.parent_type(::Type{<:Wrapper{T,N,P}}) where {T,N,P} = P Base.parent(x::Wrapper) = x.parent ArrayInterface.device(::Type{T}) where {T<:Wrapper} = ArrayInterface.device(parent_type(T)) -function ArrayInterface.layout(x::Wrapper, s::ArrayInterface.AccessElement) - ArrayInterface.Layouted{typeof(s)}(parent(x), nothing) -end +ArrayInterface.ArrayIndex{N}(x::Wrapper) where {N} = ArrayInterface.IdentityIndex{N}() struct DenseWrapper{T,N,P<:AbstractArray{T,N}} <: DenseArray{T,N} end ArrayInterface.parent_type(::Type{DenseWrapper{T,N,P}}) where {T,N,P} = P From aa14e0e194d27dde7003219e7329f61c924a4060 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 4 Oct 2021 11:14:56 -0400 Subject: [PATCH 12/16] Fix docstring --- src/indexing.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/indexing.jl b/src/indexing.jl index 9f3995936..101619f10 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -436,13 +436,13 @@ recomposing stage. For example, if `A` had the parent type `B` and `B` had the parent type `C` the following steps would occure to to derive new layouts: -``` - A--relayout_constructor(A)--> rc_a--> rc_a(A, inds)--> lyt_a - \ - parent_type(A) -> B--relayout_constructor(B)--> rc_b--> rc_b(A, inds)--> lyt_b - \ - parent_type(B)--> C --relayout_constructor(C)--> nothing -``` + + A--relayout_constructor(A)--> rc_a--> rc_a(A, inds)--> lyt_a + \ + parent_type(A) -> B--relayout_constructor(B)--> rc_b--> rc_b(A, inds)--> lyt_b + \ + parent_type(B)--> C --relayout_constructor(C)--> nothing + These results would finally be called as `dest ∘ lyt_b ∘ lyt_a` """ relayout(dest, A, inds) = _relayout(_relayout_constructors(typeof(A)), dest, A, inds) From af9276676cd527c60409c02fce001de9dc1041a1 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 4 Oct 2021 11:37:18 -0400 Subject: [PATCH 13/16] Delete bad escape sequences --- src/indexing.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/indexing.jl b/src/indexing.jl index 101619f10..62214386f 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -433,15 +433,11 @@ which are then used to construct a layout given the arguments `A` and `inds`, an `dest`. If `relayout_constructor` returns `nothing` then it is not used to in the recomposing stage. - -For example, if `A` had the parent type `B` and `B` had the parent type `C` the following -steps would occure to to derive new layouts: - A--relayout_constructor(A)--> rc_a--> rc_a(A, inds)--> lyt_a - \ - parent_type(A) -> B--relayout_constructor(B)--> rc_b--> rc_b(A, inds)--> lyt_b - \ - parent_type(B)--> C --relayout_constructor(C)--> nothing + | + parent_type(A) -> B--relayout_constructor(B)--> rc_b--> rc_b(A, inds)--> lyt_b + | + parent_type(B)--> C --relayout_constructor(C)--> nothing These results would finally be called as `dest ∘ lyt_b ∘ lyt_a` """ From b42476648283aef3ed26f694c5f32c115d773eb3 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 4 Oct 2021 18:50:53 -0400 Subject: [PATCH 14/16] Remove unnecessary custom bounds checking --- src/indexing.jl | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/indexing.jl b/src/indexing.jl index 62214386f..83ec3a1dd 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -152,7 +152,7 @@ end return LogicalIndex{Int}(arg) end @propagate_inbounds function to_index(::IndexLinear, x, arg::AbstractArray{<:AbstractCartesianIndex}) - @boundscheck _multi_check_index(axes(x), arg) || throw(BoundsError(x, arg)) + @boundscheck Base.checkindex(Bool, axes(x), arg) || throw(BoundsError(x, arg)) return arg end @propagate_inbounds function to_index(::IndexLinear, x, arg::LogicalIndex) @@ -181,11 +181,11 @@ to_index(::IndexCartesian, x, arg::Colon) = CartesianIndices(x) to_index(::IndexCartesian, x, arg::CartesianIndices{0}) = arg to_index(::IndexCartesian, x, arg::AbstractCartesianIndex) = arg function to_index(::IndexCartesian, x, arg) - @boundscheck _multi_check_index(axes(x), arg) || throw(BoundsError(x, arg)) + @boundscheck Base.checkindex(Bool, axes(x), arg) || throw(BoundsError(x, arg)) return arg end @propagate_inbounds function to_index(::IndexCartesian, x, arg::AbstractArray{<:AbstractCartesianIndex}) - @boundscheck _multi_check_index(axes(x), arg) || throw(BoundsError(x, arg)) + @boundscheck Base.checkindex(Bool, axes(x), arg) || throw(BoundsError(x, arg)) return arg end @propagate_inbounds function to_index(::IndexCartesian, x, arg::AbstractArray{Bool}) @@ -193,14 +193,6 @@ end return LogicalIndex(arg) end -function _multi_check_index(axs::Tuple, arg::AbstractArray{T}) where {T<:AbstractCartesianIndex} - b = true - for i in arg - b &= Base.checkbounds_indices(Bool, axs, (i,)) - end - return b -end - @propagate_inbounds function to_index(::IndexCartesian, x, arg::Union{Array{Bool}, BitArray}) @boundscheck checkbounds(x, arg) return LogicalIndex{Int}(arg) From 9f631add90b027067e67280adbe00405cb188dfc Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Tue, 5 Oct 2021 10:41:02 -0400 Subject: [PATCH 15/16] Update docs --- docs/src/api.md | 9 ++++++++ src/array_index.jl | 36 +++++------------------------ src/indexing.jl | 55 +++++++++++++++++++-------------------------- test/array_index.jl | 8 +++---- 4 files changed, 42 insertions(+), 66 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 5cecb20a3..16afe058e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -44,6 +44,7 @@ ArrayInterface.axes_types ArrayInterface.broadcast_axis ArrayInterface.buffer ArrayInterface.canonicalize +ArrayInterface.compose ArrayInterface.deleteat ArrayInterface.dense_dims ArrayInterface.findstructralnz @@ -58,6 +59,8 @@ ArrayInterface.offset1 ArrayInterface.offsets ArrayInterface.parent_type ArrayInterface.reduce_tup +ArrayInterface.relayout +ArrayInterface.relayout_constructor ArrayInterface.restructure ArrayInterface.safevec ArrayInterface.setindex! @@ -78,9 +81,15 @@ ArrayInterface.zeromatrix ```@docs ArrayInterface.ArrayIndex ArrayInterface.BroadcastAxis +ArrayInterface.ComposedIndex ArrayInterface.LazyAxis +ArrayInterface.LinearSubIndex +ArrayInterface.IdentityIndex ArrayInterface.OptionallyStaticStepRange ArrayInterface.OptionallyStaticUnitRange +ArrayInterface.PermutedIndex +ArrayInterface.SubIndex ArrayInterface.StrideIndex +ArrayInterface.UnkownIndex ``` diff --git a/src/array_index.jl b/src/array_index.jl index 36ea3de8e..012fb912c 100644 --- a/src/array_index.jl +++ b/src/array_index.jl @@ -261,7 +261,12 @@ This default return type when calling `ArrayIndex{N}(x)`. """ struct UnkownIndex{N} <: ArrayIndex{N} end +""" + ComposedIndex(outer, inner) +A subtype of `ArrayIndex` that lazily combines index `outer` and `inner`. Indexing a +`ComposedIndex` whith `i` is equivalent to `outer[inner[i]]`. +""" struct ComposedIndex{N,O,I} <: ArrayIndex{N} outer::O inner::I @@ -431,35 +436,7 @@ end return NDIndex(permute(Tuple(i), Val(I2))) end @inline function Base.getindex(x::SubIndex{N}, i::AbstractCartesianIndex{N}) where {N} - return NDIndex(_reindex(x.indices, Tuple(i))) -end -@generated function _reindex(subinds::S, inds::I) where {S,I} - inds_i = 1 - subinds_i = 1 - NS = known_length(S) - NI = known_length(I) - out = Expr(:tuple) - while inds_i <= NI - subinds_type = S.parameters[subinds_i] - if subinds_type <: Integer - push!(out.args, :(getfield(subinds, $subinds_i))) - subinds_i += 1 - elseif eltype(subinds_type) <: AbstractCartesianIndex - push!(out.args, :(Tuple(@inbounds(getfield(subinds, $subinds_i)[getfield(inds, $inds_i)]))...)) - inds_i += 1 - subinds_i += 1 - else - push!(out.args, :(@inbounds(getfield(subinds, $subinds_i)[getfield(inds, $inds_i)]))) - inds_i += 1 - subinds_i += 1 - end - end - if subinds_i <= NS - for i in subinds_i:NS - push!(out.args, :(getfield(subinds, $subinds_i))) - end - end - return Expr(:block, Expr(:meta, :inline), :($out)) + return NDIndex(Base.reindex(getfield(x, :indices), Tuple(i))) end @inline Base.getindex(x::LinearSubIndex, i::CanonicalInt) = offset1(x) + stride1(x) * i @propagate_inbounds function Base.getindex(ind::BidiagonalIndex, i::Int) @@ -575,7 +552,6 @@ function Base.show(io::IO, m::MIME"text/plain", @nospecialize(x::ComposedIndex)) show(io, m, outer(x)) print(io, " ∘ ") show(io, m, inner(x)) - #print(io, "$(outer(x)) ∘ $(inner(x))") end diff --git a/src/indexing.jl b/src/indexing.jl index 83ec3a1dd..7f7e4693f 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -23,7 +23,7 @@ end """ canonicalize(x) -Checks if `x` is in a canonical form for indexing. If `x` is already in a canonical form +Checks if `x` is in canonical form for indexing. If `x` is already in a canonical form then it is returned unchanged. If `x` is not in a canonical form then it is passed to `canonical_convert`. """ @@ -50,12 +50,12 @@ indexing form, and that they are inbounds. Unless all indices in `inds` return ` on a call to [`is_canonical`](@ref), then they each are checked at the axis level with [`to_index`](@ref). """ -@propagate_inbounds to_indices(A, ::Tuple{}) = to_indices(A, lazy_axes(A), ()) -@propagate_inbounds to_indices(A, inds::Tuple) = _to_indices(is_canonical(inds), A, inds) -@propagate_inbounds function to_indices(A, inds::Tuple{LinearIndices}) +@inline @propagate_inbounds to_indices(A, ::Tuple{}) = to_indices(A, lazy_axes(A), ()) +@inline @propagate_inbounds to_indices(A, inds::Tuple) = _to_indices(is_canonical(inds), A, inds) +@inline @propagate_inbounds function to_indices(A, inds::Tuple{LinearIndices}) to_indices(A, lazy_axes(A), axes(getfield(inds, 1))) end -@propagate_inbounds function _to_indices(::True, A, inds) +@inline @propagate_inbounds function _to_indices(::True, A, inds) if isone(sum(index_ndims(inds))) @boundscheck if !checkindex(Bool, eachindex(IndexLinear(), A), getfield(inds, 1)) throw(BoundsError(A, inds)) @@ -74,26 +74,26 @@ end end end -@propagate_inbounds function _to_indices(::False, A, inds) +@inline @propagate_inbounds function _to_indices(::False, A, inds) if isone(sum(index_ndims(inds))) return (to_index(LazyAxis{:}(A), getfield(inds, 1)),) else return to_indices(A, lazy_axes(A), inds) end end -@propagate_inbounds function to_indices(A, axs, inds::Tuple{<:AbstractCartesianIndex,Vararg{Any}}) +@inline @propagate_inbounds function to_indices(A, axs, inds::Tuple{<:AbstractCartesianIndex,Vararg{Any}}) to_indices(A, axs, (Tuple(getfield(inds, 1))..., tail(inds)...)) end -@propagate_inbounds function to_indices(A, axs, inds::Tuple{I,Vararg{Any}}) where {I} +@inline @propagate_inbounds function to_indices(A, axs, inds::Tuple{I,Vararg{Any}}) where {I} _to_indices(index_ndims(I), A, axs, inds) end -@propagate_inbounds function _to_indices(::StaticInt{1}, A, axs, inds) +@inline @propagate_inbounds function _to_indices(::StaticInt{1}, A, axs, inds) (to_index(_maybe_first(axs), getfield(inds, 1)), to_indices(A, _maybe_tail(axs), _maybe_tail(inds))...) end -@propagate_inbounds function _to_indices(::StaticInt{N}, A, axs, inds) where {N} +@inline @propagate_inbounds function _to_indices(::StaticInt{N}, A, axs, inds) where {N} axsfront, axstail = Base.IteratorsMD.split(axs, Val(N)) if IndexStyle(A) === IndexLinear() index = to_index(LinearIndices(axsfront), getfield(inds, 1)) @@ -103,14 +103,14 @@ end return (index, to_indices(A, axstail, _maybe_tail(inds))...) end # When used as indices themselves, CartesianIndices can simply become its tuple of ranges -@propagate_inbounds function to_indices(A, axs, inds::Tuple{CartesianIndices, Vararg{Any}}) +@inline @propagate_inbounds function to_indices(A, axs, inds::Tuple{CartesianIndices, Vararg{Any}}) to_indices(A, axs, (axes(getfield(inds, 1))..., tail(inds)...)) end # but preserve CartesianIndices{0} as they consume a dimension. -@propagate_inbounds function to_indices(A, axs, inds::Tuple{CartesianIndices{0},Vararg{Any}}) +@inline @propagate_inbounds function to_indices(A, axs, inds::Tuple{CartesianIndices{0},Vararg{Any}}) (getfield(inds, 1), to_indices(A, _maybe_tail(axs), tail(inds))...) end -@propagate_inbounds function to_indices(A, axs, ::Tuple{}) +@inline @propagate_inbounds function to_indices(A, axs, ::Tuple{}) @boundscheck if length(getfield(axs, 1)) != 1 error("Cannot drop dimension of size $(length(first(axs))).") end @@ -141,7 +141,8 @@ function to_index(s, axis, arg) throw(ArgumentError("invalid index: IndexStyle $s does not support indices of " * "type $(typeof(arg)) for instances of type $(typeof(axis)).")) end -to_index(::IndexLinear, axis, arg::Colon) = indices(axis) # Colons get converted to slices by `indices` +# Colons get converted to slices by `indices` +@inline to_index(::IndexLinear, axis, arg::Colon) = indices(axis) to_index(::IndexLinear, axis, arg::CartesianIndices{0}) = arg to_index(::IndexLinear, axis, arg::CartesianIndices{1}) = axes(arg, 1) @propagate_inbounds function to_index(::IndexLinear, axis, arg::AbstractCartesianIndex{1}) @@ -369,13 +370,13 @@ Store the given values at the given key or index within a collection. end end @propagate_inbounds function setindex!(A, val; kwargs...) - return unsafe_setindex!(A, val, to_indices(A, order_named_inds(dimnames(A), values(kwargs)))...) + unsafe_setindex!(A, val, to_indices(A, order_named_inds(dimnames(A), values(kwargs)))...) end ## unsafe_setindex! ## function unsafe_setindex!(a::A, v) where {A} parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v))) - return unsafe_setindex!(parent(a), v) + unsafe_setindex!(parent(a), v) end unsafe_setindex!(A::Array{T}, v) where {T} = Base.arrayset(false, A, convert(T, v)::T, 1) @inline function unsafe_setindex!(A, v, i::Vararg{CanonicalInt,N}) where {N} @@ -445,11 +446,9 @@ relayout(dest, A, inds) = _relayout(_relayout_constructors(typeof(A)), dest, A, Expr(:block, Expr(:meta, :inline), bexpr) end -@generated getlayout(::CPUTuple, buf::B, lyt::L, inds::I) where {B,L,I} = _tup_lyt(B, L, I) -@generated getlayout(::CPUIndex, buf::B, lyt::L, inds::I) where {B,L,I} = _idx_lyt(B, L, I) -@generated getlayout(::CPUPointer, buf::B, lyt::L, inds::I) where {B,L,I} = _ptr_lyt(B, L, I) ## CPUTuple +@generated getlayout(::CPUTuple, buf::B, lyt::L, inds::I) where {B,L,I} = _tup_lyt(B, L, I) function _tup_lyt(B::Type, L::Type, I::Type) N = length(I.parameters) s = Vector{Int}(undef, N) @@ -480,22 +479,14 @@ function _tup_lyt(B::Type, L::Type, I::Type) end end -## CPUPointer -function _ptr_lyt(B::Type, L::Type, I::Type) - if known(index_ndims(I)) === 1 - _idx_lyt(B, L, I) - else # cant use pointer b/c layout doesn't converge to an integer - _idx_lyt(B, L, I) - end -end - ## CPUIndex -function _idx_lyt(B::Type, L::Type, I::Type) - T = eltype(B) +function getlayout(::AbstractDevice, buf::B, lyt::L, inds::I) where {B,L,I} + _idx_lyt(similar(buf, Base.index_shape(inds...)), buf, lyt, inds) +end +@generated function _idx_lyt(dest, src, lyt, inds::I) where {I} N = length(I.parameters) quote Compat.@inline() - dest = Array{$T}(undef, _mapsub(length, inds)) D = eachindex(dest) Dy = iterate(D) @inbounds Base.Cartesian.@nloops $N j d -> inds[d] begin @@ -503,7 +494,7 @@ function _idx_lyt(B::Type, L::Type, I::Type) # the optimizer is not clever enough to split the union without it Dy === nothing && return dest (idx, state) = Dy - dest[idx] = buf[lyt[NDIndex(Base.Cartesian.@ntuple($N, j))]] + dest[idx] = src[lyt[NDIndex(Base.Cartesian.@ntuple($N, j))]] Dy = iterate(D, state) end return dest diff --git a/test/array_index.jl b/test/array_index.jl index 6e36479a0..3a28a8f90 100644 --- a/test/array_index.jl +++ b/test/array_index.jl @@ -1,13 +1,13 @@ function test_layout(x) @testset "$x" begin - linbuf, linlyt = ArrayInterface.layout(x, static(1)) + linbuf, linlyt = @inferred(ArrayInterface.layout(x, static(1))) for i in eachindex(IndexLinear(), x) - @test linbuf[linlyt[i]] == x[i] + @test @inferred(linbuf[linlyt[i]]) == x[i] end - carbuf, carlyt = ArrayInterface.layout(x, static(ndims(x))) + carbuf, carlyt = @inferred(ArrayInterface.layout(x, static(ndims(x)))) for i in eachindex(IndexCartesian(), x) - @test carbuf[carlyt[i]] == x[i] + @test @inferred(carbuf[carlyt[i]]) == x[i] end end return nothing From fc95c21e8fbdee3237cfaa4292199e37b99f63b7 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Fri, 15 Oct 2021 11:48:34 -0400 Subject: [PATCH 16/16] Get rid of stail indexing code --- docs/src/api.md | 3 -- src/indexing.jl | 108 ++++------------------------------------------- test/indexing.jl | 32 -------------- 3 files changed, 9 insertions(+), 134 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 16afe058e..c236f16db 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -66,13 +66,10 @@ ArrayInterface.safevec ArrayInterface.setindex! ArrayInterface.size ArrayInterface.strides -ArrayInterface.to_axes -ArrayInterface.to_axis ArrayInterface.to_dims ArrayInterface.to_index ArrayInterface.to_indices ArrayInterface.to_parent_dims -ArrayInterface.unsafe_reconstruct ArrayInterface.zeromatrix ``` diff --git a/src/indexing.jl b/src/indexing.jl index 7f7e4693f..2e8cec7ab 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -199,99 +199,6 @@ end return LogicalIndex{Int}(arg) end -""" - unsafe_reconstruct(A, data; kwargs...) - -Reconstruct `A` given the values in `data`. New methods using `unsafe_reconstruct` -should only dispatch on `A`. -""" -function unsafe_reconstruct(axis::OneTo, data; kwargs...) - if axis === data - return axis - else - return OneTo(data) - end -end -function unsafe_reconstruct(axis::UnitRange, data; kwargs...) - if axis === data - return axis - else - return UnitRange(first(data), last(data)) - end -end -function unsafe_reconstruct(axis::OptionallyStaticUnitRange, data; kwargs...) - if axis === data - return axis - else - return OptionallyStaticUnitRange(static_first(data), static_last(data)) - end -end -function unsafe_reconstruct(A::AbstractUnitRange, data; kwargs...) - return static_first(data):static_last(data) -end - -""" - to_axes(A, inds) -> Tuple - -Construct new axes given the corresponding `inds` constructed after -`to_indices(A, args) -> inds`. This method iterates through each pair of axes and -indices calling [`to_axis`](@ref). -""" -@inline function to_axes(A, inds::Tuple) - if ndims(A) === 1 - return (to_axis(axes(A, 1), first(inds)),) - elseif isone(sum(index_ndims(inds))) - return (to_axis(eachindex(IndexLinear(), A), first(inds)),) - else - return to_axes(A, axes(A), inds) - end -end -# drop this dimension -to_axes(A, a::Tuple, i::Tuple{<:Integer,Vararg{Any}}) = to_axes(A, tail(a), tail(i)) -to_axes(A, a::Tuple, i::Tuple{I,Vararg{Any}}) where {I} = _to_axes(index_ndims(I), A, a, i) -function _to_axes(::StaticInt{1}, A, axs::Tuple, inds::Tuple) - return (to_axis(first(axs), first(inds)), to_axes(A, tail(axs), tail(inds))...) -end -@propagate_inbounds function _to_axes(::StaticInt{N}, A, axs::Tuple, inds::Tuple) where {N} - axes_front, axes_tail = Base.IteratorsMD.split(axs, Val(N)) - if IndexStyle(A) === IndexLinear() - axis = to_axis(LinearIndices(axes_front), getfield(inds, 1)) - else - axis = to_axis(CartesianIndices(axes_front), getfield(inds, 1)) - end - return (axis, to_axes(A, axes_tail, tail(inds))...) -end -to_axes(A, ::Tuple{Ax,Vararg{Any}}, ::Tuple{}) where {Ax} = () -to_axes(A, ::Tuple{}, ::Tuple{}) = () - -""" - to_axis(old_axis, index) -> new_axis - -Construct an `new_axis` for a newly constructed array that corresponds to the -previously executed `to_index(old_axis, arg) -> index`. `to_axis` assumes that -`index` has already been confirmed to be in bounds. The underlying indices of -`new_axis` begins at one and extends the length of `index` (i.e., one-based indexing). -""" -@inline function to_axis(axis, inds) - if !can_change_size(axis) && - (known_length(inds) !== nothing && known_length(axis) === known_length(inds)) - return axis - else - return to_axis(IndexStyle(axis), axis, inds) - end -end - -# don't need to check size b/c slice means it's the entire axis -@inline function to_axis(axis, inds::Slice) - if can_change_size(axis) - return copy(axis) - else - return axis - end -end -to_axis(S::IndexLinear, axis, inds) = StaticInt(1):static_length(inds) - - ################ ### getindex ### ################ @@ -300,8 +207,8 @@ to_axis(S::IndexLinear, axis, inds) = StaticInt(1):static_length(inds) Retrieve the value(s) stored at the given key or index within a collection. Creating another instance of `ArrayInterface.getindex` should only be done by overloading `A`. -Changing indexing based on a given argument from `args` should be done through, -[`to_index`](@ref), or [`to_axis`](@ref). +Changing indexing based on a given argument from `args` should be done through +[`to_index`](@ref). """ @propagate_inbounds getindex(A, args...) = unsafe_getindex(A, to_indices(A, args)...) @propagate_inbounds function getindex(A; kwargs...) @@ -326,14 +233,17 @@ end end ## CartesianIndices/LinearIndices -_ints2range(x::Integer) = x:x -_ints2range(x::AbstractRange) = x +# TODO replace _ints2range with something that actually indexes each axis +_ints2range(::Tuple{}) = () +@inline _ints2range(x::Tuple{Any,Vararg{Any}}) = (getfield(x, 1), _ints2range(tail(x))...) +@inline _ints2range(x::Tuple{<:Integer,Vararg{Any}}) = _ints2range(tail(x)) + unsafe_getindex(A::CartesianIndices, i::Vararg{CanonicalInt,N}) where {N} = @inbounds(A[CartesianIndex(i)]) @inline function unsafe_getindex(A::CartesianIndices{N}, inds::Vararg{Any}) where {N} if (length(inds) === 1 && N > 1) || stride_preserving_index(typeof(inds)) === False() return Base._getindex(IndexStyle(A), A, inds...) else - return CartesianIndices(to_axes(A, _ints2range.(inds))) + return CartesianIndices(_ints2range(inds)) end end @@ -348,7 +258,7 @@ end if isone(sum(index_ndims(inds))) return @inbounds(eachindex(A)[first(inds)]) elseif stride_preserving_index(typeof(inds)) === True() - return LinearIndices(to_axes(A, _ints2range.(inds))) + return LinearIndices(_ints2range(inds)) else return Base._getindex(IndexStyle(A), A, inds...) end diff --git a/test/indexing.jl b/test/indexing.jl index fbd799b7c..0b16d4462 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -37,20 +37,6 @@ end @test_throws ArgumentError ArrayInterface.to_index(axis, error) end -@testset "unsafe_reconstruct" begin - one_to = Base.OneTo(10) - opt_ur = StaticInt(1):10 - ur = 1:10 - @test @inferred(ArrayInterface.unsafe_reconstruct(one_to, opt_ur)) === one_to - @test @inferred(ArrayInterface.unsafe_reconstruct(one_to, one_to)) === one_to - - @test @inferred(ArrayInterface.unsafe_reconstruct(opt_ur, opt_ur)) === opt_ur - @test @inferred(ArrayInterface.unsafe_reconstruct(opt_ur, one_to)) === opt_ur - - @test @inferred(ArrayInterface.unsafe_reconstruct(ur, ur)) === ur - @test @inferred(ArrayInterface.unsafe_reconstruct(ur, one_to)) === ur -end - @testset "to_indices" begin a = ones(2,2,1) v = ones(2) @@ -82,24 +68,6 @@ end # FIXME @test_throws ErrorException ArrayInterface.to_indices(ones(2,2,2), (1, 1)) end -@testset "to_axes" begin - A = ones(3, 3) - axis = StaticInt(1):StaticInt(3) - inds = StaticInt(1):StaticInt(2) - multi_inds = [CartesianIndex(1, 1), CartesianIndex(1, 2)] - - @test @inferred(ArrayInterface.to_axes(A, (axis, axis), (inds, inds))) === (inds, inds) - # vector indexing - @test @inferred(ArrayInterface.to_axes(ones(3), (axis,), (inds,))) === (inds,) - # linear indexing - @test @inferred(ArrayInterface.to_axes(A, (axis, axis), (inds,))) === (inds,) - # multidim arg - @test @inferred(ArrayInterface.to_axes(A, (axis, axis), (multi_inds,))) === (static(1):2,) - - @test ArrayInterface.to_axis(axis, axis) === axis - @test ArrayInterface.to_axis(axis, ArrayInterface.indices(axis)) === axis -end - @testset "0-dimensional" begin x = Array{Int,0}(undef) ArrayInterface.setindex!(x, 1)