Skip to content

Commit

Permalink
Use inference to determine element types
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Feb 15, 2020
1 parent 3db20ef commit 6bf5e05
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 11 deletions.
12 changes: 12 additions & 0 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,18 @@ _valof(::Val{D}) where D = D
_accumulate(op, a, _maybe_val(dims), init)

@inline function _accumulate(op::F, a::StaticArray, dims::Union{Val,Colon}, init) where {F}
if isempty(a)
if init isa _InitialValue
# Deliberately not using `return_type` here, since this `eltype` is
# exact for the singleton element case (i.e., `op` will not be called).
return similar_type(a)()
else
# Using the type that _would_ be used if `size(a, dims) == 1`:
T = return_type(op, Tuple{typeof(init), eltype(a)})
return similar_type(a, T)()
end
end

# Adjoin the initial value to `op`:
rf(x, y) = x isa _InitialValue ? y : op(x, y)

Expand Down
38 changes: 27 additions & 11 deletions test/accumulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,34 @@ using StaticArrays, Test
# label, T
("SVector", SVector),
("MVector", MVector),
("SizedVector", SizedVector{3}),
("SizedVector", SizedVector),
]
a = T(SA[1, 2, 3])
@test cumsum(a) == cumsum(collect(a))
@test cumsum(a) isa similar_type(a)
@inferred cumsum(a)
@testset "$label" for (label, a) in [
("[1, 2, 3]", T{3}(SA[1, 2, 3])),
("[]", T{0,Int}(())),
]
@test cumsum(a) == cumsum(collect(a))
@test cumsum(a) isa similar_type(a)
@inferred cumsum(a)
end
end

@testset "cumsum(::$label; dims=2)" for (label, T) in [
# label, T
("SMatrix", SMatrix),
("MMatrix", MMatrix),
("SizedMatrix", SizedMatrix{3,2}),
("SizedMatrix", SizedMatrix),
]
a = T(SA[1 2; 3 4; 5 6])
@test cumsum(a; dims = 2) == cumsum(collect(a); dims = 2)
@test cumsum(a; dims = 2) isa similar_type(a)
v"1.1" <= VERSION < v"1.2" && continue
@inferred cumsum(a; dims = Val(2))
@testset "$label" for (label, a) in [
("[1 2; 3 4; 5 6]", T{3,2}(SA[1 2; 3 4; 5 6])),
("0 x 2 matrix", T{0,2,Float64}()),
("2 x 0 matrix", T{2,0,Float64}()),
]
@test cumsum(a; dims = 2) == cumsum(collect(a); dims = 2)
@test cumsum(a; dims = 2) isa similar_type(a)
v"1.1" <= VERSION < v"1.2" && continue
@inferred cumsum(a; dims = Val(2))
end
end

@testset "cumsum(a::SArray; dims=$i); ndims(a) = $d" for d in 1:4, i in 1:d
Expand All @@ -40,4 +49,11 @@ using StaticArrays, Test
@test cumprod(a)::SArray == cumprod(collect(a))
@inferred cumprod(a)
end

@testset "empty vector with init" begin
a = SA{Int}[]
right(_, x) = x
@test accumulate(right, a; init = Val(1)) === SA{Int}[]
@inferred accumulate(right, a; init = Val(1))
end
end

0 comments on commit 6bf5e05

Please sign in to comment.