Skip to content

Commit

Permalink
remove evaluate function for functor types
Browse files Browse the repository at this point in the history
  • Loading branch information
jakebolewski committed Oct 23, 2014
1 parent f104352 commit 0c631f8
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 88 deletions.
147 changes: 69 additions & 78 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,38 @@

abstract Func{N}

type IdFun <: Func{1} end
type AbsFun <: Func{1} end
type Abs2Fun <: Func{1} end
type ExpFun <: Func{1} end
type LogFun <: Func{1} end

type AddFun <: Func{2} end
type MulFun <: Func{2} end
type AndFun <: Func{2} end
type OrFun <: Func{2} end
type MaxFun <: Func{2} end
type MinFun <: Func{2} end

evaluate(::IdFun, x) = x
evaluate(::AbsFun, x) = abs(x)
evaluate(::Abs2Fun, x) = abs2(x)
evaluate(::ExpFun, x) = exp(x)
evaluate(::LogFun, x) = log(x)
evaluate(f::Callable, x) = f(x)

evaluate(::AddFun, x, y) = x + y
evaluate(::MulFun, x, y) = x * y
evaluate(::AndFun, x, y) = x & y
evaluate(::OrFun, x, y) = x | y
evaluate(::MaxFun, x, y) = scalarmax(x, y)
evaluate(::MinFun, x, y) = scalarmin(x, y)
evaluate(f::Callable, x, y) = f(x, y)
immutable IdFun <: Func{1} end
call(::IdFun, x) = x

immutable AbsFun <: Func{1} end
call(::AbsFun, x) = abs(x)

immutable Abs2Fun <: Func{1} end
call(::Abs2Fun, x) = abs2(x)

immutable ExpFun <: Func{1} end
call(::ExpFun, x) = exp(x)

immutable LogFun <: Func{1} end
call(::LogFun, x) = log(x)

immutable AddFun <: Func{2} end
call(::AddFun, x, y) = x + y

immutable MulFun <: Func{2} end
call(::MulFun, x, y) = x * y

immutable AndFun <: Func{2} end
call(::AndFun, x, y) = x & y

immutable OrFun <: Func{2} end
call(::OrFun, x, y) = x | y

immutable MaxFun <: Func{2} end
call(::MaxFun, x, y) = scalarmax(x,y)

immutable MinFun <: Func{2} end
call(::MinFun, x, y) = scalarmin(x, y)

###### Generic (map)reduce functions ######

Expand Down Expand Up @@ -68,10 +72,10 @@ function mapfoldl_impl(f, op, v0, itr, i)
return v0
else
(x, i) = next(itr, i)
v = evaluate(op, v0, evaluate(f, x))
v = op(v0, f(x))
while !done(itr, i)
(x, i) = next(itr, i)
v = evaluate(op, v, evaluate(f, x))
(x, i) = next(itr, i)
v = op(v, f(x))
end
return v
end
Expand All @@ -93,7 +97,7 @@ function mapfoldl(f, op, itr)
return Base.mr_empty(f, op, eltype(itr))
end
(x, i) = next(itr, i)
v0 = evaluate(f, x)
v0 = f(x)
mapfoldl_impl(f, op, v0, itr, i)
end

Expand All @@ -107,17 +111,17 @@ function mapfoldr_impl(f, op, v0, itr, i::Integer)
return v0
else
x = itr[i]
v = evaluate(op, evaluate(f, x), v0)
v = op(f(x), v0)
while i > 1
x = itr[i -= 1]
v = evaluate(op, evaluate(f, x), v)
v = op(f(x), v)
end
return v
end
end

mapfoldr(f, op, v0, itr) = mapfoldr_impl(f, op, v0, itr, endof(itr))
mapfoldr(f, op, itr) = (i = endof(itr); mapfoldr_impl(f, op, evaluate(f, itr[i]), itr, i-1))
mapfoldr(f, op, itr) = (i = endof(itr); mapfoldr_impl(f, op, f(itr[i]), itr, i-1))

foldr(op, v0, itr) = mapfoldr(IdFun(), op, v0, itr)
foldr(op, itr) = mapfoldr(IdFun(), op, itr)
Expand All @@ -126,12 +130,12 @@ foldr(op, itr) = mapfoldr(IdFun(), op, itr)

# mapreduce_***_impl require ifirst < ilast
function mapreduce_seq_impl(f, op, A::AbstractArray, ifirst::Int, ilast::Int)
@inbounds fx1 = r_promote(op, evaluate(f, A[ifirst]))
@inbounds fx2 = evaluate(f, A[ifirst+=1])
@inbounds v = evaluate(op, fx1, fx2)
@inbounds fx1 = r_promote(op, f(A[ifirst]))
@inbounds fx2 = f(A[ifirst+=1])
@inbounds v = op(fx1, fx2)
while ifirst < ilast
@inbounds fx = evaluate(f, A[ifirst+=1])
v = evaluate(op, v, fx)
@inbounds fx = f(A[ifirst+=1])
v = op(v, fx)
end
return v
end
Expand All @@ -143,7 +147,7 @@ function mapreduce_pairwise_impl(f, op, A::AbstractArray, ifirst::Int, ilast::In
imid = (ifirst + ilast) >>> 1
v1 = mapreduce_pairwise_impl(f, op, A, ifirst, imid, blksize)
v2 = mapreduce_pairwise_impl(f, op, A, imid+1, ilast, blksize)
return evaluate(op, v1, v2)
return op(v1, v2)
end
end

Expand All @@ -169,15 +173,15 @@ function _mapreduce{T}(f, op, A::AbstractArray{T})
if n == 0
return mr_empty(f, op, T)
elseif n == 1
return r_promote(op, evaluate(f, A[1]))
return r_promote(op, f(A[1]))
elseif n < 16
@inbounds fx1 = r_promote(op, evaluate(f, A[1]))
@inbounds fx2 = r_promote(op, evaluate(f, A[2]))
s = evaluate(op, fx1, fx2)
@inbounds fx1 = r_promote(op, f(A[1]))
@inbounds fx2 = r_promote(op, f(A[2]))
s = op(fx1, fx2)
i = 2
while i < n
@inbounds fx = evaluate(f, A[i+=1])
s = evaluate(op, s, fx)
@inbounds fx = f(A[i+=1])
s = op(s, fx)
end
return s
else
Expand All @@ -186,7 +190,7 @@ function _mapreduce{T}(f, op, A::AbstractArray{T})
end

mapreduce(f, op, A::AbstractArray) = _mapreduce(f, op, A)
mapreduce(f, op, a::Number) = evaluate(f, a)
mapreduce(f, op, a::Number) = f(a)

function mapreduce(f, op::Function, A::AbstractArray)
is(op, +) ? _mapreduce(f, AddFun(), A) :
Expand All @@ -207,9 +211,9 @@ reduce(op, a::Number) = a

function mapreduce_seq_impl(f, op::AddFun, a::AbstractArray, ifirst::Int, ilast::Int)
@inbounds begin
s = r_promote(op, evaluate(f, a[ifirst])) + evaluate(f, a[ifirst+1])
s = r_promote(op, f(a[ifirst])) + f(a[ifirst+1])
@simd for i = ifirst+2:ilast
s += evaluate(f, a[i])
s += f(a[i])
end
end
s
Expand Down Expand Up @@ -266,14 +270,14 @@ prod(A::AbstractArray{Bool}) =

function mapreduce_impl(f, op::MaxFun, A::AbstractArray, first::Int, last::Int)
# locate the first non NaN number
v = evaluate(f, A[first])
v = f(A[first])
i = first + 1
while v != v && i <= last
@inbounds v = evaluate(f, A[i])
@inbounds v = f(A[i])
i += 1
end
while i <= last
@inbounds x = evaluate(f, A[i])
@inbounds x = f(A[i])
if x > v
v = x
end
Expand All @@ -284,14 +288,14 @@ end

function mapreduce_impl(f, op::MinFun, A::AbstractArray, first::Int, last::Int)
# locate the first non NaN number
v = evaluate(f, A[first])
v = f(A[first])
i = first + 1
while v != v && i <= last
@inbounds v = evaluate(f, A[i])
@inbounds v = f(A[i])
i += 1
end
while i <= last
@inbounds x = evaluate(f, A[i])
@inbounds x = f(A[i])
if x < v
v = x
end
Expand Down Expand Up @@ -339,28 +343,22 @@ end

function mapfoldl(f, ::AndFun, itr)
for x in itr
if !evaluate(f, x)
return false
end
!f(x) && return false
end
return true
end

function mapfoldl(f, ::OrFun, itr)
for x in itr
if evaluate(f, x)
return true
end
f(x) && return true
end
return false
end

function mapreduce_impl(f, op::AndFun, A::AbstractArray, ifirst::Int, ilast::Int)
while ifirst <= ilast
@inbounds x = A[ifirst]
if !evaluate(f, x)
return false
end
!f(x) && return false
ifirst += 1
end
return true
Expand All @@ -369,9 +367,7 @@ end
function mapreduce_impl(f, op::OrFun, A::AbstractArray, ifirst::Int, ilast::Int)
while ifirst <= ilast
@inbounds x = A[ifirst]
if evaluate(f, x)
return true
end
f(x) && return true
ifirst += 1
end
return false
Expand All @@ -390,8 +386,8 @@ immutable EqX{T} <: Func{1}
x::T
end
EqX{T}(x::T) = EqX{T}(x)
evaluate(f::EqX, y) = (y == f.x)

call(f::EqX, y) = f.x == y
in(x, itr) = any(EqX(x), itr)

const = in
Expand All @@ -401,9 +397,7 @@ const ∈ = in

function contains(eq::Function, itr, x)
for y in itr
if eq(y, x)
return true
end
eq(y, x) && return true
end
return false
end
Expand All @@ -414,25 +408,22 @@ end
function count(pred::Union(Function,Func{1}), itr)
n = 0
for x in itr
if evaluate(pred, x)
n += 1
end
pred(x) && (n += 1)
end
return n
end

function count(pred::Union(Function,Func{1}), a::AbstractArray)
n = 0
for i = 1:length(a)
@inbounds if evaluate(pred, a[i])
@inbounds if pred(a[i])
n += 1
end
end
return n
end

type NotEqZero <: Func{1} end
evaluate(NotEqZero, x) = (x != 0)
immutable NotEqZero <: Func{1} end
call(::NotEqZero, x) = x != 0

countnz(a) = count(NotEqZero(), a)

18 changes: 9 additions & 9 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ promote_union(T) = T
function reducedim_init{S}(f, op::AddFun, A::AbstractArray{S}, region)
T = promote_union(S)
if method_exists(zero, (Type{T},))
x = evaluate(f, zero(T))
x = f(zero(T))
z = zero(x) + zero(x)
Tr = typeof(z) == typeof(x) && !isbits(T) ? T : typeof(z)
else
Expand All @@ -81,7 +81,7 @@ end
function reducedim_init{S}(f, op::MulFun, A::AbstractArray{S}, region)
T = promote_union(S)
if method_exists(zero, (Type{T},))
x = evaluate(f, zero(T))
x = f(zero(T))
z = one(x) * one(x)
Tr = typeof(z) == typeof(x) && !isbits(T) ? T : typeof(z)
else
Expand All @@ -91,10 +91,10 @@ function reducedim_init{S}(f, op::MulFun, A::AbstractArray{S}, region)
return reducedim_initarray(A, region, z, Tr)
end

reducedim_init{T}(f, op::MaxFun, A::AbstractArray{T}, region) = reducedim_initarray0(A, region, typemin(evaluate(f, zero(T))))
reducedim_init{T}(f, op::MinFun, A::AbstractArray{T}, region) = reducedim_initarray0(A, region, typemax(evaluate(f, zero(T))))
reducedim_init{T}(f, op::MaxFun, A::AbstractArray{T}, region) = reducedim_initarray0(A, region, typemin(f(zero(T))))
reducedim_init{T}(f, op::MinFun, A::AbstractArray{T}, region) = reducedim_initarray0(A, region, typemax(f(zero(T))))
reducedim_init{T}(f::Union(AbsFun,Abs2Fun), op::MaxFun, A::AbstractArray{T}, region) =
reducedim_initarray(A, region, zero(evaluate(f, zero(T))))
reducedim_initarray(A, region, zero(f(zero(T))))

reducedim_init(f, op::AndFun, A::AbstractArray, region) = reducedim_initarray(A, region, true)
reducedim_init(f, op::OrFun, A::AbstractArray, region) = reducedim_initarray(A, region, false)
Expand Down Expand Up @@ -173,16 +173,16 @@ end
@nloops N i d->(d>1? (1:size(A,d)) : (1:1)) d->(j_d = sizeR_d==1 ? 1 : i_d) begin
@inbounds r = (@nref N R j)
for i_1 = 1:sizA1
@inbounds v = evaluate(f, (@nref N A i))
r = evaluate(op, r, v)
@inbounds v = f(@nref N A i)
r = op(r, v)
end
@inbounds (@nref N R j) = r
end
else
# general implementation
@nloops N i A d->(j_d = sizeR_d==1 ? 1 : i_d) begin
@inbounds v = evaluate(f, (@nref N A i))
@inbounds (@nref N R j) = evaluate(op, (@nref N R j), v)
@inbounds v = f(@nref N A i)
@inbounds (@nref N R j) = op((@nref N R j), v)
end
end
return R
Expand Down
2 changes: 1 addition & 1 deletion base/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ varzm{T}(A::AbstractArray{T}, region; corrected::Bool=true) =
immutable CentralizedAbs2Fun{T<:Number} <: Func{1}
m::T
end
evaluate(f::CentralizedAbs2Fun, x) = abs2(x - f.m)
call(f::CentralizedAbs2Fun, x) = abs2(x - f.m)
centralize_sumabs2(A::AbstractArray, m::Number, ifirst::Int, ilast::Int) =
mapreduce_impl(CentralizedAbs2Fun(m), AddFun(), A, ifirst, ilast)

Expand Down

0 comments on commit 0c631f8

Please sign in to comment.