Skip to content

Commit

Permalink
fix: handle edge case for VoA mapreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 5, 2024
1 parent 84121f2 commit 2a55225
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,9 @@ Base.map(f, A::RecursiveArrayTools.AbstractVectorOfArray) = map(f, A.u)
function Base.mapreduce(f, op, A::AbstractVectorOfArray; kwargs...)
mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...)
end
function Base.mapreduce(f, op, A::AbstractVectorOfArray{T,1,<:AbstractVector{T}}; kwargs...) where {T}
mapreduce(f, op, A.u; kwargs...)
end

## broadcasting

Expand Down
7 changes: 7 additions & 0 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ for i in 1:ndims(arrvb)
@test mapreduce(string, *, arrvb; dims=i) == mapreduce(string, *, testvb; dims=i)
end

# Test when ndims == 1
testvb = VectorOfArray(collect(1.0:0.1:2.0))
arrvb = Array(testvb)
@test sum(arrvb) == sum(testvb)
@test prod(arrvb) == prod(testvb)
@test mapreduce(string, *, arrvb) == mapreduce(string, *, testvb)

# view
testvc = VectorOfArray([rand(1:10, 3, 3) for _ in 1:3])
arrvc = Array(testvc)
Expand Down

0 comments on commit 2a55225

Please sign in to comment.