Skip to content

Commit

Permalink
Perf: Specialize fn in rmap and rmaptype
Browse files Browse the repository at this point in the history
Also avoid use of `map` on `NamedTuple`s as it doesn't specialize `f`.
  • Loading branch information
kpamnany committed Jul 23, 2021
1 parent 1ce475c commit b1bd4a6
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions src/RecursiveApply/RecursiveApply.jl
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,29 @@ export ⊞, ⊠, ⊟
Recursively apply `fn` to each element of `X`
"""
rmap(fn, X) = fn(X)
rmap(fn, X, Y) = fn(X, Y)
rmap(fn, X::Tuple) = map(x -> rmap(fn, x), X)
rmap(fn, X::Tuple, Y::Tuple) = map((x, y) -> rmap(fn, x, y), X, Y)
rmap(fn, X::NamedTuple) = map(x -> rmap(fn, x), X)
rmap(fn, X::NamedTuple{names}, Y::NamedTuple{names}) where {names} =
map((x, y) -> rmap(fn, x, y), X, Y)
rmap(fn::F, X) where {F} = fn(X)
rmap(fn::F, X, Y) where {F} = fn(X, Y)
rmap(fn::F, X::Tuple) where {F} = map(x -> rmap(fn, x), X)
rmap(fn::F, X::Tuple, Y::Tuple) where {F} = map((x, y) -> rmap(fn, x, y), X, Y)
rmap(fn::F, X::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X)))
rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y)))

"""
rmaptype(fn, T)
The return type of `rmap(fn, X::T)`.
"""
rmaptype(fn, ::Type{T}) where {T} = fn(T)
rmaptype(fn, ::Type{T}) where {T <: Tuple} =
rmaptype(fn::F, ::Type{T}) where {F, T} = fn(T)
rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} =
Tuple{map(fn, tuple(T.parameters...))...}
rmaptype(fn, ::Type{T}) where {T <: NamedTuple{names, tup}} where {names, tup} =
rmaptype(
fn::F,
::Type{T},
) where {F, T <: NamedTuple{names, tup}} where {names, tup} =
NamedTuple{names, rmaptype(fn, tup)}


"""
rmul(w, X)
w ⊠ X
Expand Down Expand Up @@ -66,7 +69,6 @@ const ⊟ = rsub

rdiv(X, w::Number) = rmap(x -> x / w, X)


"""
rmuladd(w, X, Y)
Expand All @@ -76,7 +78,6 @@ rmuladd(w::Number, X, Y) = rmap((x, y) -> muladd(w, x, y), X, Y)
rmuladd(X, w::Number, Y) = rmap((x, y) -> muladd(x, w, y), X, Y)
rmuladd(w::Number, x::Number, y::Number) = muladd(w, x, y)


"""
rmatmul1(W, S, i, j)
Expand Down

0 comments on commit b1bd4a6

Please sign in to comment.