Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursive indexing of concatenated arrays #133

Merged
merged 6 commits into from
Sep 2, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 60 additions & 48 deletions src/lazyconcat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,29 @@ function ==(a::Vcat{T,1,II}, b::Vcat{T,1,II}) where {T,II}
all(arguments(a) .== arguments(b))
end

@propagate_inbounds @inline function vcat_getindex(f, k::Integer)
@propagate_inbounds @inline vcat_getindex(f, idx::Vararg{Integer}) =
vcat_getindex_recursive(f, idx, f.args...)

@propagate_inbounds @inline function vcat_getindex_recursive(
f, idx::NTuple{1}, A, args...)
k, = idx
T = eltype(f)
κ = k
for A in f.args
n = length(A)
κ ≤ n && return convert(T,A[κ])::T
κ -= n
end
throw(BoundsError(f, k))
n = length(A)
k ≤ n && return convert(T, A[k])::T
vcat_getindex_recursive(f, (k - n, ), args...)
end

@propagate_inbounds @inline function vcat_getindex(f, k::Integer, j::Integer)
@propagate_inbounds @inline function vcat_getindex_recursive(
f, idx::NTuple{2}, A, args...)
k, j = idx
T = eltype(f)
κ = k
for A in f.args
n = size(A,1)
κ ≤ n && return convert(T,A[κ,j])::T
κ -= n
end
throw(BoundsError(f, (k,j)))
n = size(A, 1)
k ≤ n && return convert(T, A[k, j])::T
vcat_getindex_recursive(f, (k - n, j), args...)
end

@inline vcat_getindex_recursive(f, idx) = throw(BoundsError(f, idx))

@propagate_inbounds @inline getindex(f::Vcat{<:Any,1}, k::Integer) = vcat_getindex(f, k)
@propagate_inbounds @inline getindex(f::Vcat{<:Any,2}, k::Integer, j::Integer) = vcat_getindex(f, k, j)
getindex(f::Applied{DefaultArrayApplyStyle,typeof(vcat)}, k::Integer)= vcat_getindex(f, k)
Expand All @@ -77,24 +78,27 @@ getindex(f::Applied{<:Any,typeof(vcat)}, k::Integer, j::Integer)= vcat_getindex(
copy(f::Vcat) = Vcat(map(copy, f.args)...)
map(::typeof(copy), f::Vcat) = Vcat(map.(copy, f.args)...)

@propagate_inbounds @inline function setindex!(f::Vcat{T,1}, v, k::Integer) where T
κ = k
for A in f.args
n = length(A)
κ ≤ n && return setindex!(A, v, κ)
κ -= n
end
throw(BoundsError(f, k))
@propagate_inbounds @inline function vcat_setindex_recursive!(
f::Vcat{T,1} where T, v, idx::NTuple{1}, A, args...)
k, = idx
n = length(A)
k ≤ n && return setindex!(A, v, idx...)
vcat_setindex_recursive!(f, v, (k - n, ), args...)
end

@propagate_inbounds @inline function setindex!(f::Vcat{T,2}, v, k::Integer, j::Integer) where T
κ = k
for A in f.args
n = size(A,1)
κ ≤ n && return setindex!(A, v, κ, j)
κ -= n
end
throw(BoundsError(f, (k,j)))
@propagate_inbounds @inline function vcat_setindex_recursive!(
f::Vcat{T,2} where T, v, idx::NTuple{2}, A, args...)
k, j = idx
n = size(A, 1)
k ≤ n && return setindex!(A, v, idx...)
vcat_setindex_recursive!(f, v, (k - n, j), args...)
end

@inline vcat_setindex_recursive!(f, v, idx) = throw(BoundsError(f, idx))

@propagate_inbounds @inline function setindex!(
f::Vcat{T,N}, v, idx::Vararg{Integer,N}) where {T,N}
vcat_setindex_recursive!(f, v, idx, f.args...)
end

reverse(f::Vcat{<:Any,1}) = Vcat((reverse(itr) for itr in reverse(f.args))...)
Expand All @@ -119,32 +123,40 @@ ndims(::Applied{<:Any,typeof(hcat)}) = 2
size(f::Applied{<:Any,typeof(hcat)}) = (size(f.args[1],1), +(map(a -> size(a,2), f.args)...))
Base.IndexStyle(::Type{<:Hcat}) where T = Base.IndexCartesian()

function hcat_getindex(f, k::Integer, j::Integer)
@inline hcat_getindex(f, k::Integer, j::Integer) =
hcat_getindex_recursive(f, (k, j), f.args...)

@inline function hcat_getindex_recursive(
f, idx::NTuple{2}, A, args...)
k, j = idx
T = eltype(f)
ξ = j
for A in f.args
n = size(A,2)
ξ ≤ n && return T(A[k,ξ])::T
ξ -= n
end
throw(BoundsError(f, (k,j)))
n = size(A, 2)
j ≤ n && return convert(T, A[k, j])::T
hcat_getindex_recursive(f, (k, j - n), args...)
end

@inline hcat_getindex_recursive(f, idx) = throw(BoundsError(f, idx))

getindex(f::Hcat, k::Integer, j::Integer) = hcat_getindex(f, k, j)
getindex(f::Applied{DefaultArrayApplyStyle,typeof(hcat)}, k::Integer, j::Integer)= hcat_getindex(f, k, j)
getindex(f::Applied{<:Any,typeof(hcat)}, k::Integer, j::Integer)= hcat_getindex(f, k, j)

# since its mutable we need to make a copy
copy(f::Hcat) = Hcat(map(copy, f.args)...)

@inline function hcat_setindex_recursive!(
f, v, idx::NTuple{2}, A, args...)
k, j = idx
T = eltype(f)
n = size(A, 2)
j ≤ n && return setindex!(A, v, k, j)
hcat_setindex_recursive!(f, v, (k, j - n), args...)
end

@inline hcat_setindex_recursive!(f, v, idx) = throw(BoundsError(f, idx))

function setindex!(f::Hcat{T}, v, k::Integer, j::Integer) where T
ξ = j
for A in f.args
n = size(A,2)
ξ ≤ n && return setindex!(A, v, k, ξ)
ξ -= n
end
throw(BoundsError(f, (k,j)))
hcat_setindex_recursive!(f, v, (k, j), f.args...)
end


Expand Down Expand Up @@ -957,4 +969,4 @@ function sub_paddeddata(::TriangularLayout{'U','N'}, S::SubArray{<:Any,1,<:Abstr
P = parent(S)
(kr,j) = parentindices(S)
view(triangulardata(P), kr ∩ (1:j), j)
end
end