Skip to content

Commit

Permalink
dont force unroll loop in reductions
Browse files Browse the repository at this point in the history
  • Loading branch information
KristofferC committed Sep 18, 2018
1 parent 81dd6ca commit d233823
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 24 deletions.
43 changes: 21 additions & 22 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,13 @@ end


@generated function _map!(f, dest, ::Size{S}, a::StaticArray...) where {S}
exprs = Vector{Expr}(undef, prod(S))
for i 1:prod(S)
tmp = [:(a[$j][$i]) for j 1:length(a)]
exprs[i] = :(dest[$i] = f($(tmp...)))
end
tmp = [:(a[$j][i]) for j 1:length(a)]
return quote
@_inline_meta
@inbounds $(Expr(:block, exprs...))
@inbounds @simd for i 1:prod(S)
dest[i] = f($(tmp...))
end
return dest
end
end

Expand All @@ -66,28 +65,28 @@ end

@generated function _mapreduce(f, op, dims::Colon, nt::NamedTuple{()},
::Size{S}, a::StaticArray...) where {S}
tmp = [:(a[$j][1]) for j 1:length(a)]
expr = :(f($(tmp...)))
for i 2:prod(S)
tmp = [:(a[$j][$i]) for j 1:length(a)]
expr = :(op($expr, f($(tmp...))))
end
tmp = [:(a[$j][i]) for j 1:length(a)]
return quote
@_inline_meta
@inbounds return $expr
i = 1
@inbounds s = f($(tmp...))
@inbounds @simd for i = 2:$(prod(S))
s = op(s, f($(tmp...)))
end
return s
end
end

@generated function _mapreduce(f, op, dims::Colon, nt::NamedTuple{(:init,)},
::Size{S}, a::StaticArray...) where {S}
expr = :(nt.init)
for i 1:prod(S)
tmp = [:(a[$j][$i]) for j 1:length(a)]
expr = :(op($expr, f($(tmp...))))
end
::Size{S}, a::StaticArray...) where {S}
tmp = [:(a[$j][i]) for j 1:length(a)]
return quote
@_inline_meta
@inbounds return $expr
@inbounds s = nt.init
@inbounds @simd for i = 1:$(prod(S))
s = op(s, f($(tmp...)))
end
return s
end
end

Expand All @@ -98,7 +97,7 @@ end
@inline _mapreduce(f, op, D::Int, nt::NamedTuple, sz::Size{S}, a::StaticArray) where {S} =
_mapreduce(f, op, Val(D), nt, sz, a)


@generated function _mapreduce(f, op, dims::Val{D}, nt::NamedTuple{()},
::Size{S}, a::StaticArray) where {S,D}
N = length(S)
Expand Down
4 changes: 2 additions & 2 deletions test/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ end

@test iszero(sz) == iszero(z)

@test sum(sa) === sum(a)
@test sum(sa) sum(a)
@test sum(abs2, sa) === sum(abs2, a)
@test sum(sa, dims=2) === RSArray2(sum(a, dims=2))
@test sum(sa, dims=Val(2)) === RSArray2(sum(a, dims=2))
Expand All @@ -85,7 +85,7 @@ end
@test any(sb, dims=Val(2)) === RSArray2(any(b, dims=2))
@test any(x->x>0, sa, dims=Val(2)) === RSArray2(any(x->x>0, a, dims=2))

@test mean(sa) === mean(a)
@test mean(sa) mean(a)
@test mean(abs2, sa) === mean(abs2, a)
@test mean(sa, dims=Val(2)) === RSArray2(mean(a, dims=2))
@test mean(abs2, sa, dims=Val(2)) === RSArray2(mean(abs2.(a), dims=2))
Expand Down

0 comments on commit d233823

Please sign in to comment.