From e97ab9b42760d33a6c88796c6541c1c8dcf684e2 Mon Sep 17 00:00:00 2001 From: Anshul Singhvi Date: Tue, 22 Oct 2024 17:17:52 -0700 Subject: [PATCH 1/4] Implement `collect_similar` like `collect` for DiskGenerators --- src/generator.jl | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/generator.jl b/src/generator.jl index be47e62..8619deb 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -47,6 +47,31 @@ function Base.collect(itr::DiskGenerator{<:AbstractArray{<:Any,N}}) where {N} return dest end +# Warning: this is not public API! +function Base.collect_similar(A::AbstractArray, itr::DiskGenerator{<:AbstractArray{<:Any,N}}) where {N} + y = iterate(itr) + shp = axes(itr.iter) + if y === nothing + et = Base.@default_eltype(itr) + return similar(A, et, shp) + end + v1, st = y + dest = similar(A, typeof(v1), shp) + i = y + for I in eachindex(itr.iter) + if i isa Nothing # Mainly to keep JET clean + error( + "Should not be reached: iterator is shorter than its `eachindex` iterator" + ) + else + dest[I] = first(i) + i = iterate(itr, last(i)) + end + end + return dest + +end + macro implement_generator(t) t = esc(t) quote From 9a33e9afc33904484ffe846560db578ac867e2ac Mon Sep 17 00:00:00 2001 From: Anshul Singhvi Date: Tue, 22 Oct 2024 17:26:01 -0700 Subject: [PATCH 2/4] Add a test --- test/runtests.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index db8dd1c..5c9a487 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -953,3 +953,17 @@ end @test getindex_count(A) == 0 end +@testset "Map over indices correctly" begin + # This is a regression test for issue #144 + # `map` should always work over the correct indices, + # especially since we overload generators to `DiskArrayGenerator`. + + data = [i+j for i in 1:200, j in 1:100] + da = AccessCountDiskArray(data, chunksize=(10,10)) + @test map(identity, da) == data + @test all(map(identity, da) .== data) + + # Make sure that type inference works + @inferred Matrix{Int} map(identity, da) + @inferred Matrix{Float64} map(x -> x * 5.0, da) +end From 2df5c295fc8156f2d3b327713c8985c5a0046a05 Mon Sep 17 00:00:00 2001 From: Anshul Singhvi Date: Tue, 22 Oct 2024 20:49:31 -0700 Subject: [PATCH 3/4] Add a flexible approach that should work with irregular chunks (.1ms) --- src/generator.jl | 36 ++++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/src/generator.jl b/src/generator.jl index 8619deb..1b55597 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -49,23 +49,39 @@ end # Warning: this is not public API! function Base.collect_similar(A::AbstractArray, itr::DiskGenerator{<:AbstractArray{<:Any,N}}) where {N} + input = itr.iter # this is known to be an array y = iterate(itr) - shp = axes(itr.iter) + shp = axes(input) if y === nothing et = Base.@default_eltype(itr) return similar(A, et, shp) end v1, st = y - dest = similar(A, typeof(v1), shp) + dest = similar(A, typeof(v1), shp)# TODO: should this be `Base.return_type(itr.f, Tuple{eltype(input)})`? i = y - for I in eachindex(itr.iter) - if i isa Nothing # Mainly to keep JET clean - error( - "Should not be reached: iterator is shorter than its `eachindex` iterator" - ) - else - dest[I] = first(i) - i = iterate(itr, last(i)) + # If the array is chunked, read each chunk and apply the function + # via broadcasting. + if DiskArrays.haschunks(input) isa DiskArrays.Chunked + # TODO: change this if DiskArrays ever supports uneven chunks + chunks = eachchunk(input) + value_holder = Matrix{eltype(v1)}(undef, DiskArrays.max_chunksize(chunks)...) + output_holder = Matrix{typeof(v1)}(undef, DiskArrays.max_chunksize(chunks)...) + for chunk_inds in chunks + this_chunk_size = map(x -> 1:length(x), chunk_inds) + DiskArrays.readblock!(input, value_holder, chunk_inds...) + output_holder[this_chunk_size...] .= itr.f.(view(value_holder, this_chunk_size...)) + dest[chunk_inds...] .= view(output_holder, this_chunk_size...) + end + else # iterate as normal array + for I in eachindex(itr.iter) + if i isa Nothing # Mainly to keep JET clean + error( + "Should not be reached: iterator is shorter than its `eachindex` iterator" + ) + else + dest[I] = first(i) + i = iterate(itr, last(i)) + end end end return dest From f7d69aa7649179b8af14498cdb5e08f140039c88 Mon Sep 17 00:00:00 2001 From: Anshul Singhvi Date: Tue, 22 Oct 2024 20:50:30 -0700 Subject: [PATCH 4/4] cut runtime and alloc amount (but NOT allocs) (.09ms) --- src/generator.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/generator.jl b/src/generator.jl index 1b55597..a9ea1b4 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -64,13 +64,8 @@ function Base.collect_similar(A::AbstractArray, itr::DiskGenerator{<:AbstractArr if DiskArrays.haschunks(input) isa DiskArrays.Chunked # TODO: change this if DiskArrays ever supports uneven chunks chunks = eachchunk(input) - value_holder = Matrix{eltype(v1)}(undef, DiskArrays.max_chunksize(chunks)...) - output_holder = Matrix{typeof(v1)}(undef, DiskArrays.max_chunksize(chunks)...) for chunk_inds in chunks - this_chunk_size = map(x -> 1:length(x), chunk_inds) - DiskArrays.readblock!(input, value_holder, chunk_inds...) - output_holder[this_chunk_size...] .= itr.f.(view(value_holder, this_chunk_size...)) - dest[chunk_inds...] .= view(output_holder, this_chunk_size...) + dest[chunk_inds...] .= itr.f.(input[chunk_inds...]) end else # iterate as normal array for I in eachindex(itr.iter)