diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 7732b08f..733d06e1 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -299,6 +299,18 @@ _valof(::Val{D}) where D = D _accumulate(op, a, _maybe_val(dims), init) @inline function _accumulate(op::F, a::StaticArray, dims::Union{Val,Colon}, init) where {F} + if isempty(a) + if init isa _InitialValue + # Deliberately not using `return_type` here, since this `eltype` is + # exact for the singleton element case (i.e., `op` will not be called). + return similar_type(a)() + else + # Using the type that _would_ be used if `size(a, dims) == 1`: + T = return_type(op, Tuple{typeof(init), eltype(a)}) + return similar_type(a, T)() + end + end + # Adjoin the initial value to `op`: rf(x, y) = x isa _InitialValue ? y : op(x, y) diff --git a/test/accumulate.jl b/test/accumulate.jl index cdeca68a..ca0fb156 100644 --- a/test/accumulate.jl +++ b/test/accumulate.jl @@ -5,25 +5,34 @@ using StaticArrays, Test # label, T ("SVector", SVector), ("MVector", MVector), - ("SizedVector", SizedVector{3}), + ("SizedVector", SizedVector), ] - a = T(SA[1, 2, 3]) - @test cumsum(a) == cumsum(collect(a)) - @test cumsum(a) isa similar_type(a) - @inferred cumsum(a) + @testset "$label" for (label, a) in [ + ("[1, 2, 3]", T{3}(SA[1, 2, 3])), + ("[]", T{0,Int}(())), + ] + @test cumsum(a) == cumsum(collect(a)) + @test cumsum(a) isa similar_type(a) + @inferred cumsum(a) + end end @testset "cumsum(::$label; dims=2)" for (label, T) in [ # label, T ("SMatrix", SMatrix), ("MMatrix", MMatrix), - ("SizedMatrix", SizedMatrix{3,2}), + ("SizedMatrix", SizedMatrix), ] - a = T(SA[1 2; 3 4; 5 6]) - @test cumsum(a; dims = 2) == cumsum(collect(a); dims = 2) - @test cumsum(a; dims = 2) isa similar_type(a) - v"1.1" <= VERSION < v"1.2" && continue - @inferred cumsum(a; dims = Val(2)) + @testset "$label" for (label, a) in [ + ("[1 2; 3 4; 5 6]", T{3,2}(SA[1 2; 3 4; 5 6])), + ("0 x 2 matrix", T{0,2,Float64}()), + ("2 x 0 matrix", T{2,0,Float64}()), + ] + @test cumsum(a; dims = 2) == cumsum(collect(a); dims = 2) + @test cumsum(a; dims = 2) isa similar_type(a) + v"1.1" <= VERSION < v"1.2" && continue + @inferred cumsum(a; dims = Val(2)) + end end @testset "cumsum(a::SArray; dims=$i); ndims(a) = $d" for d in 1:4, i in 1:d @@ -40,4 +49,11 @@ using StaticArrays, Test @test cumprod(a)::SArray == cumprod(collect(a)) @inferred cumprod(a) end + + @testset "empty vector with init" begin + a = SA{Int}[] + right(_, x) = x + @test accumulate(right, a; init = Val(1)) === SA{Int}[] + @inferred accumulate(right, a; init = Val(1)) + end end