diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index e1b92ea5..14476500 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -664,6 +664,9 @@ Base.map(f, A::RecursiveArrayTools.AbstractVectorOfArray) = map(f, A.u) function Base.mapreduce(f, op, A::AbstractVectorOfArray; kwargs...) mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...) end +function Base.mapreduce(f, op, A::AbstractVectorOfArray{T,1,<:AbstractVector{T}}; kwargs...) where {T} + mapreduce(f, op, A.u; kwargs...) +end ## broadcasting diff --git a/test/interface_tests.jl b/test/interface_tests.jl index b056e43d..3aba249a 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -74,6 +74,13 @@ for i in 1:ndims(arrvb) @test mapreduce(string, *, arrvb; dims=i) == mapreduce(string, *, testvb; dims=i) end +# Test when ndims == 1 +testvb = VectorOfArray(collect(1.0:0.1:2.0)) +arrvb = Array(testvb) +@test sum(arrvb) == sum(testvb) +@test prod(arrvb) == prod(testvb) +@test mapreduce(string, *, arrvb) == mapreduce(string, *, testvb) + # view testvc = VectorOfArray([rand(1:10, 3, 3) for _ in 1:3]) arrvc = Array(testvc)