Skip to content

Commit

Permalink
cumsum fixes (fixes JuliaLang#18363 and JuliaLang#18336)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevengj committed Sep 5, 2016
1 parent 79a7172 commit a0f30e6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
15 changes: 8 additions & 7 deletions base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,6 @@ ctranspose{T<:Real}(A::AbstractVecOrMat{T}) = transpose(A)
transpose(x::AbstractVector) = [ transpose(v) for i=of_indices(x, OneTo(1)), v in x ]
ctranspose{T}(x::AbstractVector{T}) = T[ ctranspose(v) for i=of_indices(x, OneTo(1)), v in x ]

_cumsum_type{T<:Number}(v::AbstractArray{T}) = typeof(+zero(T))
_cumsum_type(v) = typeof(v[1]+v[1])

for (f, f!, fp, op) = ((:cumsum, :cumsum!, :cumsum_pairwise!, :+),
(:cumprod, :cumprod!, :cumprod_pairwise!, :*) )
# in-place cumsum of c = s+v[range(i1,n)], using pairwise summation
Expand All @@ -472,12 +469,16 @@ for (f, f!, fp, op) = ((:cumsum, :cumsum!, :cumsum_pairwise!, :+),
@eval function ($f!)(result::AbstractVector, v::AbstractVector)
n = length(v)
if n == 0; return result; end
($fp)(v, result, $(op==:+ ? :(zero(first(v))) : :(one(first(v)))), first(indices(v,1)), n)
li = linearindices(v)
li != linearindices(result) && throw(BoundsError())
i1 = first(li)
@inbounds result[i1] = v1 = v[i1]
n == 1 && return result
($fp)(v, result, v1, i1+1, n-1)
return result
end

@eval function ($f)(v::AbstractVector)
c = $(op===:+ ? (:(similar(v,_cumsum_type(v)))) : (:(similar(v))))
return ($f!)(c, v)
@eval function ($f){T}(v::AbstractVector{T})
return ($f!)(similar(v), v)
end
end
2 changes: 1 addition & 1 deletion base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ julia> cumsum(a,2)
4 9 15
```
"""
cumsum(A::AbstractArray, axis::Integer=1) = cumsum!(similar(A, Base._cumsum_type(A)), A, axis)
cumsum(A::AbstractArray, axis::Integer=1) = cumsum!(similar(A), A, axis)
cumsum!(B, A::AbstractArray) = cumsum!(B, A, 1)
"""
cumprod(A, dim=1)
Expand Down
9 changes: 9 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1689,6 +1689,15 @@ end
@test cumsum([1 2; 3 4], 2) == [1 3; 3 7]
@test cumsum([1 2; 3 4], 3) == [1 2; 3 4]

# issue #18363
@test_throws BoundsError cumsum!([0,0], 1:4)
@test cumsum(Any[]) == Any[] && isa(cumsum(Any[]), Vector{Any})
@test cumsum(Any[1, 2.3]) == [1, 3.3]

#issue #18336
@test cumsum([-0.0, -0.0])[1] === cumsum([-0.0, -0.0])[2] === -0.0
@test cumprod(-0.0im + (0:0))[1] === Complex(0.0, -0.0)

module TestNLoops15895

using Base.Cartesian
Expand Down

0 comments on commit a0f30e6

Please sign in to comment.