Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove evaluate function for functor types #8790

Merged
merged 1 commit into from
Oct 24, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 69 additions & 78 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,38 @@

abstract Func{N}

type IdFun <: Func{1} end
type AbsFun <: Func{1} end
type Abs2Fun <: Func{1} end
type ExpFun <: Func{1} end
type LogFun <: Func{1} end

type AddFun <: Func{2} end
type MulFun <: Func{2} end
type AndFun <: Func{2} end
type OrFun <: Func{2} end
type MaxFun <: Func{2} end
type MinFun <: Func{2} end

evaluate(::IdFun, x) = x
evaluate(::AbsFun, x) = abs(x)
evaluate(::Abs2Fun, x) = abs2(x)
evaluate(::ExpFun, x) = exp(x)
evaluate(::LogFun, x) = log(x)
evaluate(f::Callable, x) = f(x)

evaluate(::AddFun, x, y) = x + y
evaluate(::MulFun, x, y) = x * y
evaluate(::AndFun, x, y) = x & y
evaluate(::OrFun, x, y) = x | y
evaluate(::MaxFun, x, y) = scalarmax(x, y)
evaluate(::MinFun, x, y) = scalarmin(x, y)
evaluate(f::Callable, x, y) = f(x, y)
immutable IdFun <: Func{1} end
call(::IdFun, x) = x

immutable AbsFun <: Func{1} end
call(::AbsFun, x) = abs(x)

immutable Abs2Fun <: Func{1} end
call(::Abs2Fun, x) = abs2(x)

immutable ExpFun <: Func{1} end
call(::ExpFun, x) = exp(x)

immutable LogFun <: Func{1} end
call(::LogFun, x) = log(x)

immutable AddFun <: Func{2} end
call(::AddFun, x, y) = x + y

immutable MulFun <: Func{2} end
call(::MulFun, x, y) = x * y

immutable AndFun <: Func{2} end
call(::AndFun, x, y) = x & y

immutable OrFun <: Func{2} end
call(::OrFun, x, y) = x | y

immutable MaxFun <: Func{2} end
call(::MaxFun, x, y) = scalarmax(x,y)

immutable MinFun <: Func{2} end
call(::MinFun, x, y) = scalarmin(x, y)

###### Generic (map)reduce functions ######

Expand Down Expand Up @@ -68,10 +72,10 @@ function mapfoldl_impl(f, op, v0, itr, i)
return v0
else
(x, i) = next(itr, i)
v = evaluate(op, v0, evaluate(f, x))
v = op(v0, f(x))
while !done(itr, i)
(x, i) = next(itr, i)
v = evaluate(op, v, evaluate(f, x))
(x, i) = next(itr, i)
v = op(v, f(x))
end
return v
end
Expand All @@ -93,7 +97,7 @@ function mapfoldl(f, op, itr)
return Base.mr_empty(f, op, eltype(itr))
end
(x, i) = next(itr, i)
v0 = evaluate(f, x)
v0 = f(x)
mapfoldl_impl(f, op, v0, itr, i)
end

Expand All @@ -107,17 +111,17 @@ function mapfoldr_impl(f, op, v0, itr, i::Integer)
return v0
else
x = itr[i]
v = evaluate(op, evaluate(f, x), v0)
v = op(f(x), v0)
while i > 1
x = itr[i -= 1]
v = evaluate(op, evaluate(f, x), v)
v = op(f(x), v)
end
return v
end
end

mapfoldr(f, op, v0, itr) = mapfoldr_impl(f, op, v0, itr, endof(itr))
mapfoldr(f, op, itr) = (i = endof(itr); mapfoldr_impl(f, op, evaluate(f, itr[i]), itr, i-1))
mapfoldr(f, op, itr) = (i = endof(itr); mapfoldr_impl(f, op, f(itr[i]), itr, i-1))

foldr(op, v0, itr) = mapfoldr(IdFun(), op, v0, itr)
foldr(op, itr) = mapfoldr(IdFun(), op, itr)
Expand All @@ -126,12 +130,12 @@ foldr(op, itr) = mapfoldr(IdFun(), op, itr)

# mapreduce_***_impl require ifirst < ilast
function mapreduce_seq_impl(f, op, A::AbstractArray, ifirst::Int, ilast::Int)
@inbounds fx1 = r_promote(op, evaluate(f, A[ifirst]))
@inbounds fx2 = evaluate(f, A[ifirst+=1])
@inbounds v = evaluate(op, fx1, fx2)
@inbounds fx1 = r_promote(op, f(A[ifirst]))
@inbounds fx2 = f(A[ifirst+=1])
@inbounds v = op(fx1, fx2)
while ifirst < ilast
@inbounds fx = evaluate(f, A[ifirst+=1])
v = evaluate(op, v, fx)
@inbounds fx = f(A[ifirst+=1])
v = op(v, fx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems te be an @inbounds missing in line 137 in comparison to old line 133.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

end
return v
end
Expand All @@ -143,7 +147,7 @@ function mapreduce_pairwise_impl(f, op, A::AbstractArray, ifirst::Int, ilast::In
imid = (ifirst + ilast) >>> 1
v1 = mapreduce_pairwise_impl(f, op, A, ifirst, imid, blksize)
v2 = mapreduce_pairwise_impl(f, op, A, imid+1, ilast, blksize)
return evaluate(op, v1, v2)
return op(v1, v2)
end
end

Expand All @@ -169,15 +173,15 @@ function _mapreduce{T}(f, op, A::AbstractArray{T})
if n == 0
return mr_empty(f, op, T)
elseif n == 1
return r_promote(op, evaluate(f, A[1]))
return r_promote(op, f(A[1]))
elseif n < 16
@inbounds fx1 = r_promote(op, evaluate(f, A[1]))
@inbounds fx2 = r_promote(op, evaluate(f, A[2]))
s = evaluate(op, fx1, fx2)
@inbounds fx1 = r_promote(op, f(A[1]))
@inbounds fx2 = r_promote(op, f(A[2]))
s = op(fx1, fx2)
i = 2
while i < n
@inbounds fx = evaluate(f, A[i+=1])
s = evaluate(op, s, fx)
@inbounds fx = f(A[i+=1])
s = op(s, fx)
end
return s
else
Expand All @@ -186,7 +190,7 @@ function _mapreduce{T}(f, op, A::AbstractArray{T})
end

mapreduce(f, op, A::AbstractArray) = _mapreduce(f, op, A)
mapreduce(f, op, a::Number) = evaluate(f, a)
mapreduce(f, op, a::Number) = f(a)

function mapreduce(f, op::Function, A::AbstractArray)
is(op, +) ? _mapreduce(f, AddFun(), A) :
Expand All @@ -207,9 +211,9 @@ reduce(op, a::Number) = a

function mapreduce_seq_impl(f, op::AddFun, a::AbstractArray, ifirst::Int, ilast::Int)
@inbounds begin
s = r_promote(op, evaluate(f, a[ifirst])) + evaluate(f, a[ifirst+1])
s = r_promote(op, f(a[ifirst])) + f(a[ifirst+1])
@simd for i = ifirst+2:ilast
s += evaluate(f, a[i])
s += f(a[i])
end
end
s
Expand Down Expand Up @@ -266,14 +270,14 @@ prod(A::AbstractArray{Bool}) =

function mapreduce_impl(f, op::MaxFun, A::AbstractArray, first::Int, last::Int)
# locate the first non NaN number
v = evaluate(f, A[first])
v = f(A[first])
i = first + 1
while v != v && i <= last
@inbounds v = evaluate(f, A[i])
@inbounds v = f(A[i])
i += 1
end
while i <= last
@inbounds x = evaluate(f, A[i])
@inbounds x = f(A[i])
if x > v
v = x
end
Expand All @@ -284,14 +288,14 @@ end

function mapreduce_impl(f, op::MinFun, A::AbstractArray, first::Int, last::Int)
# locate the first non NaN number
v = evaluate(f, A[first])
v = f(A[first])
i = first + 1
while v != v && i <= last
@inbounds v = evaluate(f, A[i])
@inbounds v = f(A[i])
i += 1
end
while i <= last
@inbounds x = evaluate(f, A[i])
@inbounds x = f(A[i])
if x < v
v = x
end
Expand Down Expand Up @@ -339,28 +343,22 @@ end

function mapfoldl(f, ::AndFun, itr)
for x in itr
if !evaluate(f, x)
return false
end
!f(x) && return false
end
return true
end

function mapfoldl(f, ::OrFun, itr)
for x in itr
if evaluate(f, x)
return true
end
f(x) && return true
end
return false
end

function mapreduce_impl(f, op::AndFun, A::AbstractArray, ifirst::Int, ilast::Int)
while ifirst <= ilast
@inbounds x = A[ifirst]
if !evaluate(f, x)
return false
end
!f(x) && return false
ifirst += 1
end
return true
Expand All @@ -369,9 +367,7 @@ end
function mapreduce_impl(f, op::OrFun, A::AbstractArray, ifirst::Int, ilast::Int)
while ifirst <= ilast
@inbounds x = A[ifirst]
if evaluate(f, x)
return true
end
f(x) && return true
ifirst += 1
end
return false
Expand All @@ -390,8 +386,8 @@ immutable EqX{T} <: Func{1}
x::T
end
EqX{T}(x::T) = EqX{T}(x)
evaluate(f::EqX, y) = (y == f.x)

call(f::EqX, y) = f.x == y
in(x, itr) = any(EqX(x), itr)

const ∈ = in
Expand All @@ -401,9 +397,7 @@ const ∈ = in

function contains(eq::Function, itr, x)
for y in itr
if eq(y, x)
return true
end
eq(y, x) && return true
end
return false
end
Expand All @@ -414,25 +408,22 @@ end
function count(pred::Union(Function,Func{1}), itr)
n = 0
for x in itr
if evaluate(pred, x)
n += 1
end
pred(x) && (n += 1)
end
return n
end

function count(pred::Union(Function,Func{1}), a::AbstractArray)
n = 0
for i = 1:length(a)
@inbounds if evaluate(pred, a[i])
@inbounds if pred(a[i])
n += 1
end
end
return n
end

type NotEqZero <: Func{1} end
evaluate(NotEqZero, x) = (x != 0)
immutable NotEqZero <: Func{1} end
call(::NotEqZero, x) = x != 0

countnz(a) = count(NotEqZero(), a)

18 changes: 9 additions & 9 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ promote_union(T) = T
function reducedim_init{S}(f, op::AddFun, A::AbstractArray{S}, region)
T = promote_union(S)
if method_exists(zero, (Type{T},))
x = evaluate(f, zero(T))
x = f(zero(T))
z = zero(x) + zero(x)
Tr = typeof(z) == typeof(x) && !isbits(T) ? T : typeof(z)
else
Expand All @@ -81,7 +81,7 @@ end
function reducedim_init{S}(f, op::MulFun, A::AbstractArray{S}, region)
T = promote_union(S)
if method_exists(zero, (Type{T},))
x = evaluate(f, zero(T))
x = f(zero(T))
z = one(x) * one(x)
Tr = typeof(z) == typeof(x) && !isbits(T) ? T : typeof(z)
else
Expand All @@ -91,10 +91,10 @@ function reducedim_init{S}(f, op::MulFun, A::AbstractArray{S}, region)
return reducedim_initarray(A, region, z, Tr)
end

reducedim_init{T}(f, op::MaxFun, A::AbstractArray{T}, region) = reducedim_initarray0(A, region, typemin(evaluate(f, zero(T))))
reducedim_init{T}(f, op::MinFun, A::AbstractArray{T}, region) = reducedim_initarray0(A, region, typemax(evaluate(f, zero(T))))
reducedim_init{T}(f, op::MaxFun, A::AbstractArray{T}, region) = reducedim_initarray0(A, region, typemin(f(zero(T))))
reducedim_init{T}(f, op::MinFun, A::AbstractArray{T}, region) = reducedim_initarray0(A, region, typemax(f(zero(T))))
reducedim_init{T}(f::Union(AbsFun,Abs2Fun), op::MaxFun, A::AbstractArray{T}, region) =
reducedim_initarray(A, region, zero(evaluate(f, zero(T))))
reducedim_initarray(A, region, zero(f(zero(T))))

reducedim_init(f, op::AndFun, A::AbstractArray, region) = reducedim_initarray(A, region, true)
reducedim_init(f, op::OrFun, A::AbstractArray, region) = reducedim_initarray(A, region, false)
Expand Down Expand Up @@ -173,16 +173,16 @@ end
@nloops N i d->(d>1? (1:size(A,d)) : (1:1)) d->(j_d = sizeR_d==1 ? 1 : i_d) begin
@inbounds r = (@nref N R j)
for i_1 = 1:sizA1
@inbounds v = evaluate(f, (@nref N A i))
r = evaluate(op, r, v)
@inbounds v = f(@nref N A i)
r = op(r, v)
end
@inbounds (@nref N R j) = r
end
else
# general implementation
@nloops N i A d->(j_d = sizeR_d==1 ? 1 : i_d) begin
@inbounds v = evaluate(f, (@nref N A i))
@inbounds (@nref N R j) = evaluate(op, (@nref N R j), v)
@inbounds v = f(@nref N A i)
@inbounds (@nref N R j) = op((@nref N R j), v)
end
end
return R
Expand Down
2 changes: 1 addition & 1 deletion base/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ varzm{T}(A::AbstractArray{T}, region; corrected::Bool=true) =
immutable CentralizedAbs2Fun{T<:Number} <: Func{1}
m::T
end
evaluate(f::CentralizedAbs2Fun, x) = abs2(x - f.m)
call(f::CentralizedAbs2Fun, x) = abs2(x - f.m)
centralize_sumabs2(A::AbstractArray, m::Number, ifirst::Int, ilast::Int) =
mapreduce_impl(CentralizedAbs2Fun(m), AddFun(), A, ifirst, ilast)

Expand Down