Skip to content

Commit

Permalink
Merge pull request #4039 from stevengj/pairwise
Browse files Browse the repository at this point in the history
RFC: use pairwise summation for sum, cumsum, and cumprod
  • Loading branch information
JeffBezanson committed Aug 13, 2013
2 parents 33fdcbb + 01ff268 commit c8f89d1
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 22 deletions.
59 changes: 45 additions & 14 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -931,17 +931,28 @@ function (!=)(A::AbstractArray, B::AbstractArray)
return false
end

for (f, op) = ((:cumsum, :+), (:cumprod, :*) )
for (f, fp, op) = ((:cumsum, :cumsum_pairwise, :+),
(:cumprod, :cumprod_pairwise, :*) )
# in-place cumsum of c = s+v(i1:n), using pairwise summation as for sum
@eval function ($fp)(v::AbstractVector, c::AbstractVector, s, i1, n)
if n < 128
@inbounds c[i1] = ($op)(s, v[i1])
for i = i1+1:i1+n-1
@inbounds c[i] = $(op)(c[i-1], v[i])
end
else
n2 = div(n,2)
($fp)(v, c, s, i1, n2)
($fp)(v, c, c[(i1+n2)-1], i1+n2, n-n2)
end
end

@eval function ($f)(v::AbstractVector)
n = length(v)
c = $(op===:+ ? (:(similar(v,typeof(+zero(eltype(v)))))) :
(:(similar(v))))
if n == 0; return c; end

c[1] = v[1]
for i=2:n
c[i] = ($op)(c[i-1], v[i])
end
($fp)(v, c, $(op==:+ ? :(zero(eltype(v))) : :(one(eltype(v)))), 1, n)
return c
end

Expand Down Expand Up @@ -1367,17 +1378,37 @@ prod(A::AbstractArray{Bool}) =
prod(A::AbstractArray{Bool}, region) =
error("use all() instead of prod() for boolean arrays")

function sum{T}(A::AbstractArray{T})
if isempty(A)
return zero(T)
end
v = A[1]
for i=2:length(A)
@inbounds v += A[i]
# Pairwise (cascade) summation of A[i1:i1+n-1], which O(log n) error growth
# [vs O(n) for a simple loop] with negligible performance cost if
# the base case is large enough. See, e.g.:
# http://en.wikipedia.org/wiki/Pairwise_summation
# Higham, Nicholas J. (1993), "The accuracy of floating point
# summation", SIAM Journal on Scientific Computing 14 (4): 783–799.
# In fact, the root-mean-square error growth, assuming random roundoff
# errors, is only O(sqrt(log n)), which is nearly indistinguishable from O(1)
# in practice. See:
# Manfred Tasche and Hansmartin Zeuner, Handbook of
# Analytic-Computational Methods in Applied Mathematics (2000).
function sum_pairwise(A::AbstractArray, i1,n)
if n < 128
@inbounds s = A[i1]
for i = i1+1:i1+n-1
@inbounds s += A[i]
end
return s
else
n2 = div(n,2)
return sum_pairwise(A, i1, n2) + sum_pairwise(A, i1+n2, n-n2)
end
v
end

function sum{T}(A::AbstractArray{T})
n = length(A)
n == 0 ? zero(T) : sum_pairwise(A, 1, n)
end

# Kahan (compensated) summation: O(1) error growth, at the expense
# of a considerable increase in computational expense.
function sum_kbn{T<:FloatingPoint}(A::AbstractArray{T})
n = length(A)
if (n == 0)
Expand Down
25 changes: 23 additions & 2 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,27 @@ function mapreduce(f::Callable, op::Function, v0, itr)
return v
end

# mapreduce for associative operations, using pairwise recursive reduction
# for improved accuracy (see sum_pairwise)
function mr_pairwise(f::Callable, op::Function, A::AbstractArray, i1,n)
if n < 128
@inbounds v = f(A[i1])
for i = i1+1:i1+n-1
@inbounds v = op(v,f(A[i]))
end
return v
else
n2 = div(n,2)
return op(mr_pairwise(f,op,A, i1,n2), mr_pairwise(f,op,A, i1+n2,n-n2))
end
end
function mapreduce_associative(f::Callable, op::Function, A::AbstractArray)
n = length(A)
n == 0 ? op() : mr_pairwise(f,op,A, 1,n)
end
# can't easily do pairwise reduction without random access, so punt:
mapreduce_associative(f::Callable, op::Function, itr) = mapreduce(f, op, itr)

function any(itr)
for x in itr
if x
Expand All @@ -171,8 +192,8 @@ end

max(f::Function, itr) = mapreduce(f, max, itr)
min(f::Function, itr) = mapreduce(f, min, itr)
sum(f::Function, itr) = mapreduce(f, + , itr)
prod(f::Function, itr) = mapreduce(f, * , itr)
sum(f::Function, itr) = mapreduce_associative(f, + , itr)
prod(f::Function, itr) = mapreduce_associative(f, * , itr)

function count(pred::Function, itr)
s = 0
Expand Down
23 changes: 17 additions & 6 deletions base/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ function mean(iterable)
end
return total/count
end
mean(v::AbstractArray) = sum(v) / length(v)
mean(v::AbstractArray, region) = sum(v, region) / prod(size(v)[region])

function median!{T<:Real}(v::AbstractVector{T}; checknan::Bool=true)
Expand All @@ -28,16 +29,26 @@ end
median{T<:Real}(v::AbstractArray{T}; checknan::Bool=true) =
median!(vec(copy(v)), checknan=checknan)

## variance with known mean
function varm(v::AbstractVector, m::Number)
## variance with known mean, using pairwise summation
function varm_pairwise(A::AbstractArray, m, i1,n) # see sum_pairwise
if n < 128
@inbounds s = abs2(A[i1] - m)
for i = i1+1:i1+n-1
@inbounds s += abs2(A[i] - m)
end
return s
else
n2 = div(n,2)
return varm_pairwise(A, m, i1, n2) + varm_pairwise(A, m, i1+n2, n-n2)
end
end
function varm(v::AbstractArray, m::Number)
n = length(v)
if n == 0 || n == 1
return NaN
end
x = v - m
return dot(x, x) / (n - 1)
return varm_pairwise(v, m, 1,n) / (n - 1)
end
varm(v::AbstractArray, m::Number) = varm(vec(v), m)
varm(v::Ranges, m::Number) = var(v)

## variance
Expand All @@ -52,7 +63,7 @@ end
var(v::AbstractArray) = varm(v, mean(v))
function var(v::AbstractArray, region)
x = v .- mean(v, region)
return sum(x.*x, region) / (prod(size(v)[region]) - 1)
return sum(abs2(x), region) / (prod(size(v)[region]) - 1)
end

## standard deviation with known mean
Expand Down
9 changes: 9 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,15 @@ v[2,2,1,1] = 40.0

@test isequal(v,sum(z,(3,4)))

z = rand(10^6)
let es = sum_kbn(z), es2 = sum_kbn(z[1:10^5])
@test (es - sum(z)) < es * 1e-13
cs = cumsum(z)
@test (es - cs[end]) < es * 1e-13
@test (es2 - cs[10^5]) < es2 * 1e-13
end
@test sum(sin(z)) == sum(sin, z)

## large matrices transpose ##

for i = 1 : 3
Expand Down
4 changes: 4 additions & 0 deletions test/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
@test all(hist([1:100]/100,0.0:0.01:1.0)[2] .==1)
@test hist([1,1,1,1,1])[2][1] == 5

A = Complex128[exp(i*im) for i in 1:10^4]
@test_approx_eq varm(A,0.) sum(map(abs2,A))/(length(A)-1)
@test_approx_eq varm(A,mean(A)) var(A,1)

@test midpoints(1.0:1.0:10.0) == 1.5:1.0:9.5
@test midpoints(1:10) == 1.5:9.5
@test midpoints(Float64[1.0:1.0:10.0]) == Float64[1.5:1.0:9.5]
Expand Down

0 comments on commit c8f89d1

Please sign in to comment.