Skip to content

Commit

Permalink
remove inference stuff
Browse files Browse the repository at this point in the history
also do linalg

dont force unroll loop in reductions
  • Loading branch information
KristofferC committed Sep 21, 2018
1 parent 8609799 commit 7898e02
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 61 deletions.
70 changes: 33 additions & 37 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,13 @@ end
return :(zero(promote_op(*, eltype(a), eltype(b))))
end

expr = :(conj(a[1]) * b[1])
for j = 2:prod(S)
expr = :($expr + conj(a[$j]) * b[$j])
end

return quote
@_inline_meta
@inbounds return $expr
s = conj(a[1]) * b[1]
@inbounds @simd for j = 2:$(prod(S))
s += conj(a[j]) * b[j]
end
return s
end
end

Expand All @@ -244,15 +243,13 @@ end
if prod(S) == 0
return :(zero(promote_op(*, eltype(a), eltype(b))))
end

expr = :(a[1] * b[1])
for j = 2:prod(S)
expr = :($expr + a[$j] * b[$j])
end

return quote
@_inline_meta
@inbounds return $expr
s = a[1] * b[1]
@inbounds @simd for j = 2:$(prod(S))
s += a[j] * b[j]
end
return s
end
end

Expand All @@ -264,14 +261,13 @@ end
return :(zero(real(eltype(a))))
end

expr = :(abs2(a[1]))
for j = 2:prod(S)
expr = :($expr + abs2(a[$j]))
end

return quote
$(Expr(:meta, :inline))
@inbounds return sqrt($expr)
@_inline_meta
s = abs2(a[1])
@inbounds @simd for j = 2:$(prod(S))
s += abs2(a[j])
end
return sqrt(s)
end
end

Expand All @@ -283,28 +279,27 @@ _norm_p0(x) = x == 0 ? zero(x) : one(x)
return :(zero(real(eltype(a))))
end

expr = :(abs(a[1])^p)
for j = 2:prod(S)
expr = :($expr + abs(a[$j])^p)
end

expr_p1 = :(abs(a[1]))
for j = 2:prod(S)
expr_p1 = :($expr_p1 + abs(a[$j]))
end

return quote
$(Expr(:meta, :inline))
@_inline_meta
s = zero(real(eltype(a)))
if p == Inf
return mapreduce(abs, max, a; init=$(zero(real(eltype(a)))))
elseif p == 1
@inbounds return $expr_p1
s = abs(a[1])
@inbounds @simd for j = 2:$(prod(S))
s += abs(a[j])
end
return s
elseif p == 2
return norm(a)
elseif p == 0
return mapreduce(_norm_p0, +, a; init=$(zero(real(eltype(a)))))
else
@inbounds return ($expr)^(inv(p))
s = abs(a[1])^p
@inbounds @simd for j = 2:$(prod(S))
s += abs(a[j])^p
end
return s^(inv(p))
end
end
end
Expand All @@ -325,12 +320,13 @@ end
return :(zero(eltype(a)))
end

exprs = [:(a[$(LinearIndices(S)[i, i])]) for i = 1:S[1]]
total = reduce((ex1, ex2) -> :($ex1 + $ex2), exprs)

return quote
@_inline_meta
@inbounds return $total
s = a[1,1]
@inbounds @simd for 2 in 1:$(S[1])
s += a[i,i]
end
return s
end
end

Expand Down
43 changes: 21 additions & 22 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,13 @@ end


@generated function _map!(f, dest, ::Size{S}, a::StaticArray...) where {S}
exprs = Vector{Expr}(undef, prod(S))
for i 1:prod(S)
tmp = [:(a[$j][$i]) for j 1:length(a)]
exprs[i] = :(dest[$i] = f($(tmp...)))
end
tmp = [:(a[$j][i]) for j 1:length(a)]
return quote
@_inline_meta
@inbounds $(Expr(:block, exprs...))
@inbounds @simd for i 1:$(prod(S))
dest[i] = f($(tmp...))
end
return dest
end
end

Expand All @@ -66,28 +65,28 @@ end

@generated function _mapreduce(f, op, dims::Colon, nt::NamedTuple{()},
::Size{S}, a::StaticArray...) where {S}
tmp = [:(a[$j][1]) for j 1:length(a)]
expr = :(f($(tmp...)))
for i 2:prod(S)
tmp = [:(a[$j][$i]) for j 1:length(a)]
expr = :(op($expr, f($(tmp...))))
end
tmp = [:(a[$j][i]) for j 1:length(a)]
return quote
@_inline_meta
@inbounds return $expr
i = 1
@inbounds s = f($(tmp...))
@inbounds @simd for i = 2:$(prod(S))
s = op(s, f($(tmp...)))
end
return s
end
end

@generated function _mapreduce(f, op, dims::Colon, nt::NamedTuple{(:init,)},
::Size{S}, a::StaticArray...) where {S}
expr = :(nt.init)
for i 1:prod(S)
tmp = [:(a[$j][$i]) for j 1:length(a)]
expr = :(op($expr, f($(tmp...))))
end
::Size{S}, a::StaticArray...) where {S}
tmp = [:(a[$j][i]) for j 1:length(a)]
return quote
@_inline_meta
@inbounds return $expr
s = nt.init
@inbounds @simd for i = 1:$(prod(S))
s = op(s, f($(tmp...)))
end
return s
end
end

Expand All @@ -98,7 +97,7 @@ end
@inline _mapreduce(f, op, D::Int, nt::NamedTuple, sz::Size{S}, a::StaticArray) where {S} =
_mapreduce(f, op, Val(D), nt, sz, a)


@generated function _mapreduce(f, op, dims::Val{D}, nt::NamedTuple{()},
::Size{S}, a::StaticArray) where {S,D}
N = length(S)
Expand Down
4 changes: 2 additions & 2 deletions test/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ end

@test iszero(sz) == iszero(z)

@test sum(sa) === sum(a)
@test sum(sa) sum(a)
@test sum(abs2, sa) === sum(abs2, a)
@test sum(sa, dims=2) === RSArray2(sum(a, dims=2))
@test sum(sa, dims=Val(2)) === RSArray2(sum(a, dims=2))
Expand All @@ -85,7 +85,7 @@ end
@test any(sb, dims=Val(2)) === RSArray2(any(b, dims=2))
@test any(x->x>0, sa, dims=Val(2)) === RSArray2(any(x->x>0, a, dims=2))

@test mean(sa) === mean(a)
@test mean(sa) mean(a)
@test mean(abs2, sa) === mean(abs2, a)
@test mean(sa, dims=Val(2)) === RSArray2(mean(a, dims=2))
@test mean(abs2, sa, dims=Val(2)) === RSArray2(mean(abs2.(a), dims=2))
Expand Down

0 comments on commit 7898e02

Please sign in to comment.