Skip to content

Commit

Permalink
fix grid chunks for iteration perf (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaqz authored Sep 21, 2023
1 parent 64e1901 commit 1fa1e1c
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/chunks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ end
grid_offset(r::IrregularChunks) = 0
max_chunksize(r::IrregularChunks) = maximum(diff(r.offsets))

struct GridChunks{N} <: AbstractArray{NTuple{N,UnitRange{Int64}},N}
chunks::Tuple{Vararg{ChunkType,N}}
struct GridChunks{N,C<:Tuple{Vararg{<:ChunkType,N}}} <: AbstractArray{NTuple{N,UnitRange{Int64}},N}
chunks::C
end
GridChunks(ct::ChunkType...) = GridChunks(ct)
GridChunks(a, chunksize; offset=(_ -> 0).(size(a))) = GridChunks(size(a), chunksize; offset)
Expand Down
3 changes: 2 additions & 1 deletion src/diskarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ is_batch_arg(::AbstractRange) = false
function getindex_disk(a, i...)
checkscalar(i)
if any(is_batch_arg, i)
batchgetindex(a, i...) else
batchgetindex(a, i...)
else
inds, trans = interpret_indices_disk(a, i)
data = Array{eltype(a)}(undef, map(length, inds)...)
readblock!(a, data, inds...)
Expand Down
60 changes: 34 additions & 26 deletions src/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ end
Base.length(b::BlockedIndices) = prod(last.(last.(b.c.chunks)))
Base.IteratorEltype(::Type{<:BlockedIndices}) = Base.HasEltype()
Base.IteratorSize(::Type{<:BlockedIndices{<:GridChunks{N}}}) where {N} = Base.HasShape{N}()
Base.size(b::BlockedIndices) = last.(last.(b.c.chunks))
Base.size(b::BlockedIndices)::NTuple{<:Any,Int} = map(last last, b.c.chunks)
Base.eltype(b::BlockedIndices) = CartesianIndex{ndims(b.c)}
function Base.iterate(a::BlockedIndices)
chunkiter = Iterators.Stateful(a.c)
Expand All @@ -29,40 +29,48 @@ function Base.iterate(::BlockedIndices, i)
innerinds = Iterators.Stateful(CartesianIndices(first(ii)))
r = iterate(innerinds)
r === nothing && return nothing
return first(r), (chunkiter, innerinds)
else
return first(r), (chunkiter, innerinds)
end
return first(r), (chunkiter, innerinds)
end

# Implementaion macros
@noinline function _iterate_disk(a::AbstractArray{T}, i::I) where {T,I<:Tuple{A,B,C}} where {A,B,C}
datacur::A, bi::B, bstate::C = i
(chunkiter, innerinds) = bstate
cistateold = length(chunkiter)
biter = iterate(bi, bstate)
if biter === nothing
return nothing
else
innernow, bstatenew = biter
(chunkiter, innerinds) = bstatenew
if length(chunkiter) !== cistateold
curchunk = innerinds.itr.indices
datacur = OffsetArray(a[curchunk...], innerinds.itr)
return datacur[innernow]::T, (datacur, bi, bstatenew)::I
else
return datacur[innernow]::T, (datacur, bi, bstatenew)::I
end
end
end
@noinline function _iterate_disk(a)
bi = BlockedIndices(eachchunk(a))
it = iterate(bi)
isnothing(it) && return nothing
innernow, (chunkiter, innerinds) = it
curchunk = innerinds.itr.indices
datacur = OffsetArray(a[curchunk...], innerinds.itr)
return datacur[innernow], (datacur, bi, (chunkiter, innerinds))
end

macro implement_iteration(t)
t = esc(t)
quote
Base.eachindex(a::$t) = BlockedIndices(eachchunk(a))
function Base.iterate(a::$t)
bi = BlockedIndices(eachchunk(a))
it = iterate(bi)
isnothing(it) && return nothing
innernow, (chunkiter, innerinds) = it
curchunk = innerinds.itr.indices
datacur = OffsetArray(a[curchunk...], innerinds.itr)
return datacur[innernow], (datacur, bi, (chunkiter, innerinds))
end
function Base.iterate(a::$t, i)
datacur, bi, bstate = i
(chunkiter, innerinds) = bstate
cistateold = length(chunkiter)
biter = iterate(bi, bstate)
if biter === nothing
return nothing
end
innernow, bstatenew = biter
(chunkiter, innerinds) = bstatenew
if length(chunkiter) !== cistateold
curchunk = innerinds.itr.indices
datacur = OffsetArray(a[curchunk...], innerinds.itr)
end
return datacur[innernow], (datacur, bi, bstatenew)
end
Base.iterate(a::$t) = _iterate_disk(a)
Base.iterate(a::$t, i) = _iterate_disk(a, i)
end
end

0 comments on commit 1fa1e1c

Please sign in to comment.